diff --git a/surfsense_backend/.env.example b/surfsense_backend/.env.example index a6b2b30a3..6a8f991e4 100644 --- a/surfsense_backend/.env.example +++ b/surfsense_backend/.env.example @@ -84,6 +84,9 @@ SECRET_KEY=SECRET # JWT Token Lifetimes (optional, defaults shown) # ACCESS_TOKEN_LIFETIME_SECONDS=86400 # 1 day # REFRESH_TOKEN_LIFETIME_SECONDS=1209600 # 2 weeks +# Personal Access Tokens (PATs). Empty/unset = no maximum; users may create +# never-expiring PATs. When set, PAT creation requires an expiry <= this many days. +# PAT_MAX_EXPIRY_DAYS= NEXT_FRONTEND_URL=http://localhost:3000 diff --git a/surfsense_backend/alembic/versions/166_add_pat_and_api_access.py b/surfsense_backend/alembic/versions/166_add_pat_and_api_access.py new file mode 100644 index 000000000..b49b099a6 --- /dev/null +++ b/surfsense_backend/alembic/versions/166_add_pat_and_api_access.py @@ -0,0 +1,83 @@ +"""Add personal access tokens and search-space API access gate. + +Revision ID: 166 +Revises: 165 +""" + +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + +revision: str = "166" +down_revision: str | None = "165" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.execute( + """ + CREATE TABLE IF NOT EXISTS personal_access_tokens ( + id SERIAL PRIMARY KEY, + user_id UUID NOT NULL REFERENCES "user"(id) ON DELETE CASCADE, + token_hash VARCHAR(64) NOT NULL, + token_prefix VARCHAR(16) NOT NULL, + label VARCHAR NOT NULL, + expires_at TIMESTAMP WITH TIME ZONE, + last_used_at TIMESTAMP WITH TIME ZONE, + created_at TIMESTAMP WITH TIME ZONE NOT NULL + ); + """ + ) + + op.execute( + "CREATE UNIQUE INDEX IF NOT EXISTS ix_personal_access_tokens_token_hash " + "ON personal_access_tokens (token_hash)" + ) + op.execute( + "CREATE INDEX IF NOT EXISTS ix_personal_access_tokens_user_id " + "ON personal_access_tokens (user_id)" + ) + op.execute( + "CREATE INDEX IF NOT EXISTS ix_personal_access_tokens_id " + "ON personal_access_tokens (id)" + ) + op.execute( + "CREATE INDEX IF NOT EXISTS ix_personal_access_tokens_created_at " + "ON personal_access_tokens (created_at)" + ) + op.execute( + "CREATE INDEX IF NOT EXISTS ix_personal_access_tokens_expires_at " + "ON personal_access_tokens (expires_at)" + ) + + bind = op.get_bind() + api_access_column_exists = bind.execute( + sa.text( + """ + SELECT EXISTS ( + SELECT FROM information_schema.columns + WHERE table_schema = current_schema() + AND table_name = 'searchspaces' + AND column_name = 'api_access_enabled' + ) + """ + ) + ).scalar() + + op.execute( + "ALTER TABLE searchspaces ADD COLUMN IF NOT EXISTS " + "api_access_enabled BOOLEAN NOT NULL DEFAULT false" + ) + + if not api_access_column_exists: + op.execute("UPDATE searchspaces SET api_access_enabled = true") + + +def downgrade() -> None: + op.execute( + "ALTER TABLE searchspaces DROP COLUMN IF EXISTS api_access_enabled" + ) + op.execute("DROP TABLE IF EXISTS personal_access_tokens") diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/factory.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/factory.py index 10a734192..d823a5a06 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/factory.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/factory.py @@ -34,6 +34,7 @@ from app.agents.chat.runtime.llm_config import AgentConfig from app.agents.chat.runtime.prompt_caching import ( apply_litellm_prompt_caching, ) +from app.auth.context import AuthContext from app.db import ChatVisibility from app.services.connector_service import ConnectorService from app.services.user_tool_allowlist import ( @@ -73,6 +74,7 @@ async def create_multi_agent_chat_deep_agent( anon_session_id: str | None = None, filesystem_selection: FilesystemSelection | None = None, image_gen_model_id: int | None = None, + auth_context: AuthContext | None = None, ): """Deep agent with SurfSense tools/middleware; registry route subagents behind ``task`` when enabled. @@ -139,6 +141,7 @@ async def create_multi_agent_chat_deep_agent( "connector_service": connector_service, "firecrawl_api_key": firecrawl_api_key, "user_id": user_id, + "auth_context": auth_context, "thread_id": thread_id, "thread_visibility": visibility, "available_connectors": available_connectors, diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/automation/create.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/automation/create.py index 4472a11ac..fe42410ed 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/automation/create.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/automation/create.py @@ -30,9 +30,10 @@ from pydantic import ValidationError from app.agents.chat.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( request_approval, ) +from app.auth.context import AuthContext from app.automations.schemas.api import AutomationCreate from app.automations.services.automation import AutomationService -from app.db import User, async_session_maker +from app.db import async_session_maker from app.utils.content_utils import extract_text_content from .prompt import build_draft_prompt @@ -47,6 +48,7 @@ def create_create_automation_tool( search_space_id: int, user_id: str | UUID, llm: Any, + auth_context: AuthContext | None = None, ): """Factory for the ``create_automation`` tool. @@ -56,8 +58,6 @@ def create_create_automation_tool( ``AsyncSession`` is opened per call to avoid stale sessions on compiled-agent cache hits (same pattern as the Notion / memory tools). """ - uid = UUID(user_id) if isinstance(user_id, str) else user_id - @tool async def create_automation(intent: str, runtime: ToolRuntime) -> dict[str, Any]: """Draft + save an automation from a natural-language intent. @@ -165,14 +165,17 @@ def create_create_automation_tool( "issues": _format_validation_issues(exc), } + if auth_context is None: + logger.error( + "create_automation called without AuthContext; refusing to persist" + ) + return { + "status": "error", + "message": "authorization context missing for automation creation", + } + async with async_session_maker() as session: - user = await session.get(User, uid) - if user is None: - return { - "status": "error", - "message": "user not found in this session", - } - service = AutomationService(session=session, user=user) + service = AutomationService(session=session, auth=auth_context) created = await service.create(final_validated) return { "status": "saved", diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/registry.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/registry.py index f04d7cdec..5e7c2d5d6 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/registry.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/registry.py @@ -60,6 +60,7 @@ def _build_create_automation_tool(deps: dict[str, Any]) -> BaseTool: return create_create_automation_tool( search_space_id=deps["search_space_id"], user_id=deps["user_id"], + auth_context=deps.get("auth_context"), llm=deps["llm"], ) diff --git a/surfsense_backend/app/app.py b/surfsense_backend/app/app.py index 6dfe6a776..e6aa2fa3e 100644 --- a/surfsense_backend/app/app.py +++ b/surfsense_backend/app/app.py @@ -27,6 +27,7 @@ from app.agents.chat.runtime.checkpointer import ( close_checkpointer, setup_checkpointer_tables, ) +from app.auth.context import AuthContext from app.config import ( config, initialize_image_gen_router, @@ -34,7 +35,7 @@ from app.config import ( initialize_openrouter_integration, initialize_pricing_registration, ) -from app.db import User, create_db_and_tables, get_async_session +from app.db import create_db_and_tables, get_async_session from app.exceptions import GENERIC_5XX_MESSAGE, ISSUES_URL, SurfSenseError from app.gateway.byo_long_poll import ( start_byo_long_poll_supervisors, @@ -55,7 +56,7 @@ from app.routes import router as crud_router from app.routes.auth_routes import router as auth_router from app.schemas import UserCreate, UserRead, UserUpdate from app.session_events import register_session_hooks -from app.users import SECRET, auth_backend, current_active_user, fastapi_users +from app.users import SECRET, allow_any_principal, auth_backend, fastapi_users from app.utils.perf import log_system_snapshot _error_logger = logging.getLogger("surfsense.errors") @@ -1032,7 +1033,7 @@ async def readiness_check(): @app.get("/verify-token") async def authenticated_route( - user: User = Depends(current_active_user), + auth: AuthContext = Depends(allow_any_principal), session: AsyncSession = Depends(get_async_session), ): - return {"message": "Token is valid"} + return {"message": "Token is valid", "method": auth.method} diff --git a/surfsense_backend/app/auth/__init__.py b/surfsense_backend/app/auth/__init__.py new file mode 100644 index 000000000..0486f3d79 --- /dev/null +++ b/surfsense_backend/app/auth/__init__.py @@ -0,0 +1 @@ +"""Authentication principals and helpers.""" diff --git a/surfsense_backend/app/auth/context.py b/surfsense_backend/app/auth/context.py new file mode 100644 index 000000000..d14c9f784 --- /dev/null +++ b/surfsense_backend/app/auth/context.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal + +from app.db import PersonalAccessToken, User + +AuthMethod = Literal["session", "pat", "system"] + + +@dataclass(frozen=True) +class AuthContext: + """Typed principal for authorization decisions.""" + + user: User + method: AuthMethod + pat: PersonalAccessToken | None = None + source: str | None = None + + @classmethod + def session(cls, user: User) -> AuthContext: + return cls(user=user, method="session") + + @classmethod + def pat_auth(cls, user: User, pat: PersonalAccessToken) -> AuthContext: + return cls(user=user, method="pat", pat=pat) + + @classmethod + def system(cls, user: User, source: str) -> AuthContext: + return cls(user=user, method="system", source=source) + + @property + def is_gated(self) -> bool: + return self.method == "pat" + + @property + def is_session(self) -> bool: + return self.method == "session" diff --git a/surfsense_backend/app/automations/actions/builtin/agent_task/invoke.py b/surfsense_backend/app/automations/actions/builtin/agent_task/invoke.py index c3a35930d..b2f441961 100644 --- a/surfsense_backend/app/automations/actions/builtin/agent_task/invoke.py +++ b/surfsense_backend/app/automations/actions/builtin/agent_task/invoke.py @@ -16,7 +16,8 @@ from app.agents.chat.runtime.mention_resolver import ( substitute_in_text, ) from app.agents.chat.shared.context import SurfSenseContextSchema -from app.db import ChatVisibility, async_session_maker +from app.auth.context import AuthContext +from app.db import ChatVisibility, User, async_session_maker from app.schemas.new_chat import MentionedDocumentInfo from ...types import ActionContext @@ -147,6 +148,12 @@ async def run_agent_task( decision = "approve" if auto_approve_all else "reject" async with async_session_maker() as agent_session: + auth_context = None + if ctx.creator_user_id: + user = await agent_session.get(User, ctx.creator_user_id) + if user is not None: + auth_context = AuthContext.system(user, source="automation") + deps = await build_dependencies( session=agent_session, search_space_id=ctx.search_space_id, @@ -168,6 +175,7 @@ async def run_agent_task( thread_visibility=ChatVisibility.PRIVATE, mentioned_document_ids=mentioned_document_ids, image_gen_model_id=ctx.image_gen_model_id, + auth_context=auth_context, ) agent_query, runtime_context = await _resolve_mention_context( diff --git a/surfsense_backend/app/automations/services/automation.py b/surfsense_backend/app/automations/services/automation.py index 1d371c35d..261d41bfc 100644 --- a/surfsense_backend/app/automations/services/automation.py +++ b/surfsense_backend/app/automations/services/automation.py @@ -27,17 +27,19 @@ from app.automations.services.model_policy import ( ) from app.automations.triggers import get_trigger from app.automations.triggers.builtin.schedule import compute_next_fire_at -from app.db import Permission, SearchSpace, User, get_async_session -from app.users import current_active_user +from app.auth.context import AuthContext +from app.db import Permission, SearchSpace, get_async_session +from app.users import get_auth_context from app.utils.rbac import check_permission class AutomationService: """Lifecycle of the ``Automation`` resource.""" - def __init__(self, *, session: AsyncSession, user: User) -> None: + def __init__(self, *, session: AsyncSession, auth: AuthContext) -> None: self.session = session - self.user = user + self.auth = auth + self.user = auth.user async def create(self, payload: AutomationCreate) -> Automation: """Create an automation and its initial triggers in one transaction.""" @@ -235,7 +237,7 @@ class AutomationService: async def _authorize(self, search_space_id: int, permission: str) -> None: await check_permission( self.session, - self.user, + self.auth, search_space_id, permission, f"You don't have permission to {permission.split(':')[1]} automations in this search space", @@ -274,6 +276,6 @@ def _build_trigger(spec: TriggerCreate) -> AutomationTrigger: def get_automation_service( session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ) -> AutomationService: - return AutomationService(session=session, user=user) + return AutomationService(session=session, auth=auth) diff --git a/surfsense_backend/app/automations/services/run.py b/surfsense_backend/app/automations/services/run.py index 3ef80416f..8ef763e5e 100644 --- a/surfsense_backend/app/automations/services/run.py +++ b/surfsense_backend/app/automations/services/run.py @@ -8,17 +8,18 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.automations.persistence.models.automation import Automation from app.automations.persistence.models.run import AutomationRun -from app.db import Permission, User, get_async_session -from app.users import current_active_user +from app.auth.context import AuthContext +from app.db import Permission, get_async_session +from app.users import get_auth_context from app.utils.rbac import check_permission class RunService: """Read-only access to ``AutomationRun`` history.""" - def __init__(self, *, session: AsyncSession, user: User) -> None: + def __init__(self, *, session: AsyncSession, auth: AuthContext) -> None: self.session = session - self.user = user + self.auth = auth async def list( self, @@ -63,7 +64,7 @@ class RunService: ) await check_permission( self.session, - self.user, + self.auth, automation.search_space_id, permission, f"You don't have permission to {permission.split(':')[1]} automations in this search space", @@ -73,6 +74,6 @@ class RunService: def get_run_service( session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ) -> RunService: - return RunService(session=session, user=user) + return RunService(session=session, auth=auth) diff --git a/surfsense_backend/app/automations/services/trigger.py b/surfsense_backend/app/automations/services/trigger.py index 523153927..7ff6e56fa 100644 --- a/surfsense_backend/app/automations/services/trigger.py +++ b/surfsense_backend/app/automations/services/trigger.py @@ -14,17 +14,18 @@ from app.automations.persistence.models.trigger import AutomationTrigger from app.automations.schemas.api import TriggerCreate, TriggerUpdate from app.automations.triggers import get_trigger from app.automations.triggers.builtin.schedule import compute_next_fire_at -from app.db import Permission, User, get_async_session -from app.users import current_active_user +from app.auth.context import AuthContext +from app.db import Permission, get_async_session +from app.users import get_auth_context from app.utils.rbac import check_permission class TriggerService: """Lifecycle of the ``AutomationTrigger`` sub-resource.""" - def __init__(self, *, session: AsyncSession, user: User) -> None: + def __init__(self, *, session: AsyncSession, auth: AuthContext) -> None: self.session = session - self.user = user + self.auth = auth async def add( self, *, automation_id: int, payload: TriggerCreate @@ -101,7 +102,7 @@ class TriggerService: ) await check_permission( self.session, - self.user, + self.auth, automation.search_space_id, permission, f"You don't have permission to {permission.split(':')[1]} automations in this search space", @@ -144,6 +145,6 @@ def _initial_next_fire( def get_trigger_service( session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ) -> TriggerService: - return TriggerService(session=session, user=user) + return TriggerService(session=session, auth=auth) diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index 63be54654..b998f05cf 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -919,6 +919,10 @@ class Config: REFRESH_TOKEN_LIFETIME_SECONDS = int( os.getenv("REFRESH_TOKEN_LIFETIME_SECONDS", str(14 * 24 * 60 * 60)) # 2 weeks ) + _PAT_MAX_EXPIRY_DAYS = os.getenv("PAT_MAX_EXPIRY_DAYS", "").strip() + PAT_MAX_EXPIRY_DAYS = ( + int(_PAT_MAX_EXPIRY_DAYS) if _PAT_MAX_EXPIRY_DAYS else None + ) # ETL Service ETL_SERVICE = os.getenv("ETL_SERVICE") diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 3f098d5d2..a65a964fd 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -368,6 +368,9 @@ class Permission(StrEnum): SETTINGS_UPDATE = "settings:update" SETTINGS_DELETE = "settings:delete" # Delete the entire search space + # API Access + API_ACCESS_MANAGE = "api_access:manage" + # Public Sharing PUBLIC_SHARING_VIEW = "public_sharing:view" PUBLIC_SHARING_CREATE = "public_sharing:create" @@ -1693,6 +1696,9 @@ class SearchSpace(BaseModel, TimestampMixin): citations_enabled = Column( Boolean, nullable=False, default=True ) # Enable/disable citations + api_access_enabled = Column( + Boolean, nullable=False, default=False, server_default="false" + ) qna_custom_instructions = Column( Text, nullable=True, default="" ) # User's custom instructions @@ -2330,6 +2336,11 @@ if config.AUTH_TYPE == "GOOGLE": back_populates="user", cascade="all, delete-orphan", ) + personal_access_tokens = relationship( + "PersonalAccessToken", + back_populates="user", + cascade="all, delete-orphan", + ) else: @@ -2462,6 +2473,11 @@ else: back_populates="user", cascade="all, delete-orphan", ) + personal_access_tokens = relationship( + "PersonalAccessToken", + back_populates="user", + cascade="all, delete-orphan", + ) class AgentActionLog(BaseModel): @@ -2712,6 +2728,36 @@ class RefreshToken(Base, TimestampMixin): return not self.is_expired and not self.is_revoked +class PersonalAccessToken(BaseModel, TimestampMixin): + """ + Stores hashed Personal Access Tokens for programmatic API access. + Plaintext tokens are shown once on creation and are never persisted. + """ + + __tablename__ = "personal_access_tokens" + + user_id = Column( + UUID(as_uuid=True), + ForeignKey("user.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + user = relationship("User", back_populates="personal_access_tokens") + token_hash = Column(String(64), unique=True, nullable=False, index=True) + token_prefix = Column(String(16), nullable=False) + label = Column(String, nullable=False) + expires_at = Column(TIMESTAMP(timezone=True), nullable=True, index=True) + last_used_at = Column(TIMESTAMP(timezone=True), nullable=True) + + @property + def is_expired(self) -> bool: + return self.expires_at is not None and datetime.now(UTC) >= self.expires_at + + @property + def is_valid(self) -> bool: + return not self.is_expired + + # Register model packages that live outside this file so their classes # are present in Base.metadata before configure_mappers() resolves any # string-based relationship() references. diff --git a/surfsense_backend/app/file_storage/api.py b/surfsense_backend/app/file_storage/api.py index c649ba63d..80417baaf 100644 --- a/surfsense_backend/app/file_storage/api.py +++ b/surfsense_backend/app/file_storage/api.py @@ -9,7 +9,8 @@ from fastapi.responses import StreamingResponse from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from app.db import Document, Permission, User, get_async_session +from app.auth.context import AuthContext +from app.db import Document, Permission, get_async_session from app.file_storage.persistence.enums import DocumentFileKind from app.file_storage.schemas import DocumentFileRead from app.file_storage.service import ( @@ -17,14 +18,14 @@ from app.file_storage.service import ( list_document_files, open_document_file_stream, ) -from app.users import current_active_user +from app.users import get_auth_context from app.utils.rbac import check_permission router = APIRouter() async def _load_readable_document( - *, document_id: int, session: AsyncSession, user: User + *, document_id: int, session: AsyncSession, auth: AuthContext ) -> Document: """Load a document the user may read, or raise 404/403.""" document = ( @@ -35,7 +36,7 @@ async def _load_readable_document( await check_permission( session, - user, + auth, document.search_space_id, Permission.DOCUMENTS_READ.value, "You don't have permission to read documents in this search space", @@ -57,10 +58,10 @@ def _content_disposition(filename: str) -> str: async def read_document_files( document_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ) -> list[DocumentFileRead]: """Return metadata for every stored file of a document (gates the UI).""" - await _load_readable_document(document_id=document_id, session=session, user=user) + await _load_readable_document(document_id=document_id, session=session, auth=auth) records = await list_document_files(session, document_id=document_id) return [DocumentFileRead.model_validate(r) for r in records] @@ -69,10 +70,10 @@ async def read_document_files( async def download_original_document_file( document_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ) -> StreamingResponse: """Stream the document's original uploaded file.""" - await _load_readable_document(document_id=document_id, session=session, user=user) + await _load_readable_document(document_id=document_id, session=session, auth=auth) record = await get_document_file( session, document_id=document_id, kind=DocumentFileKind.ORIGINAL diff --git a/surfsense_backend/app/gateway/agent_invoke.py b/surfsense_backend/app/gateway/agent_invoke.py index 8701ccc55..e03ea8c8b 100644 --- a/surfsense_backend/app/gateway/agent_invoke.py +++ b/surfsense_backend/app/gateway/agent_invoke.py @@ -9,6 +9,7 @@ from collections.abc import AsyncIterator from sqlalchemy import update from sqlalchemy.ext.asyncio import AsyncSession +from app.auth.context import AuthContext from app.db import ExternalChatBinding, NewChatMessage from app.gateway.auth_invariant import assert_authorization_invariant from app.gateway.base.translator import BaseStreamTranslator, GatewayStreamEvent @@ -64,6 +65,7 @@ async def call_agent_for_gateway( request_id: str | None = None, ) -> None: user = await assert_authorization_invariant(session, binding) + auth_context = AuthContext.system(user, source="gateway") thread = await get_or_create_thread_for_binding(session, binding) await session.commit() @@ -81,6 +83,7 @@ async def call_agent_for_gateway( current_user_display_name=user.display_name or "A team member", disabled_tools=sorted(DEFAULT_HITL_TOOL_NAMES), request_id=request_id or "gateway", + auth_context=auth_context, ) events = _events_from_sse(stream) try: diff --git a/surfsense_backend/app/gateway/auth_invariant.py b/surfsense_backend/app/gateway/auth_invariant.py index e72023ce1..008250957 100644 --- a/surfsense_backend/app/gateway/auth_invariant.py +++ b/surfsense_backend/app/gateway/auth_invariant.py @@ -5,6 +5,7 @@ from __future__ import annotations from fastapi import HTTPException from sqlalchemy.ext.asyncio import AsyncSession +from app.auth.context import AuthContext from app.db import ExternalChatBinding, Permission, User from app.gateway.bindings import suspend_binding from app.observability.metrics import record_gateway_auth_invariant_failure @@ -39,11 +40,13 @@ async def assert_authorization_invariant( if user is None: await _fail(session, binding, "owner_missing") + auth = AuthContext.system(user, source="gateway") + try: - await check_search_space_access(session, user, binding.search_space_id) + await check_search_space_access(session, auth, binding.search_space_id) await check_permission( session, - user, + auth, binding.search_space_id, Permission.CHATS_CREATE.value, "External chat owner no longer has permission to chat in this search space", diff --git a/surfsense_backend/app/notifications/api/api.py b/surfsense_backend/app/notifications/api/api.py index 9a136ca7b..7794a5867 100644 --- a/surfsense_backend/app/notifications/api/api.py +++ b/surfsense_backend/app/notifications/api/api.py @@ -8,7 +8,8 @@ from fastapi import APIRouter, Depends, HTTPException, Query, status from sqlalchemy import case, desc, func, literal, literal_column, select, update from sqlalchemy.ext.asyncio import AsyncSession -from app.db import User, get_async_session +from app.auth.context import AuthContext +from app.db import get_async_session from app.notifications.api.schemas import ( BatchUnreadCountResponse, CategoryUnreadCount, @@ -27,7 +28,7 @@ from app.notifications.api.transform import ( from app.notifications.constants import CATEGORY_TYPES, SYNC_WINDOW_DAYS from app.notifications.persistence import Notification from app.notifications.types import NotificationCategory, NotificationType -from app.users import current_active_user +from app.users import require_session_context router = APIRouter(prefix="/notifications", tags=["notifications"]) @@ -35,10 +36,11 @@ router = APIRouter(prefix="/notifications", tags=["notifications"]) @router.get("/unread-counts-batch", response_model=BatchUnreadCountResponse) async def get_unread_counts_batch( search_space_id: int | None = Query(None, description="Filter by search space ID"), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), session: AsyncSession = Depends(get_async_session), ) -> BatchUnreadCountResponse: """Unread counts for every category in a single query.""" + user = auth.user cutoff_date = datetime.now(UTC) - timedelta(days=SYNC_WINDOW_DAYS) base_filter = [ @@ -86,10 +88,11 @@ async def get_unread_counts_batch( @router.get("/source-types", response_model=SourceTypesResponse) async def get_notification_source_types( search_space_id: int | None = Query(None, description="Filter by search space ID"), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), session: AsyncSession = Depends(get_async_session), ) -> SourceTypesResponse: """Distinct connector/document source types for the Status tab filter.""" + user = auth.user base_filter = [Notification.user_id == user.id] if search_space_id is not None: @@ -160,7 +163,7 @@ async def get_unread_count( category: NotificationCategory | None = Query( None, description="Filter by category: 'comments' or 'status'" ), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), session: AsyncSession = Depends(get_async_session), ) -> UnreadCountResponse: """Total and recent (within sync window) unread counts for the user. @@ -168,6 +171,7 @@ async def get_unread_count( Returning both lets a client hold the older count static while live-syncing the recent ones. """ + user = auth.user cutoff_date = datetime.now(UTC) - timedelta(days=SYNC_WINDOW_DAYS) base_filter = [ @@ -230,10 +234,11 @@ async def list_notifications( ), limit: int = Query(50, ge=1, le=100, description="Number of items to return"), offset: int = Query(0, ge=0, description="Number of items to skip"), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), session: AsyncSession = Depends(get_async_session), ) -> NotificationListResponse: """Paginated inbox fallback for items outside the Zero sync window.""" + user = auth.user query = select(Notification).where(Notification.user_id == user.id) count_query = select(func.count(Notification.id)).where( Notification.user_id == user.id @@ -328,10 +333,11 @@ async def list_notifications( @router.patch("/{notification_id}/read", response_model=MarkReadResponse) async def mark_notification_as_read( notification_id: int, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), session: AsyncSession = Depends(get_async_session), ) -> MarkReadResponse: """Mark one of the user's notifications read; Zero syncs the change.""" + user = auth.user # Scope to the caller's own notifications. result = await session.execute( select(Notification).where( @@ -364,10 +370,11 @@ async def mark_notification_as_read( @router.patch("/read-all", response_model=MarkAllReadResponse) async def mark_all_notifications_as_read( - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), session: AsyncSession = Depends(get_async_session), ) -> MarkAllReadResponse: """Mark all of the user's notifications read; Zero syncs the changes.""" + user = auth.user result = await session.execute( update(Notification) .where( diff --git a/surfsense_backend/app/podcasts/api/routes.py b/surfsense_backend/app/podcasts/api/routes.py index cfcb2ede9..582b0531e 100644 --- a/surfsense_backend/app/podcasts/api/routes.py +++ b/surfsense_backend/app/podcasts/api/routes.py @@ -18,12 +18,12 @@ from fastapi.responses import StreamingResponse from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from app.auth.context import AuthContext from app.config import config as app_config from app.db import ( Permission, SearchSpace, SearchSpaceMembership, - User, get_async_session, ) from app.podcasts.generation.brief import propose_brief @@ -42,7 +42,7 @@ from app.podcasts.voices import ( provider_from_service, render_voice_preview, ) -from app.users import current_active_user +from app.users import get_auth_context from app.utils.rbac import check_permission from .schemas import ( @@ -63,13 +63,14 @@ async def list_podcasts( skip: int = 0, limit: int = 100, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user if skip < 0 or limit < 1: raise HTTPException(status_code=400, detail="Invalid pagination parameters") if search_space_id is not None: - await _require(session, user, search_space_id, Permission.PODCASTS_READ) + await _require(session, auth, search_space_id, Permission.PODCASTS_READ) query = ( select(Podcast) .where(Podcast.search_space_id == search_space_id) @@ -132,7 +133,7 @@ async def list_languages(): @router.get("/podcasts/voices/{voice_id}/preview") async def preview_voice( voice_id: str, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): """A short audio sample of a voice, so users pick by sound.""" if not app_config.TTS_SERVICE: @@ -156,9 +157,9 @@ async def preview_voice( async def create_podcast( body: CreatePodcastRequest, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): - await _require(session, user, body.search_space_id, Permission.PODCASTS_CREATE) + await _require(session, auth, body.search_space_id, Permission.PODCASTS_CREATE) service = PodcastService(session) podcast = await service.create( @@ -185,9 +186,9 @@ async def create_podcast( async def get_podcast( podcast_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): - podcast = await _load(session, user, podcast_id, Permission.PODCASTS_READ) + podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_READ) return PodcastDetail.of(podcast) @@ -196,9 +197,9 @@ async def update_spec( podcast_id: int, body: UpdateSpecRequest, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): - podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE) + podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_UPDATE) async with _lifecycle_errors(): await PodcastService(session).update_spec( podcast, body.spec, body.expected_version @@ -211,10 +212,10 @@ async def update_spec( async def approve_brief( podcast_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): """Approve the brief and start drafting the transcript.""" - podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE) + podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_UPDATE) async with _lifecycle_errors(): await PodcastService(session).begin_drafting(podcast) await session.commit() @@ -228,10 +229,10 @@ async def approve_brief( async def regenerate_transcript( podcast_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): """Reopen the brief gate for a fresh take; drafting waits for re-approval.""" - podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE) + podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_UPDATE) async with _lifecycle_errors(): await PodcastService(session).regenerate(podcast) await session.commit() @@ -242,10 +243,10 @@ async def regenerate_transcript( async def revert_regeneration( podcast_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): """Back out of a regeneration and return to the finished episode.""" - podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE) + podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_UPDATE) async with _lifecycle_errors(): await PodcastService(session).revert_regeneration(podcast) await session.commit() @@ -256,9 +257,9 @@ async def revert_regeneration( async def cancel_podcast( podcast_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): - podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE) + podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_UPDATE) async with _lifecycle_errors(): await PodcastService(session).cancel(podcast) await session.commit() @@ -269,9 +270,9 @@ async def cancel_podcast( async def delete_podcast( podcast_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): - podcast = await _load(session, user, podcast_id, Permission.PODCASTS_DELETE) + podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_DELETE) await purge_audio(podcast) await session.delete(podcast) await session.commit() @@ -282,9 +283,9 @@ async def delete_podcast( async def stream_podcast( podcast_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): - podcast = await _load(session, user, podcast_id, Permission.PODCASTS_READ) + podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_READ) if podcast.storage_key: # Verify first so a missing object is a 404, not a mid-stream crash. @@ -323,13 +324,13 @@ async def stream_podcast( async def _require( session: AsyncSession, - user: User, + auth: AuthContext, search_space_id: int, permission: Permission, ) -> None: await check_permission( session, - user, + auth, search_space_id, permission.value, "You don't have permission for podcasts in this search space", @@ -338,14 +339,14 @@ async def _require( async def _load( session: AsyncSession, - user: User, + auth: AuthContext, podcast_id: int, permission: Permission, ) -> Podcast: podcast = await PodcastRepository(session).get(podcast_id) if podcast is None: raise HTTPException(status_code=404, detail="Podcast not found") - await _require(session, user, podcast.search_space_id, permission) + await _require(session, auth, podcast.search_space_id, permission) return podcast diff --git a/surfsense_backend/app/routes/__init__.py b/surfsense_backend/app/routes/__init__.py index 8ce84d179..caa1a2546 100644 --- a/surfsense_backend/app/routes/__init__.py +++ b/surfsense_backend/app/routes/__init__.py @@ -54,6 +54,7 @@ from .notes_routes import router as notes_router from .notion_add_connector_route import router as notion_add_connector_router from .obsidian_plugin_routes import router as obsidian_plugin_router from .onedrive_add_connector_route import router as onedrive_add_connector_router +from .personal_access_tokens_routes import router as personal_access_tokens_router from .prompts_routes import router as prompts_router from .public_chat_routes import router as public_chat_router from .rbac_routes import router as rbac_router @@ -113,6 +114,7 @@ router.include_router(slack_add_connector_router) router.include_router(teams_add_connector_router) router.include_router(onedrive_add_connector_router) router.include_router(obsidian_plugin_router) # Obsidian plugin push API +router.include_router(personal_access_tokens_router) # Personal access token manager router.include_router(discord_add_connector_router) router.include_router(jira_add_connector_router) router.include_router(confluence_add_connector_router) diff --git a/surfsense_backend/app/routes/agent_action_log_route.py b/surfsense_backend/app/routes/agent_action_log_route.py index 9a55fdec3..bf94ae3b4 100644 --- a/surfsense_backend/app/routes/agent_action_log_route.py +++ b/surfsense_backend/app/routes/agent_action_log_route.py @@ -28,6 +28,7 @@ from pydantic import BaseModel from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession +from app.auth.context import AuthContext from app.agents.chat.multi_agent_chat.shared.feature_flags import get_flags from app.db import ( AgentActionLog, @@ -36,7 +37,7 @@ from app.db import ( User, get_async_session, ) -from app.users import current_active_user +from app.users import get_auth_context from app.utils.rbac import check_permission logger = logging.getLogger(__name__) @@ -111,8 +112,9 @@ async def list_thread_actions( page: int = Query(0, ge=0), page_size: int = Query(50, ge=1, le=200), session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ) -> AgentActionListResponse: + user = auth.user """List agent actions for a thread, newest first. Authorization: @@ -132,7 +134,7 @@ async def list_thread_actions( await check_permission( session, - user, + auth, thread.search_space_id, Permission.CHATS_READ.value, "You don't have permission to view this thread's action log.", diff --git a/surfsense_backend/app/routes/agent_flags_route.py b/surfsense_backend/app/routes/agent_flags_route.py index e97608cbe..222909c59 100644 --- a/surfsense_backend/app/routes/agent_flags_route.py +++ b/surfsense_backend/app/routes/agent_flags_route.py @@ -26,9 +26,9 @@ from app.agents.chat.multi_agent_chat.shared.feature_flags import ( AgentFeatureFlags, get_flags, ) +from app.auth.context import AuthContext from app.config import config -from app.db import User -from app.users import current_active_user +from app.users import require_session_context router = APIRouter() @@ -75,6 +75,6 @@ class AgentFeatureFlagsRead(BaseModel): @router.get("/agent/flags", response_model=AgentFeatureFlagsRead) async def get_agent_flags( - _user: User = Depends(current_active_user), + _auth: AuthContext = Depends(require_session_context), ) -> AgentFeatureFlagsRead: return AgentFeatureFlagsRead.from_flags(get_flags()) diff --git a/surfsense_backend/app/routes/agent_permissions_route.py b/surfsense_backend/app/routes/agent_permissions_route.py index 0c07eeb9c..521adfb03 100644 --- a/surfsense_backend/app/routes/agent_permissions_route.py +++ b/surfsense_backend/app/routes/agent_permissions_route.py @@ -30,6 +30,7 @@ from sqlalchemy import select from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession +from app.auth.context import AuthContext from app.agents.chat.multi_agent_chat.shared.feature_flags import get_flags from app.db import ( AgentPermissionRule, @@ -39,7 +40,7 @@ from app.db import ( User, get_async_session, ) -from app.users import current_active_user +from app.users import get_auth_context from app.utils.rbac import check_permission logger = logging.getLogger(__name__) @@ -133,15 +134,16 @@ def _to_read(row: AgentPermissionRule) -> AgentPermissionRuleRead: async def _ensure_search_space_membership_admin( - session: AsyncSession, user: User, search_space_id: int + session: AsyncSession, auth: AuthContext, search_space_id: int ) -> None: + user = auth.user """Curating agent rules == "settings" administration on the space.""" space = await session.get(SearchSpace, search_space_id) if space is None: raise HTTPException(status_code=404, detail="Search space not found.") await check_permission( session, - user, + auth, search_space_id, Permission.SETTINGS_UPDATE.value, "You don't have permission to manage agent permission rules in this space.", @@ -160,8 +162,9 @@ async def _ensure_search_space_membership_admin( async def list_rules( search_space_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ) -> list[AgentPermissionRuleRead]: + user = auth.user _flag_guard() await _ensure_search_space_membership_admin(session, user, search_space_id) @@ -183,8 +186,9 @@ async def create_rule( search_space_id: int, payload: AgentPermissionRuleCreate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ) -> AgentPermissionRuleRead: + user = auth.user _flag_guard() await _ensure_search_space_membership_admin(session, user, search_space_id) @@ -232,8 +236,9 @@ async def update_rule( rule_id: int, payload: AgentPermissionRuleUpdate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ) -> AgentPermissionRuleRead: + user = auth.user _flag_guard() await _ensure_search_space_membership_admin(session, user, search_space_id) @@ -266,8 +271,9 @@ async def delete_rule( search_space_id: int, rule_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ) -> None: + user = auth.user _flag_guard() await _ensure_search_space_membership_admin(session, user, search_space_id) diff --git a/surfsense_backend/app/routes/agent_revert_route.py b/surfsense_backend/app/routes/agent_revert_route.py index ce21de69d..a00c292d0 100644 --- a/surfsense_backend/app/routes/agent_revert_route.py +++ b/surfsense_backend/app/routes/agent_revert_route.py @@ -5,7 +5,7 @@ here" affordance. To prevent accidental usage during the gap we return ``503 Service Unavailable`` until the ``SURFSENSE_ENABLE_REVERT_ROUTE`` flag flips. Once enabled, the route runs: -1. Authentication via :func:`current_active_user`. +1. Authentication via an interactive session context. 2. Action lookup; 404 if the action does not belong to the thread. 3. Authorization via :func:`app.services.revert_service.can_revert`. 4. Revert dispatch via :func:`app.services.revert_service.revert_action`. @@ -33,9 +33,9 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from app.agents.chat.multi_agent_chat.shared.feature_flags import get_flags +from app.auth.context import AuthContext from app.db import ( AgentActionLog, - User, get_async_session, ) from app.services.revert_service import ( @@ -45,7 +45,7 @@ from app.services.revert_service import ( load_thread, revert_action, ) -from app.users import current_active_user +from app.users import require_session_context logger = logging.getLogger(__name__) @@ -57,8 +57,9 @@ async def revert_agent_action( thread_id: int, action_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), ) -> dict: + user = auth.user flags = get_flags() if flags.disable_new_agent_stack or not flags.enable_revert_route: raise HTTPException( @@ -269,7 +270,7 @@ async def revert_agent_turn( thread_id: int, chat_turn_id: str, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), ) -> RevertTurnResponse: """Revert every reversible action emitted during ``chat_turn_id``. @@ -281,6 +282,7 @@ async def revert_agent_turn( Partial success is intentional and returned with HTTP 200. Callers must inspect ``results[*].status`` to find rows that need attention. """ + user = auth.user flags = get_flags() if flags.disable_new_agent_stack or not flags.enable_revert_route: diff --git a/surfsense_backend/app/routes/airtable_add_connector_route.py b/surfsense_backend/app/routes/airtable_add_connector_route.py index f70b9166b..d5cbc2498 100644 --- a/surfsense_backend/app/routes/airtable_add_connector_route.py +++ b/surfsense_backend/app/routes/airtable_add_connector_route.py @@ -10,16 +10,16 @@ from pydantic import ValidationError from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession +from app.auth.context import AuthContext from app.config import config from app.connectors.airtable_connector import fetch_airtable_user_email from app.db import ( SearchSourceConnector, SearchSourceConnectorType, - User, get_async_session, ) from app.schemas.airtable_auth_credentials import AirtableAuthCredentialsBase -from app.users import current_active_user +from app.users import require_session_context from app.utils.connector_naming import ( check_duplicate_connector, generate_unique_connector_name, @@ -78,7 +78,10 @@ def make_basic_auth_header(client_id: str, client_secret: str) -> str: @router.get("/auth/airtable/connector/add") -async def connect_airtable(space_id: int, user: User = Depends(current_active_user)): +async def connect_airtable( + space_id: int, + auth: AuthContext = Depends(require_session_context), +): """ Initiate Airtable OAuth flow. @@ -89,6 +92,7 @@ async def connect_airtable(space_id: int, user: User = Depends(current_active_us Returns: Authorization URL for redirect """ + user = auth.user try: if not space_id: raise HTTPException(status_code=400, detail="space_id is required") diff --git a/surfsense_backend/app/routes/auth_routes.py b/surfsense_backend/app/routes/auth_routes.py index b1cbaf2a5..be1506a9f 100644 --- a/surfsense_backend/app/routes/auth_routes.py +++ b/surfsense_backend/app/routes/auth_routes.py @@ -5,6 +5,7 @@ import logging from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy import select +from app.auth.context import AuthContext from app.db import User, async_session_maker from app.schemas.auth import ( LogoutAllResponse, @@ -13,7 +14,7 @@ from app.schemas.auth import ( RefreshTokenRequest, RefreshTokenResponse, ) -from app.users import current_active_user, get_jwt_strategy +from app.users import get_jwt_strategy, require_session_context from app.utils.refresh_tokens import ( revoke_all_user_tokens, revoke_refresh_token, @@ -83,11 +84,14 @@ async def revoke_token(request: LogoutRequest): @router.post("/logout-all", response_model=LogoutAllResponse) -async def logout_all_devices(user: User = Depends(current_active_user)): +async def logout_all_devices( + auth: AuthContext = Depends(require_session_context), +): """ Logout from all devices by revoking all refresh tokens for the user. Requires valid access token. """ + user = auth.user await revoke_all_user_tokens(user.id) logger.info(f"User {user.id} logged out from all devices") return LogoutAllResponse() diff --git a/surfsense_backend/app/routes/chat_comments_routes.py b/surfsense_backend/app/routes/chat_comments_routes.py index f5a8fd0af..2e1eb1d27 100644 --- a/surfsense_backend/app/routes/chat_comments_routes.py +++ b/surfsense_backend/app/routes/chat_comments_routes.py @@ -5,7 +5,8 @@ Routes for chat comments and mentions. from fastapi import APIRouter, Depends from sqlalchemy.ext.asyncio import AsyncSession -from app.db import User, get_async_session +from app.auth.context import AuthContext +from app.db import get_async_session from app.schemas.chat_comments import ( CommentBatchRequest, CommentBatchResponse, @@ -25,7 +26,7 @@ from app.services.chat_comments_service import ( get_user_mentions, update_comment, ) -from app.users import current_active_user +from app.users import require_session_context router = APIRouter() @@ -34,20 +35,20 @@ router = APIRouter() async def batch_list_comments( request: CommentBatchRequest, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), ): """Batch-fetch comments for multiple messages in one request.""" - return await get_comments_for_messages_batch(session, request.message_ids, user) + return await get_comments_for_messages_batch(session, request.message_ids, auth) @router.get("/messages/{message_id}/comments", response_model=CommentListResponse) async def list_comments( message_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), ): """List all comments for a message with their replies.""" - return await get_comments_for_message(session, message_id, user) + return await get_comments_for_message(session, message_id, auth) @router.post("/messages/{message_id}/comments", response_model=CommentResponse) @@ -55,10 +56,10 @@ async def add_comment( message_id: int, request: CommentCreateRequest, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), ): """Create a top-level comment on an AI response.""" - return await create_comment(session, message_id, request.content, user) + return await create_comment(session, message_id, request.content, auth) @router.post("/comments/{comment_id}/replies", response_model=CommentReplyResponse) @@ -66,10 +67,10 @@ async def add_reply( comment_id: int, request: CommentCreateRequest, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), ): """Reply to an existing comment.""" - return await create_reply(session, comment_id, request.content, user) + return await create_reply(session, comment_id, request.content, auth) @router.put("/comments/{comment_id}", response_model=CommentReplyResponse) @@ -77,20 +78,20 @@ async def edit_comment( comment_id: int, request: CommentUpdateRequest, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), ): """Update a comment's content (author only).""" - return await update_comment(session, comment_id, request.content, user) + return await update_comment(session, comment_id, request.content, auth) @router.delete("/comments/{comment_id}") async def remove_comment( comment_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), ): """Delete a comment (author or user with COMMENTS_DELETE permission).""" - return await delete_comment(session, comment_id, user) + return await delete_comment(session, comment_id, auth) # ============================================================================= @@ -102,7 +103,7 @@ async def remove_comment( async def list_mentions( search_space_id: int | None = None, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), ): """List mentions for the current user.""" - return await get_user_mentions(session, user, search_space_id) + return await get_user_mentions(session, auth, search_space_id) diff --git a/surfsense_backend/app/routes/clickup_add_connector_route.py b/surfsense_backend/app/routes/clickup_add_connector_route.py index f7b0876e5..3be32b217 100644 --- a/surfsense_backend/app/routes/clickup_add_connector_route.py +++ b/surfsense_backend/app/routes/clickup_add_connector_route.py @@ -16,15 +16,15 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select +from app.auth.context import AuthContext from app.config import config from app.db import ( SearchSourceConnector, SearchSourceConnectorType, - User, get_async_session, ) from app.schemas.clickup_auth_credentials import ClickUpAuthCredentialsBase -from app.users import current_active_user +from app.users import require_session_context from app.utils.oauth_security import OAuthStateManager, TokenEncryption logger = logging.getLogger(__name__) @@ -61,7 +61,10 @@ def get_token_encryption() -> TokenEncryption: @router.get("/auth/clickup/connector/add") -async def connect_clickup(space_id: int, user: User = Depends(current_active_user)): +async def connect_clickup( + space_id: int, + auth: AuthContext = Depends(require_session_context), +): """ Initiate ClickUp OAuth flow. @@ -72,6 +75,7 @@ async def connect_clickup(space_id: int, user: User = Depends(current_active_use Returns: Authorization URL for redirect """ + user = auth.user try: if not space_id: raise HTTPException(status_code=400, detail="space_id is required") diff --git a/surfsense_backend/app/routes/composio_routes.py b/surfsense_backend/app/routes/composio_routes.py index 7bc2addf8..410f90256 100644 --- a/surfsense_backend/app/routes/composio_routes.py +++ b/surfsense_backend/app/routes/composio_routes.py @@ -22,11 +22,11 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy.orm.attributes import flag_modified +from app.auth.context import AuthContext from app.config import config from app.db import ( SearchSourceConnector, SearchSourceConnectorType, - User, get_async_session, ) from app.services.composio_service import ( @@ -35,12 +35,13 @@ from app.services.composio_service import ( TOOLKIT_TO_CONNECTOR_TYPE, ComposioService, ) -from app.users import current_active_user +from app.users import get_auth_context, require_session_context from app.utils.connector_naming import ( count_connectors_of_type, get_base_name_for_type, ) from app.utils.oauth_security import OAuthStateManager +from app.utils.rbac import check_search_space_access logger = logging.getLogger(__name__) @@ -68,13 +69,16 @@ def get_state_manager() -> OAuthStateManager: @router.get("/composio/toolkits") -async def list_composio_toolkits(user: User = Depends(current_active_user)): +async def list_composio_toolkits( + auth: AuthContext = Depends(require_session_context), +): """ List available Composio toolkits. Returns: JSON with list of available toolkits and their metadata. """ + del auth if not ComposioService.is_enabled(): raise HTTPException( status_code=503, @@ -98,7 +102,7 @@ async def initiate_composio_auth( toolkit_id: str = Query( ..., description="Composio toolkit ID (e.g., 'googledrive', 'gmail')" ), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), ): """ Initiate Composio OAuth flow for a specific toolkit. @@ -110,6 +114,7 @@ async def initiate_composio_auth( Returns: JSON with auth_url to redirect user to Composio authorization """ + user = auth.user if not ComposioService.is_enabled(): raise HTTPException( status_code=503, @@ -446,7 +451,7 @@ async def reauth_composio_connector( space_id: int, connector_id: int, return_url: str | None = None, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), session: AsyncSession = Depends(get_async_session), ): """ @@ -460,6 +465,7 @@ async def reauth_composio_connector( connector_id: ID of the existing Composio connector to re-authenticate return_url: Optional frontend path to redirect to after completion """ + user = auth.user if not ComposioService.is_enabled(): raise HTTPException( status_code=503, detail="Composio integration is not enabled." @@ -644,7 +650,7 @@ async def list_composio_drive_folders( connector_id: int, parent_id: str | None = None, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): """ List folders AND files in user's Google Drive via Composio. @@ -659,6 +665,7 @@ async def list_composio_drive_folders( ) connector = None + user = auth.user try: result = await session.execute( select(SearchSourceConnector).filter( @@ -676,6 +683,8 @@ async def list_composio_drive_folders( detail="Composio Google Drive connector not found or access denied", ) + await check_search_space_access(session, auth, connector.search_space_id) + composio_connected_account_id = connector.config.get( "composio_connected_account_id" ) diff --git a/surfsense_backend/app/routes/confluence_add_connector_route.py b/surfsense_backend/app/routes/confluence_add_connector_route.py index 42235e240..cc9e681bf 100644 --- a/surfsense_backend/app/routes/confluence_add_connector_route.py +++ b/surfsense_backend/app/routes/confluence_add_connector_route.py @@ -15,15 +15,15 @@ from pydantic import ValidationError from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession +from app.auth.context import AuthContext from app.config import config from app.db import ( SearchSourceConnector, SearchSourceConnectorType, - User, get_async_session, ) from app.schemas.atlassian_auth_credentials import AtlassianAuthCredentialsBase -from app.users import current_active_user +from app.users import require_session_context from app.utils.connector_naming import ( check_duplicate_connector, extract_identifier_from_credentials, @@ -77,7 +77,10 @@ def get_token_encryption() -> TokenEncryption: @router.get("/auth/confluence/connector/add") -async def connect_confluence(space_id: int, user: User = Depends(current_active_user)): +async def connect_confluence( + space_id: int, + auth: AuthContext = Depends(require_session_context), +): """ Initiate Confluence OAuth flow. @@ -88,6 +91,7 @@ async def connect_confluence(space_id: int, user: User = Depends(current_active_ Returns: Authorization URL for redirect """ + user = auth.user try: if not space_id: raise HTTPException(status_code=400, detail="space_id is required") @@ -421,10 +425,11 @@ async def reauth_confluence( space_id: int, connector_id: int, return_url: str | None = None, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), session: AsyncSession = Depends(get_async_session), ): """Initiate Confluence re-authentication to upgrade OAuth scopes.""" + user = auth.user try: from sqlalchemy.future import select diff --git a/surfsense_backend/app/routes/discord_add_connector_route.py b/surfsense_backend/app/routes/discord_add_connector_route.py index 4ab48f544..1da0987b0 100644 --- a/surfsense_backend/app/routes/discord_add_connector_route.py +++ b/surfsense_backend/app/routes/discord_add_connector_route.py @@ -15,21 +15,22 @@ from pydantic import ValidationError from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession +from app.auth.context import AuthContext from app.config import config from app.db import ( SearchSourceConnector, SearchSourceConnectorType, - User, get_async_session, ) from app.schemas.discord_auth_credentials import DiscordAuthCredentialsBase -from app.users import current_active_user +from app.users import get_auth_context, require_session_context from app.utils.connector_naming import ( check_duplicate_connector, extract_identifier_from_credentials, generate_unique_connector_name, ) from app.utils.oauth_security import OAuthStateManager, TokenEncryption +from app.utils.rbac import check_search_space_access logger = logging.getLogger(__name__) @@ -77,7 +78,10 @@ def get_token_encryption() -> TokenEncryption: @router.get("/auth/discord/connector/add") -async def connect_discord(space_id: int, user: User = Depends(current_active_user)): +async def connect_discord( + space_id: int, + auth: AuthContext = Depends(require_session_context), +): """ Initiate Discord OAuth flow. @@ -88,6 +92,7 @@ async def connect_discord(space_id: int, user: User = Depends(current_active_use Returns: Authorization URL for redirect """ + user = auth.user try: if not space_id: raise HTTPException(status_code=400, detail="space_id is required") @@ -610,7 +615,7 @@ def _compute_channel_permissions( async def get_discord_channels( connector_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): """ Get list of Discord text channels for a connector with permission info. @@ -628,6 +633,7 @@ async def get_discord_channels( """ from sqlalchemy import select + user = auth.user try: # Get connector and verify ownership result = await session.execute( @@ -646,6 +652,8 @@ async def get_discord_channels( detail="Discord connector not found or access denied", ) + await check_search_space_access(session, auth, connector.search_space_id) + # Get credentials and decrypt bot token credentials = DiscordAuthCredentialsBase.from_dict(connector.config) token_encryption = get_token_encryption() diff --git a/surfsense_backend/app/routes/documents_routes.py b/surfsense_backend/app/routes/documents_routes.py index 53f03a0ca..9d908f4a1 100644 --- a/surfsense_backend/app/routes/documents_routes.py +++ b/surfsense_backend/app/routes/documents_routes.py @@ -7,6 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy.orm import selectinload +from app.auth.context import AuthContext from app.agents.chat.runtime.path_resolver import virtual_path_to_doc from app.db import ( Chunk, @@ -35,7 +36,7 @@ from app.schemas import ( PaginatedResponse, ) from app.services.task_dispatcher import TaskDispatcher, get_task_dispatcher -from app.users import current_active_user +from app.users import get_auth_context from app.utils.rbac import check_permission try: @@ -60,8 +61,9 @@ MAX_FILE_SIZE_BYTES = 500 * 1024 * 1024 # 500 MB per file async def create_documents( request: DocumentsCreate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Create new documents. Requires DOCUMENTS_CREATE permission. @@ -70,7 +72,7 @@ async def create_documents( # Check permission await check_permission( session, - user, + auth, request.search_space_id, Permission.DOCUMENTS_CREATE.value, "You don't have permission to create documents in this search space", @@ -128,9 +130,10 @@ async def create_documents_file_upload( use_vision_llm: bool = Form(False), processing_mode: str = Form("basic"), session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), dispatcher: TaskDispatcher = Depends(get_task_dispatcher), ): + user = auth.user """ Upload files as documents with real-time status tracking. @@ -159,7 +162,7 @@ async def create_documents_file_upload( try: await check_permission( session, - user, + auth, search_space_id, Permission.DOCUMENTS_CREATE.value, "You don't have permission to create documents in this search space", @@ -340,8 +343,9 @@ async def read_documents( sort_by: str = "created_at", sort_order: str = "desc", session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ List documents the user has access to, with optional filtering and pagination. Requires DOCUMENTS_READ permission for the search space(s). @@ -369,7 +373,7 @@ async def read_documents( if search_space_id is not None: await check_permission( session, - user, + auth, search_space_id, Permission.DOCUMENTS_READ.value, "You don't have permission to read documents in this search space", @@ -519,8 +523,9 @@ async def search_documents( search_space_id: int | None = None, document_types: str | None = None, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Search documents by title substring, optionally filtered by search_space_id and document_types. Requires DOCUMENTS_READ permission for the search space(s). @@ -549,7 +554,7 @@ async def search_documents( if search_space_id is not None: await check_permission( session, - user, + auth, search_space_id, Permission.DOCUMENTS_READ.value, "You don't have permission to read documents in this search space", @@ -677,8 +682,9 @@ async def search_document_titles( page: int = 0, page_size: int = 20, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Lightweight document title search optimized for mention picker (@mentions). @@ -703,7 +709,7 @@ async def search_document_titles( # Check permission for the search space await check_permission( session, - user, + auth, search_space_id, Permission.DOCUMENTS_READ.value, "You don't have permission to read documents in this search space", @@ -781,8 +787,9 @@ async def get_document_by_virtual_path( search_space_id: int, virtual_path: str, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """Resolve a knowledge-base document by its agent-facing virtual path. The agent renders every document under ``/documents/...`` with a @@ -804,7 +811,7 @@ async def get_document_by_virtual_path( try: await check_permission( session, - user, + auth, search_space_id, Permission.DOCUMENTS_READ.value, "You don't have permission to read documents in this search space", @@ -838,8 +845,9 @@ async def get_documents_status( search_space_id: int, document_ids: str, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Batch status endpoint for documents in a search space. @@ -849,7 +857,7 @@ async def get_documents_status( try: await check_permission( session, - user, + auth, search_space_id, Permission.DOCUMENTS_READ.value, "You don't have permission to read documents in this search space", @@ -905,8 +913,9 @@ async def get_documents_status( async def get_document_type_counts( search_space_id: int | None = None, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Get counts of documents by type for search spaces the user has access to. Requires DOCUMENTS_READ permission for the search space(s). @@ -926,7 +935,7 @@ async def get_document_type_counts( # Check permission for specific search space await check_permission( session, - user, + auth, search_space_id, Permission.DOCUMENTS_READ.value, "You don't have permission to read documents in this search space", @@ -965,7 +974,7 @@ async def get_document_by_chunk_id( 5, ge=0, description="Number of chunks before/after the cited chunk to include" ), session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): """ Retrieves a document based on a chunk ID, including a window of chunks around the cited one. @@ -995,7 +1004,7 @@ async def get_document_by_chunk_id( await check_permission( session, - user, + auth, document.search_space_id, Permission.DOCUMENTS_READ.value, "You don't have permission to read documents in this search space", @@ -1060,12 +1069,13 @@ async def get_document_by_chunk_id( async def get_watched_folders( search_space_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """Return root folders that are marked as watched (metadata->>'watched' = 'true').""" await check_permission( session, - user, + auth, search_space_id, Permission.DOCUMENTS_READ.value, "You don't have permission to read documents in this search space", @@ -1101,8 +1111,9 @@ async def get_document_chunks_paginated( None, ge=0, description="Direct offset; overrides page * page_size" ), session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Paginated chunk loading for a document. Supports both page-based and offset-based access. @@ -1120,7 +1131,7 @@ async def get_document_chunks_paginated( await check_permission( session, - user, + auth, document.search_space_id, Permission.DOCUMENTS_READ.value, "You don't have permission to read documents in this search space", @@ -1162,8 +1173,9 @@ async def get_document_chunks_paginated( async def read_document( document_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Get a specific document by ID. Requires DOCUMENTS_READ permission for the search space. @@ -1182,7 +1194,7 @@ async def read_document( # Check permission for the search space await check_permission( session, - user, + auth, document.search_space_id, Permission.DOCUMENTS_READ.value, "You don't have permission to read documents in this search space", @@ -1216,8 +1228,9 @@ async def update_document( document_id: int, document_update: DocumentUpdate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Update a document. Requires DOCUMENTS_UPDATE permission for the search space. @@ -1236,7 +1249,7 @@ async def update_document( # Check permission for the search space await check_permission( session, - user, + auth, db_document.search_space_id, Permission.DOCUMENTS_UPDATE.value, "You don't have permission to update documents in this search space", @@ -1275,8 +1288,9 @@ async def update_document( async def delete_document( document_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Delete a document. Requires DOCUMENTS_DELETE permission for the search space. @@ -1311,7 +1325,7 @@ async def delete_document( # Check permission for the search space await check_permission( session, - user, + auth, document.search_space_id, Permission.DOCUMENTS_DELETE.value, "You don't have permission to delete documents in this search space", @@ -1355,8 +1369,9 @@ async def delete_document( async def list_document_versions( document_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """List all versions for a document, ordered by version_number descending.""" document = ( await session.execute(select(Document).where(Document.id == document_id)) @@ -1396,8 +1411,9 @@ async def get_document_version( document_id: int, version_number: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """Get full version content including source_markdown.""" document = ( await session.execute(select(Document).where(Document.id == document_id)) @@ -1434,8 +1450,9 @@ async def restore_document_version( document_id: int, version_number: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """Restore a previous version: snapshot current state, then overwrite document content.""" document = ( await session.execute(select(Document).where(Document.id == document_id)) @@ -1517,8 +1534,9 @@ class FolderSyncFinalizeRequest(PydanticBaseModel): async def folder_mtime_check( request: FolderMtimeCheckRequest, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """Pre-upload optimization: check which files need uploading based on mtime. Returns the subset of relative paths where the file is new or has a @@ -1528,7 +1546,7 @@ async def folder_mtime_check( await check_permission( session, - user, + auth, request.search_space_id, Permission.DOCUMENTS_CREATE.value, "You don't have permission to create documents in this search space", @@ -1587,8 +1605,9 @@ async def folder_upload( use_vision_llm: bool = Form(False), processing_mode: str = Form("basic"), session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """Upload files from the desktop app for folder indexing. Files are written to temp storage and dispatched to a Celery task. @@ -1603,7 +1622,7 @@ async def folder_upload( await check_permission( session, - user, + auth, search_space_id, Permission.DOCUMENTS_CREATE.value, "You don't have permission to create documents in this search space", @@ -1733,8 +1752,9 @@ async def folder_upload( async def folder_unlink( request: FolderUnlinkRequest, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """Handle file deletion events from the desktop watcher. For each relative path, find the matching document and delete it. @@ -1746,7 +1766,7 @@ async def folder_unlink( await check_permission( session, - user, + auth, request.search_space_id, Permission.DOCUMENTS_DELETE.value, "You don't have permission to delete documents in this search space", @@ -1787,8 +1807,9 @@ async def folder_unlink( async def folder_sync_finalize( request: FolderSyncFinalizeRequest, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """Finalize a full folder scan by deleting orphaned documents. The client sends the complete list of relative paths currently in the @@ -1803,7 +1824,7 @@ async def folder_sync_finalize( await check_permission( session, - user, + auth, request.search_space_id, Permission.DOCUMENTS_DELETE.value, "You don't have permission to delete documents in this search space", diff --git a/surfsense_backend/app/routes/dropbox_add_connector_route.py b/surfsense_backend/app/routes/dropbox_add_connector_route.py index 1dba64467..6a9284371 100644 --- a/surfsense_backend/app/routes/dropbox_add_connector_route.py +++ b/surfsense_backend/app/routes/dropbox_add_connector_route.py @@ -21,21 +21,22 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy.orm.attributes import flag_modified +from app.auth.context import AuthContext from app.config import config from app.connectors.dropbox import DropboxClient, list_folder_contents from app.db import ( SearchSourceConnector, SearchSourceConnectorType, - User, get_async_session, ) -from app.users import current_active_user +from app.users import get_auth_context, require_session_context from app.utils.connector_naming import ( check_duplicate_connector, extract_identifier_from_credentials, generate_unique_connector_name, ) from app.utils.oauth_security import OAuthStateManager, TokenEncryption +from app.utils.rbac import check_search_space_access logger = logging.getLogger(__name__) router = APIRouter() @@ -66,8 +67,12 @@ def get_token_encryption() -> TokenEncryption: @router.get("/auth/dropbox/connector/add") -async def connect_dropbox(space_id: int, user: User = Depends(current_active_user)): +async def connect_dropbox( + space_id: int, + auth: AuthContext = Depends(require_session_context), +): """Initiate Dropbox OAuth flow.""" + user = auth.user try: if not space_id: raise HTTPException(status_code=400, detail="space_id is required") @@ -109,10 +114,11 @@ async def reauth_dropbox( space_id: int, connector_id: int, return_url: str | None = None, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), session: AsyncSession = Depends(get_async_session), ): """Re-authenticate an existing Dropbox connector.""" + user = auth.user try: result = await session.execute( select(SearchSourceConnector).filter( @@ -405,10 +411,11 @@ async def list_dropbox_folders( connector_id: int, parent_path: str = "", session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): """List folders and files in user's Dropbox.""" connector = None + user = auth.user try: result = await session.execute( select(SearchSourceConnector).filter( @@ -424,6 +431,8 @@ async def list_dropbox_folders( status_code=404, detail="Dropbox connector not found or access denied" ) + await check_search_space_access(session, auth, connector.search_space_id) + dropbox_client = DropboxClient(session, connector_id) items, error = await list_folder_contents(dropbox_client, path=parent_path) diff --git a/surfsense_backend/app/routes/editor_routes.py b/surfsense_backend/app/routes/editor_routes.py index 8250fff98..fe00995ea 100644 --- a/surfsense_backend/app/routes/editor_routes.py +++ b/surfsense_backend/app/routes/editor_routes.py @@ -18,6 +18,7 @@ from fastapi.responses import StreamingResponse from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession +from app.auth.context import AuthContext from app.db import Chunk, Document, DocumentType, Permission, User, get_async_session from app.routes.reports_routes import ( _FILE_EXTENSIONS, @@ -31,7 +32,7 @@ from app.templates.export_helpers import ( get_reference_docx_path, get_typst_template_path, ) -from app.users import current_active_user +from app.users import get_auth_context from app.utils.rbac import check_permission logger = logging.getLogger(__name__) @@ -47,8 +48,9 @@ async def get_editor_content( search_space_id: int, document_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Get document content for editing. @@ -60,7 +62,7 @@ async def get_editor_content( # Check RBAC permission await check_permission( session, - user, + auth, search_space_id, Permission.DOCUMENTS_READ.value, "You don't have permission to read documents in this search space", @@ -178,15 +180,16 @@ async def download_document_markdown( search_space_id: int, document_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Download the full document content as a .md file. Reconstructs markdown from source_markdown or chunks. """ await check_permission( session, - user, + auth, search_space_id, Permission.DOCUMENTS_READ.value, "You don't have permission to read documents in this search space", @@ -244,8 +247,9 @@ async def save_document( document_id: int, data: dict[str, Any], session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Save document markdown and trigger reindexing. Called when user clicks 'Save & Exit'. @@ -259,7 +263,7 @@ async def save_document( # Check RBAC permission await check_permission( session, - user, + auth, search_space_id, Permission.DOCUMENTS_UPDATE.value, "You don't have permission to update documents in this search space", @@ -331,12 +335,13 @@ async def export_document( description="Export format: pdf, docx, html, latex, epub, odt, or plain", ), session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """Export a document in the requested format (reuses the report export pipeline).""" await check_permission( session, - user, + auth, search_space_id, Permission.DOCUMENTS_READ.value, "You don't have permission to read documents in this search space", diff --git a/surfsense_backend/app/routes/export_routes.py b/surfsense_backend/app/routes/export_routes.py index 4f2b545a3..8e419157f 100644 --- a/surfsense_backend/app/routes/export_routes.py +++ b/surfsense_backend/app/routes/export_routes.py @@ -7,9 +7,10 @@ from fastapi import APIRouter, Depends, HTTPException, Query from fastapi.responses import StreamingResponse from sqlalchemy.ext.asyncio import AsyncSession +from app.auth.context import AuthContext from app.db import Permission, User, get_async_session from app.services.export_service import build_export_zip -from app.users import current_active_user +from app.users import get_auth_context from app.utils.rbac import check_permission logger = logging.getLogger(__name__) @@ -24,12 +25,13 @@ async def export_knowledge_base( None, description="Export only this folder's subtree" ), session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """Export documents as a ZIP of markdown files preserving folder structure.""" await check_permission( session, - user, + auth, search_space_id, Permission.DOCUMENTS_READ.value, "You don't have permission to export documents in this search space", diff --git a/surfsense_backend/app/routes/folders_routes.py b/surfsense_backend/app/routes/folders_routes.py index dca55f31e..8a5dfcb73 100644 --- a/surfsense_backend/app/routes/folders_routes.py +++ b/surfsense_backend/app/routes/folders_routes.py @@ -5,6 +5,7 @@ from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select +from app.auth.context import AuthContext from app.db import Document, Folder, Permission, User, get_async_session from app.schemas import ( BulkDocumentMove, @@ -23,7 +24,7 @@ from app.services.folder_service import ( get_subtree_max_depth, validate_folder_depth, ) -from app.users import current_active_user +from app.users import get_auth_context from app.utils.rbac import check_permission router = APIRouter() @@ -33,13 +34,14 @@ router = APIRouter() async def create_folder( request: FolderCreate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """Create a new folder. Requires DOCUMENTS_CREATE permission.""" try: await check_permission( session, - user, + auth, request.search_space_id, Permission.DOCUMENTS_CREATE.value, "You don't have permission to create folders in this search space", @@ -91,13 +93,14 @@ async def create_folder( async def list_folders( search_space_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """List all folders in a search space (flat). Requires DOCUMENTS_READ permission.""" try: await check_permission( session, - user, + auth, search_space_id, Permission.DOCUMENTS_READ.value, "You don't have permission to read folders in this search space", @@ -122,8 +125,9 @@ async def list_folders( async def get_folder( folder_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """Get a single folder. Requires DOCUMENTS_READ permission.""" try: folder = await session.get(Folder, folder_id) @@ -132,7 +136,7 @@ async def get_folder( await check_permission( session, - user, + auth, folder.search_space_id, Permission.DOCUMENTS_READ.value, "You don't have permission to read folders in this search space", @@ -152,8 +156,9 @@ async def get_folder( async def get_folder_breadcrumb( folder_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """Get ancestor chain for breadcrumb display. Requires DOCUMENTS_READ permission.""" try: folder = await session.get(Folder, folder_id) @@ -162,7 +167,7 @@ async def get_folder_breadcrumb( await check_permission( session, - user, + auth, folder.search_space_id, Permission.DOCUMENTS_READ.value, "You don't have permission to read folders in this search space", @@ -196,8 +201,9 @@ async def get_folder_breadcrumb( async def stop_watching_folder( folder_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """Clear the watched flag from a folder's metadata.""" folder = await session.get(Folder, folder_id) if not folder: @@ -205,7 +211,7 @@ async def stop_watching_folder( await check_permission( session, - user, + auth, folder.search_space_id, Permission.DOCUMENTS_UPDATE.value, "You don't have permission to update folders in this search space", @@ -224,8 +230,9 @@ async def update_folder( folder_id: int, request: FolderUpdate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """Rename a folder. Requires DOCUMENTS_UPDATE permission.""" try: folder = await session.get(Folder, folder_id) @@ -234,7 +241,7 @@ async def update_folder( await check_permission( session, - user, + auth, folder.search_space_id, Permission.DOCUMENTS_UPDATE.value, "You don't have permission to update folders in this search space", @@ -264,8 +271,9 @@ async def move_folder( folder_id: int, request: FolderMove, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """Move a folder to a new parent. Requires DOCUMENTS_UPDATE permission.""" try: folder = await session.get(Folder, folder_id) @@ -274,7 +282,7 @@ async def move_folder( await check_permission( session, - user, + auth, folder.search_space_id, Permission.DOCUMENTS_UPDATE.value, "You don't have permission to move folders in this search space", @@ -324,8 +332,9 @@ async def reorder_folder( folder_id: int, request: FolderReorder, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """Reorder a folder among its siblings via fractional indexing. Requires DOCUMENTS_UPDATE.""" try: folder = await session.get(Folder, folder_id) @@ -334,7 +343,7 @@ async def reorder_folder( await check_permission( session, - user, + auth, folder.search_space_id, Permission.DOCUMENTS_UPDATE.value, "You don't have permission to reorder folders in this search space", @@ -365,8 +374,9 @@ async def reorder_folder( async def delete_folder( folder_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """Mark documents for deletion and dispatch Celery to delete docs first, then folders.""" try: folder = await session.get(Folder, folder_id) @@ -375,7 +385,7 @@ async def delete_folder( await check_permission( session, - user, + auth, folder.search_space_id, Permission.DOCUMENTS_DELETE.value, "You don't have permission to delete folders in this search space", @@ -439,8 +449,9 @@ async def move_document( document_id: int, request: DocumentMove, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """Move a document to a folder (or root). Requires DOCUMENTS_UPDATE permission.""" try: result = await session.execute( @@ -452,7 +463,7 @@ async def move_document( await check_permission( session, - user, + auth, document.search_space_id, Permission.DOCUMENTS_UPDATE.value, "You don't have permission to move documents in this search space", @@ -485,8 +496,9 @@ async def move_document( async def bulk_move_documents( request: BulkDocumentMove, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """Move multiple documents to a folder (or root). Requires DOCUMENTS_UPDATE permission.""" try: if not request.document_ids: @@ -504,7 +516,7 @@ async def bulk_move_documents( for ss_id in search_space_ids: await check_permission( session, - user, + auth, ss_id, Permission.DOCUMENTS_UPDATE.value, "You don't have permission to move documents in this search space", diff --git a/surfsense_backend/app/routes/gateway_webhook_routes.py b/surfsense_backend/app/routes/gateway_webhook_routes.py index 9b4af4b83..0d05f4baf 100644 --- a/surfsense_backend/app/routes/gateway_webhook_routes.py +++ b/surfsense_backend/app/routes/gateway_webhook_routes.py @@ -20,6 +20,7 @@ from sqlalchemy import or_, select from sqlalchemy.ext.asyncio import AsyncSession from starlette.responses import JSONResponse, RedirectResponse, Response +from app.auth.context import AuthContext from app.config import config from app.db import ( ExternalChatAccount, @@ -51,7 +52,7 @@ from app.observability.metrics import ( record_gateway_inbox_write, record_gateway_webhook_parse_error, ) -from app.users import current_active_user +from app.users import get_auth_context from app.utils.oauth_security import OAuthStateManager, TokenEncryption from app.utils.rbac import check_search_space_access @@ -250,14 +251,15 @@ def _telegram_message(payload: dict[str, Any]) -> dict[str, Any] | None: @router.get("/slack/install") async def install_slack_gateway( search_space_id: int, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), session: AsyncSession = Depends(get_async_session), ) -> dict[str, str]: + user = auth.user if not _slack_gateway_enabled(): raise HTTPException( status_code=500, detail="Slack gateway OAuth is not configured" ) - await check_search_space_access(session, user, search_space_id) + await check_search_space_access(session, auth, search_space_id) state = _get_state_manager().generate_secure_state(search_space_id, user.id) auth_params = { "client_id": config.GATEWAY_SLACK_CLIENT_ID, @@ -409,14 +411,15 @@ async def slack_gateway_callback( @router.get("/discord/install") async def install_discord_gateway( search_space_id: int, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), session: AsyncSession = Depends(get_async_session), ) -> dict[str, str]: + user = auth.user if not _discord_gateway_enabled(): raise HTTPException( status_code=500, detail="Discord gateway OAuth is not configured" ) - await check_search_space_access(session, user, search_space_id) + await check_search_space_access(session, auth, search_space_id) state = _get_state_manager().generate_secure_state(search_space_id, user.id) auth_params = { "client_id": config.DISCORD_CLIENT_ID, @@ -712,10 +715,11 @@ async def telegram_webhook( @router.post("/bindings/start", response_model=StartBindingResponse) async def start_binding( body: StartBindingRequest, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), session: AsyncSession = Depends(get_async_session), ) -> StartBindingResponse: - await check_search_space_access(session, user, body.search_space_id) + user = auth.user + await check_search_space_access(session, auth, body.search_space_id) code = generate_pairing_code() if body.platform == ExternalChatPlatform.TELEGRAM: if not _telegram_gateway_enabled(): @@ -774,9 +778,10 @@ async def start_binding( @router.get("/bindings") async def list_bindings( - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), session: AsyncSession = Depends(get_async_session), ) -> list[dict[str, Any]]: + user = auth.user result = await session.execute( select(ExternalChatBinding, ExternalChatAccount) .join( @@ -803,9 +808,10 @@ async def list_bindings( @router.get("/connections") async def list_connections( platform: ExternalChatPlatform | None = None, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), session: AsyncSession = Depends(get_async_session), ) -> list[dict[str, Any]]: + user = auth.user active_whatsapp_mode = _active_whatsapp_account_mode() if platform == ExternalChatPlatform.WHATSAPP and active_whatsapp_mode is None: return [] @@ -946,9 +952,10 @@ async def list_connections( @router.get("/platforms") async def list_platforms( - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), session: AsyncSession = Depends(get_async_session), ) -> list[dict[str, Any]]: + user = auth.user result = await session.execute( select(ExternalChatAccount).where( (ExternalChatAccount.owner_user_id == user.id) @@ -970,8 +977,9 @@ async def list_platforms( @config_router.get("/config") async def get_gateway_config( - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ) -> dict[str, bool | str]: + user = auth.user if not config.GATEWAY_ENABLED: return { "enabled": False, @@ -993,9 +1001,10 @@ async def get_gateway_config( async def update_binding_search_space( binding_id: int, body: UpdateBindingSearchSpaceRequest, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), session: AsyncSession = Depends(get_async_session), ) -> dict[str, bool]: + user = auth.user binding = await session.get(ExternalChatBinding, binding_id) if binding is None or binding.user_id != user.id: raise HTTPException(status_code=404, detail="Binding not found") @@ -1010,7 +1019,7 @@ async def update_binding_search_space( if account is None or _is_inactive_whatsapp_account(account): raise HTTPException(status_code=404, detail="Binding not found") - await check_search_space_access(session, user, body.search_space_id) + await check_search_space_access(session, auth, body.search_space_id) if binding.search_space_id != body.search_space_id: binding.search_space_id = body.search_space_id binding.new_chat_thread_id = None @@ -1023,9 +1032,10 @@ async def update_binding_search_space( async def update_gateway_account_search_space( account_id: int, body: UpdateAccountSearchSpaceRequest, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), session: AsyncSession = Depends(get_async_session), ) -> dict[str, bool]: + user = auth.user account = await session.get(ExternalChatAccount, account_id) if ( account is None @@ -1036,7 +1046,7 @@ async def update_gateway_account_search_space( ): raise HTTPException(status_code=404, detail="Gateway account not found") - await check_search_space_access(session, user, body.search_space_id) + await check_search_space_access(session, auth, body.search_space_id) account.owner_search_space_id = body.search_space_id account.updated_at = datetime.now(UTC) @@ -1061,9 +1071,10 @@ async def update_gateway_account_search_space( @router.delete("/bindings/{binding_id}") async def delete_binding( binding_id: int, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), session: AsyncSession = Depends(get_async_session), ) -> dict[str, bool]: + user = auth.user binding = await session.get(ExternalChatBinding, binding_id) if binding is None or binding.user_id != user.id: raise HTTPException(status_code=404, detail="Binding not found") @@ -1078,9 +1089,10 @@ async def delete_binding( @router.delete("/accounts/{account_id}") async def delete_gateway_account( account_id: int, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), session: AsyncSession = Depends(get_async_session), ) -> dict[str, bool]: + user = auth.user account = await session.get(ExternalChatAccount, account_id) if ( account is None @@ -1114,9 +1126,10 @@ async def delete_gateway_account( @router.post("/bindings/{binding_id}/resume") async def resume_external_chat_binding( binding_id: int, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), session: AsyncSession = Depends(get_async_session), ) -> dict[str, bool]: + user = auth.user binding = await session.get(ExternalChatBinding, binding_id) if binding is None or binding.user_id != user.id: raise HTTPException(status_code=404, detail="Binding not found") diff --git a/surfsense_backend/app/routes/gateway_whatsapp_baileys_routes.py b/surfsense_backend/app/routes/gateway_whatsapp_baileys_routes.py index 1fcf5c438..370b1cc8d 100644 --- a/surfsense_backend/app/routes/gateway_whatsapp_baileys_routes.py +++ b/surfsense_backend/app/routes/gateway_whatsapp_baileys_routes.py @@ -10,6 +10,7 @@ from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from app.auth.context import AuthContext from app.config import config from app.db import ( ExternalChatAccount, @@ -20,7 +21,7 @@ from app.db import ( get_async_session, ) from app.gateway.whatsapp.adapter_baileys import WhatsAppBaileysAdapter -from app.users import current_active_user +from app.users import get_auth_context from app.utils.rbac import check_search_space_access router = APIRouter(prefix="/gateway/whatsapp/baileys", tags=["gateway"]) @@ -60,11 +61,12 @@ async def _get_user_whatsapp_account( @router.post("/pair") async def request_pairing_code( body: BaileysPairRequest, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), session: AsyncSession = Depends(get_async_session), ) -> dict[str, Any]: + user = auth.user _ensure_baileys_enabled() - await check_search_space_access(session, user, body.search_space_id) + await check_search_space_access(session, auth, body.search_space_id) adapter = WhatsAppBaileysAdapter() try: pairing = await adapter.request_pairing_code(phone_number=body.phone_number) @@ -97,8 +99,9 @@ async def request_pairing_code( @router.get("/health") async def bridge_health( - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ) -> dict[str, Any]: + user = auth.user _ensure_baileys_enabled() adapter = WhatsAppBaileysAdapter() try: diff --git a/surfsense_backend/app/routes/google_calendar_add_connector_route.py b/surfsense_backend/app/routes/google_calendar_add_connector_route.py index a143fd50d..8789287b8 100644 --- a/surfsense_backend/app/routes/google_calendar_add_connector_route.py +++ b/surfsense_backend/app/routes/google_calendar_add_connector_route.py @@ -15,15 +15,15 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm.attributes import flag_modified +from app.auth.context import AuthContext from app.config import config from app.connectors.google_gmail_connector import fetch_google_user_email from app.db import ( SearchSourceConnector, SearchSourceConnectorType, - User, get_async_session, ) -from app.users import current_active_user +from app.users import require_session_context from app.utils.connector_naming import ( check_duplicate_connector, generate_unique_connector_name, @@ -88,7 +88,11 @@ def get_google_flow(): @router.get("/auth/google/calendar/connector/add") -async def connect_calendar(space_id: int, user: User = Depends(current_active_user)): +async def connect_calendar( + space_id: int, + auth: AuthContext = Depends(require_session_context), +): + user = auth.user try: if not space_id: raise HTTPException(status_code=400, detail="space_id is required") @@ -127,10 +131,11 @@ async def reauth_calendar( space_id: int, connector_id: int, return_url: str | None = None, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), session: AsyncSession = Depends(get_async_session), ): """Initiate Google Calendar re-authentication for an existing connector.""" + user = auth.user try: result = await session.execute( select(SearchSourceConnector).filter( diff --git a/surfsense_backend/app/routes/google_drive_add_connector_route.py b/surfsense_backend/app/routes/google_drive_add_connector_route.py index 8706326b7..c97c82eb0 100644 --- a/surfsense_backend/app/routes/google_drive_add_connector_route.py +++ b/surfsense_backend/app/routes/google_drive_add_connector_route.py @@ -23,6 +23,7 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select +from app.auth.context import AuthContext from app.config import config from app.connectors.google_drive import ( GoogleDriveClient, @@ -33,10 +34,9 @@ from app.connectors.google_gmail_connector import fetch_google_user_email from app.db import ( SearchSourceConnector, SearchSourceConnectorType, - User, get_async_session, ) -from app.users import current_active_user +from app.users import get_auth_context, require_session_context from app.utils.connector_naming import ( check_duplicate_connector, generate_unique_connector_name, @@ -46,6 +46,7 @@ from app.utils.oauth_security import ( TokenEncryption, generate_code_verifier, ) +from app.utils.rbac import check_search_space_access # Relax token scope validation for Google OAuth os.environ["OAUTHLIB_RELAX_TOKEN_SCOPE"] = "1" @@ -110,7 +111,10 @@ def get_google_flow(): @router.get("/auth/google/drive/connector/add") -async def connect_drive(space_id: int, user: User = Depends(current_active_user)): +async def connect_drive( + space_id: int, + auth: AuthContext = Depends(require_session_context), +): """ Initiate Google Drive OAuth flow. @@ -120,6 +124,7 @@ async def connect_drive(space_id: int, user: User = Depends(current_active_user) Returns: JSON with auth_url to redirect user to Google authorization """ + user = auth.user try: if not space_id: raise HTTPException(status_code=400, detail="space_id is required") @@ -165,7 +170,7 @@ async def reauth_drive( space_id: int, connector_id: int, return_url: str | None = None, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), session: AsyncSession = Depends(get_async_session), ): """ @@ -178,6 +183,7 @@ async def reauth_drive( Returns: JSON with auth_url to redirect user to Google authorization """ + user = auth.user try: result = await session.execute( select(SearchSourceConnector).filter( @@ -472,7 +478,7 @@ async def list_google_drive_folders( connector_id: int, parent_id: str | None = None, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): """ List folders AND files in user's Google Drive with hierarchical support. @@ -492,6 +498,7 @@ async def list_google_drive_folders( ] } """ + user = auth.user try: # Get connector and verify ownership result = await session.execute( @@ -510,6 +517,8 @@ async def list_google_drive_folders( detail="Google Drive connector not found or access denied", ) + await check_search_space_access(session, auth, connector.search_space_id) + # Initialize Drive client (credentials will be loaded on first API call) drive_client = GoogleDriveClient(session, connector_id) diff --git a/surfsense_backend/app/routes/google_gmail_add_connector_route.py b/surfsense_backend/app/routes/google_gmail_add_connector_route.py index 9b807a556..82475c792 100644 --- a/surfsense_backend/app/routes/google_gmail_add_connector_route.py +++ b/surfsense_backend/app/routes/google_gmail_add_connector_route.py @@ -15,15 +15,15 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm.attributes import flag_modified +from app.auth.context import AuthContext from app.config import config from app.connectors.google_gmail_connector import fetch_google_user_email from app.db import ( SearchSourceConnector, SearchSourceConnectorType, - User, get_async_session, ) -from app.users import current_active_user +from app.users import require_session_context from app.utils.connector_naming import ( check_duplicate_connector, generate_unique_connector_name, @@ -92,7 +92,10 @@ def get_google_flow(): @router.get("/auth/google/gmail/connector/add") -async def connect_gmail(space_id: int, user: User = Depends(current_active_user)): +async def connect_gmail( + space_id: int, + auth: AuthContext = Depends(require_session_context), +): """ Initiate Google Gmail OAuth flow. @@ -102,6 +105,7 @@ async def connect_gmail(space_id: int, user: User = Depends(current_active_user) Returns: JSON with auth_url to redirect user to Google authorization """ + user = auth.user try: if not space_id: raise HTTPException(status_code=400, detail="space_id is required") @@ -145,10 +149,11 @@ async def reauth_gmail( space_id: int, connector_id: int, return_url: str | None = None, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), session: AsyncSession = Depends(get_async_session), ): """Initiate Gmail re-authentication for an existing connector.""" + user = auth.user try: result = await session.execute( select(SearchSourceConnector).filter( diff --git a/surfsense_backend/app/routes/image_generation_routes.py b/surfsense_backend/app/routes/image_generation_routes.py index 5bd058cb1..9376c8f0f 100644 --- a/surfsense_backend/app/routes/image_generation_routes.py +++ b/surfsense_backend/app/routes/image_generation_routes.py @@ -16,6 +16,7 @@ from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload +from app.auth.context import AuthContext from app.config import config from app.db import ( ImageGeneration, @@ -46,7 +47,7 @@ from app.services.image_gen_router_service import ( ) from app.services.model_capabilities import has_capability from app.services.model_resolver import to_litellm -from app.users import current_active_user +from app.users import get_auth_context from app.utils.rbac import check_permission from app.utils.signed_image_urls import verify_image_token @@ -245,8 +246,9 @@ async def _execute_image_generation( async def create_image_generation( data: ImageGenerationCreate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """Create and execute an image generation request. Premium configs are gated by the user's shared premium credit pool. @@ -270,7 +272,7 @@ async def create_image_generation( try: await check_permission( session, - user, + auth, data.search_space_id, Permission.IMAGE_GENERATIONS_CREATE.value, "You don't have permission to create image generations in this search space", @@ -365,8 +367,9 @@ async def list_image_generations( skip: int = 0, limit: int = 50, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """List image generations.""" if skip < 0 or limit < 1: raise HTTPException(status_code=400, detail="Invalid pagination parameters") @@ -377,7 +380,7 @@ async def list_image_generations( if search_space_id is not None: await check_permission( session, - user, + auth, search_space_id, Permission.IMAGE_GENERATIONS_READ.value, "You don't have permission to read image generations in this search space", @@ -417,8 +420,9 @@ async def list_image_generations( async def get_image_generation( image_gen_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """Get a specific image generation by ID.""" try: result = await session.execute( @@ -430,7 +434,7 @@ async def get_image_generation( await check_permission( session, - user, + auth, image_gen.search_space_id, Permission.IMAGE_GENERATIONS_READ.value, "You don't have permission to read image generations in this search space", @@ -449,8 +453,9 @@ async def get_image_generation( async def delete_image_generation( image_gen_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """Delete an image generation record.""" try: result = await session.execute( @@ -462,7 +467,7 @@ async def delete_image_generation( await check_permission( session, - user, + auth, db_image_gen.search_space_id, Permission.IMAGE_GENERATIONS_DELETE.value, "You don't have permission to delete image generations in this search space", diff --git a/surfsense_backend/app/routes/incentive_tasks_routes.py b/surfsense_backend/app/routes/incentive_tasks_routes.py index 1dae09a2d..2635df42f 100644 --- a/surfsense_backend/app/routes/incentive_tasks_routes.py +++ b/surfsense_backend/app/routes/incentive_tasks_routes.py @@ -8,10 +8,10 @@ from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from app.auth.context import AuthContext from app.db import ( INCENTIVE_TASKS_CONFIG, IncentiveTaskType, - User, UserIncentiveTask, get_async_session, ) @@ -21,19 +21,20 @@ from app.schemas.incentive_tasks import ( IncentiveTasksResponse, TaskAlreadyCompletedResponse, ) -from app.users import current_active_user +from app.users import require_session_context router = APIRouter(prefix="/incentive-tasks", tags=["incentive-tasks"]) @router.get("", response_model=IncentiveTasksResponse) async def get_incentive_tasks( - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), session: AsyncSession = Depends(get_async_session), ) -> IncentiveTasksResponse: """ Get all available incentive tasks with the user's completion status. """ + user = auth.user # Get all completed tasks for this user result = await session.execute( select(UserIncentiveTask).where(UserIncentiveTask.user_id == user.id) @@ -75,7 +76,7 @@ async def get_incentive_tasks( ) async def complete_task( task_type: IncentiveTaskType, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), session: AsyncSession = Depends(get_async_session), ) -> CompleteTaskResponse | TaskAlreadyCompletedResponse: """ @@ -84,6 +85,7 @@ async def complete_task( Each task can only be completed once. If the task was already completed, returns the existing completion information without awarding additional credit. """ + user = auth.user # Validate task type exists in config task_config = INCENTIVE_TASKS_CONFIG.get(task_type) if not task_config: diff --git a/surfsense_backend/app/routes/jira_add_connector_route.py b/surfsense_backend/app/routes/jira_add_connector_route.py index eeb4f91d9..c29d0609b 100644 --- a/surfsense_backend/app/routes/jira_add_connector_route.py +++ b/surfsense_backend/app/routes/jira_add_connector_route.py @@ -16,15 +16,15 @@ from pydantic import ValidationError from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession +from app.auth.context import AuthContext from app.config import config from app.db import ( SearchSourceConnector, SearchSourceConnectorType, - User, get_async_session, ) from app.schemas.atlassian_auth_credentials import AtlassianAuthCredentialsBase -from app.users import current_active_user +from app.users import require_session_context from app.utils.connector_naming import ( check_duplicate_connector, extract_identifier_from_credentials, @@ -75,7 +75,10 @@ def get_token_encryption() -> TokenEncryption: @router.get("/auth/jira/connector/add") -async def connect_jira(space_id: int, user: User = Depends(current_active_user)): +async def connect_jira( + space_id: int, + auth: AuthContext = Depends(require_session_context), +): """ Initiate Jira OAuth flow. @@ -86,6 +89,7 @@ async def connect_jira(space_id: int, user: User = Depends(current_active_user)) Returns: Authorization URL for redirect """ + user = auth.user try: if not space_id: raise HTTPException(status_code=400, detail="space_id is required") @@ -438,10 +442,11 @@ async def reauth_jira( space_id: int, connector_id: int, return_url: str | None = None, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), session: AsyncSession = Depends(get_async_session), ): """Initiate Jira re-authentication to upgrade OAuth scopes.""" + user = auth.user try: from sqlalchemy.future import select diff --git a/surfsense_backend/app/routes/linear_add_connector_route.py b/surfsense_backend/app/routes/linear_add_connector_route.py index f59c17d25..1d7cc172f 100644 --- a/surfsense_backend/app/routes/linear_add_connector_route.py +++ b/surfsense_backend/app/routes/linear_add_connector_route.py @@ -17,16 +17,16 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm.attributes import flag_modified +from app.auth.context import AuthContext from app.config import config from app.connectors.linear_connector import fetch_linear_organization_name from app.db import ( SearchSourceConnector, SearchSourceConnectorType, - User, get_async_session, ) from app.schemas.linear_auth_credentials import LinearAuthCredentialsBase -from app.users import current_active_user +from app.users import require_session_context from app.utils.connector_naming import ( check_duplicate_connector, generate_unique_connector_name, @@ -79,7 +79,10 @@ def make_basic_auth_header(client_id: str, client_secret: str) -> str: @router.get("/auth/linear/connector/add") -async def connect_linear(space_id: int, user: User = Depends(current_active_user)): +async def connect_linear( + space_id: int, + auth: AuthContext = Depends(require_session_context), +): """ Initiate Linear OAuth flow. @@ -90,6 +93,7 @@ async def connect_linear(space_id: int, user: User = Depends(current_active_user Returns: Authorization URL for redirect """ + user = auth.user try: if not space_id: raise HTTPException(status_code=400, detail="space_id is required") @@ -134,10 +138,11 @@ async def reauth_linear( space_id: int, connector_id: int, return_url: str | None = None, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), session: AsyncSession = Depends(get_async_session), ): """Initiate Linear re-authentication for an existing connector.""" + user = auth.user try: result = await session.execute( select(SearchSourceConnector).filter( diff --git a/surfsense_backend/app/routes/logs_routes.py b/surfsense_backend/app/routes/logs_routes.py index b82e02077..16400ef0b 100644 --- a/surfsense_backend/app/routes/logs_routes.py +++ b/surfsense_backend/app/routes/logs_routes.py @@ -5,6 +5,7 @@ from sqlalchemy import and_, desc from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select +from app.auth.context import AuthContext from app.db import ( Log, LogLevel, @@ -16,7 +17,7 @@ from app.db import ( get_async_session, ) from app.schemas import LogCreate, LogRead, LogUpdate -from app.users import current_active_user +from app.users import get_auth_context from app.utils.rbac import check_permission router = APIRouter() @@ -26,8 +27,9 @@ router = APIRouter() async def create_log( log: LogCreate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Create a new log entry. Note: This is typically called internally. Requires LOGS_READ permission (since logs are usually system-generated). @@ -36,7 +38,7 @@ async def create_log( # Check if the user has access to the search space await check_permission( session, - user, + auth, log.search_space_id, Permission.LOGS_READ.value, "You don't have permission to access logs in this search space", @@ -67,8 +69,9 @@ async def read_logs( start_date: datetime | None = None, end_date: datetime | None = None, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Get logs with optional filtering. Requires LOGS_READ permission for the search space(s). @@ -81,7 +84,7 @@ async def read_logs( # Check permission for specific search space await check_permission( session, - user, + auth, search_space_id, Permission.LOGS_READ.value, "You don't have permission to read logs in this search space", @@ -136,8 +139,9 @@ async def read_logs( async def read_log( log_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Get a specific log by ID. Requires LOGS_READ permission for the search space. @@ -152,7 +156,7 @@ async def read_log( # Check permission for the search space await check_permission( session, - user, + auth, log.search_space_id, Permission.LOGS_READ.value, "You don't have permission to read logs in this search space", @@ -172,8 +176,9 @@ async def update_log( log_id: int, log_update: LogUpdate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Update a log entry. Requires LOGS_READ permission (logs are typically updated by system). @@ -188,7 +193,7 @@ async def update_log( # Check permission for the search space await check_permission( session, - user, + auth, db_log.search_space_id, Permission.LOGS_READ.value, "You don't have permission to access logs in this search space", @@ -215,8 +220,9 @@ async def update_log( async def delete_log( log_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Delete a log entry. Requires LOGS_DELETE permission for the search space. @@ -231,7 +237,7 @@ async def delete_log( # Check permission for the search space await check_permission( session, - user, + auth, db_log.search_space_id, Permission.LOGS_DELETE.value, "You don't have permission to delete logs in this search space", @@ -254,8 +260,9 @@ async def get_logs_summary( search_space_id: int, hours: int = 24, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Get a summary of logs for a search space in the last X hours. Requires LOGS_READ permission for the search space. @@ -264,7 +271,7 @@ async def get_logs_summary( # Check permission await check_permission( session, - user, + auth, search_space_id, Permission.LOGS_READ.value, "You don't have permission to read logs in this search space", diff --git a/surfsense_backend/app/routes/luma_add_connector_route.py b/surfsense_backend/app/routes/luma_add_connector_route.py index 7040581bc..9a6f18940 100644 --- a/surfsense_backend/app/routes/luma_add_connector_route.py +++ b/surfsense_backend/app/routes/luma_add_connector_route.py @@ -6,13 +6,13 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select +from app.auth.context import AuthContext from app.db import ( SearchSourceConnector, SearchSourceConnectorType, - User, get_async_session, ) -from app.users import current_active_user +from app.users import require_session_context logger = logging.getLogger(__name__) @@ -29,7 +29,7 @@ class AddLumaConnectorRequest(BaseModel): @router.post("/connectors/luma/add") async def add_luma_connector( request: AddLumaConnectorRequest, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), session: AsyncSession = Depends(get_async_session), ): """ @@ -46,6 +46,7 @@ async def add_luma_connector( Raises: HTTPException: If connector already exists or validation fails """ + user = auth.user try: # Check if a Luma connector already exists for this search space and user result = await session.execute( @@ -118,7 +119,7 @@ async def add_luma_connector( @router.delete("/connectors/luma") async def delete_luma_connector( space_id: int, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), session: AsyncSession = Depends(get_async_session), ): """ @@ -135,6 +136,7 @@ async def delete_luma_connector( Raises: HTTPException: If connector doesn't exist """ + user = auth.user try: result = await session.execute( select(SearchSourceConnector).filter( @@ -173,7 +175,7 @@ async def delete_luma_connector( @router.get("/connectors/luma/test") async def test_luma_connector( space_id: int, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), session: AsyncSession = Depends(get_async_session), ): """ @@ -190,6 +192,7 @@ async def test_luma_connector( Raises: HTTPException: If connector doesn't exist or test fails """ + user = auth.user try: # Get the Luma connector for this search space and user result = await session.execute( diff --git a/surfsense_backend/app/routes/mcp_oauth_route.py b/surfsense_backend/app/routes/mcp_oauth_route.py index fdeb6ecfd..dbeb8738c 100644 --- a/surfsense_backend/app/routes/mcp_oauth_route.py +++ b/surfsense_backend/app/routes/mcp_oauth_route.py @@ -20,14 +20,14 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm.attributes import flag_modified +from app.auth.context import AuthContext from app.config import config from app.db import ( SearchSourceConnector, SearchSourceConnectorType, - User, get_async_session, ) -from app.users import current_active_user +from app.users import require_session_context from app.utils.connector_naming import generate_unique_connector_name from app.utils.oauth_security import ( OAuthStateManager, @@ -164,8 +164,9 @@ def _frontend_redirect( async def connect_mcp_service( service: str, space_id: int, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), ): + user = auth.user from app.services.mcp_oauth.registry import get_service svc = get_service(service) @@ -523,9 +524,10 @@ async def reauth_mcp_service( space_id: int, connector_id: int, return_url: str | None = None, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), session: AsyncSession = Depends(get_async_session), ): + user = auth.user from app.services.mcp_oauth.registry import get_service svc = get_service(service) diff --git a/surfsense_backend/app/routes/memory_routes.py b/surfsense_backend/app/routes/memory_routes.py index 8e73a277c..d2a82a81c 100644 --- a/surfsense_backend/app/routes/memory_routes.py +++ b/surfsense_backend/app/routes/memory_routes.py @@ -6,7 +6,8 @@ from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel from sqlalchemy.ext.asyncio import AsyncSession -from app.db import User, get_async_session +from app.auth.context import AuthContext +from app.db import get_async_session from app.services.memory import ( MemoryRead, MemoryScope, @@ -15,7 +16,7 @@ from app.services.memory import ( reset_memory, save_memory, ) -from app.users import current_active_user +from app.users import require_session_context router = APIRouter() @@ -26,9 +27,10 @@ class MemoryUpdate(BaseModel): @router.get("/users/me/memory", response_model=MemoryRead) async def get_user_memory( - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), session: AsyncSession = Depends(get_async_session), ): + user = auth.user memory_md = await read_memory( scope=MemoryScope.USER, target_id=user.id, @@ -40,9 +42,10 @@ async def get_user_memory( @router.put("/users/me/memory", response_model=MemoryRead) async def update_user_memory( body: MemoryUpdate, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), session: AsyncSession = Depends(get_async_session), ): + user = auth.user result = await save_memory( scope=MemoryScope.USER, target_id=user.id, @@ -56,9 +59,10 @@ async def update_user_memory( @router.post("/users/me/memory/reset", response_model=MemoryRead) async def reset_user_memory( - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), session: AsyncSession = Depends(get_async_session), ): + user = auth.user result = await reset_memory( scope=MemoryScope.USER, target_id=user.id, diff --git a/surfsense_backend/app/routes/model_connections_routes.py b/surfsense_backend/app/routes/model_connections_routes.py index 4d32a32af..d75e1de79 100644 --- a/surfsense_backend/app/routes/model_connections_routes.py +++ b/surfsense_backend/app/routes/model_connections_routes.py @@ -5,6 +5,7 @@ from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload +from app.auth.context import AuthContext from app.config import config from app.db import ( Connection, @@ -14,7 +15,6 @@ from app.db import ( NewChatThread, Permission, SearchSpace, - User, get_async_session, ) from app.schemas import ( @@ -42,7 +42,7 @@ from app.services.model_connection_service import ( verify_connection, ) from app.services.provider_registry import REGISTRY -from app.users import current_active_user +from app.users import get_auth_context, require_session_context from app.utils.rbac import check_permission router = APIRouter() @@ -257,8 +257,8 @@ async def _default_unset_roles( @router.get("/model-providers", response_model=list[ModelProviderRead]) -async def list_model_providers(user: User = Depends(current_active_user)): - del user +async def list_model_providers(auth: AuthContext = Depends(require_session_context)): + del auth local_only = {"ollama_chat", "lm_studio"} return [ ModelProviderRead( @@ -298,14 +298,16 @@ async def _load_connection(session: AsyncSession, connection_id: int) -> Connect async def _assert_connection_access( session: AsyncSession, - user: User, + auth: AuthContext, conn: Connection, permission: str = Permission.LLM_CONFIGS_CREATE.value, + allow_spaceless_pat: bool = False, ) -> None: + user = auth.user if conn.search_space_id: await check_permission( session, - user, + auth, conn.search_space_id, permission, "You don't have permission to manage model connections in this search space", @@ -315,17 +317,22 @@ async def _assert_connection_access( raise HTTPException( status_code=403, detail="Connection does not belong to user" ) + if auth.is_gated and not allow_spaceless_pat: + raise HTTPException( + status_code=403, + detail="Managing personal model connections requires an interactive session", + ) @router.get("/global-llm-config-status") -async def global_llm_config_status(user: User = Depends(current_active_user)): - del user +async def global_llm_config_status(auth: AuthContext = Depends(require_session_context)): + del auth return {"exists": config.GLOBAL_LLM_CONFIG_FILE_EXISTS} @router.get("/global-model-connections", response_model=list[ConnectionRead]) -async def list_global_connections(user: User = Depends(current_active_user)): - del user +async def list_global_connections(auth: AuthContext = Depends(require_session_context)): + del auth models_by_connection: dict[int, list[dict]] = {} for model in config.GLOBAL_MODELS: models_by_connection.setdefault(model["connection_id"], []).append(model) @@ -339,13 +346,14 @@ async def list_global_connections(user: User = Depends(current_active_user)): async def list_connections( search_space_id: int | None = None, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user stmt = select(Connection).options(selectinload(Connection.models)) if search_space_id is not None: await check_permission( session, - user, + auth, search_space_id, Permission.LLM_CONFIGS_CREATE.value, "You don't have permission to view model connections in this search space", @@ -363,8 +371,9 @@ async def list_connections( async def create_connection( data: ConnectionCreate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user if data.scope == ConnectionScope.GLOBAL: raise HTTPException(status_code=400, detail="GLOBAL connections are YAML-only") if data.scope == ConnectionScope.SEARCH_SPACE: @@ -372,11 +381,16 @@ async def create_connection( raise HTTPException(status_code=400, detail="search_space_id is required") await check_permission( session, - user, + auth, data.search_space_id, Permission.LLM_CONFIGS_CREATE.value, "You don't have permission to create model connections in this search space", ) + elif auth.is_gated: + raise HTTPException( + status_code=403, + detail="Managing personal model connections requires an interactive session", + ) payload = data.model_dump(exclude={"search_space_id", "models"}) conn = Connection( @@ -411,16 +425,22 @@ async def create_connection( async def preview_connection_models( data: ConnectionCreate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user if data.scope == ConnectionScope.SEARCH_SPACE and data.search_space_id is not None: await check_permission( session, - user, + auth, data.search_space_id, Permission.LLM_CONFIGS_CREATE.value, "You don't have permission to create model connections in this search space", ) + elif auth.is_gated: + raise HTTPException( + status_code=403, + detail="Testing personal model connections requires an interactive session", + ) draft = Connection( provider=data.provider, @@ -445,16 +465,22 @@ async def preview_connection_models( async def test_preview_connection_model( data: ModelTestPreview, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user if data.scope == ConnectionScope.SEARCH_SPACE and data.search_space_id is not None: await check_permission( session, - user, + auth, data.search_space_id, Permission.LLM_CONFIGS_CREATE.value, "You don't have permission to create model connections in this search space", ) + elif auth.is_gated: + raise HTTPException( + status_code=403, + detail="Testing personal model connections requires an interactive session", + ) model_id = data.model_id.strip() if not model_id: @@ -491,11 +517,11 @@ async def update_connection( connection_id: int, data: ConnectionUpdate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): conn = await _load_connection(session, connection_id) await _assert_connection_access( - session, user, conn, Permission.LLM_CONFIGS_UPDATE.value + session, auth, conn, Permission.LLM_CONFIGS_UPDATE.value ) search_space_id = conn.search_space_id for key, value in data.model_dump(exclude_unset=True).items(): @@ -512,11 +538,11 @@ async def update_connection( async def delete_connection( connection_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): conn = await _load_connection(session, connection_id) await _assert_connection_access( - session, user, conn, Permission.LLM_CONFIGS_DELETE.value + session, auth, conn, Permission.LLM_CONFIGS_DELETE.value ) search_space_id = conn.search_space_id await session.delete(conn) @@ -533,11 +559,11 @@ async def delete_connection( async def verify_model_connection( connection_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): conn = await _load_connection(session, connection_id) await _assert_connection_access( - session, user, conn, Permission.LLM_CONFIGS_CREATE.value + session, auth, conn, Permission.LLM_CONFIGS_CREATE.value ) result = await verify_connection(conn) return VerifyConnectionResponse( @@ -551,11 +577,11 @@ async def verify_model_connection( async def discover_connection_models( connection_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): conn = await _load_connection(session, connection_id) await _assert_connection_access( - session, user, conn, Permission.LLM_CONFIGS_CREATE.value + session, auth, conn, Permission.LLM_CONFIGS_CREATE.value ) try: discovered = await discover_models(conn) @@ -595,11 +621,11 @@ async def add_manual_model( connection_id: int, data: ModelCreate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): conn = await _load_connection(session, connection_id) await _assert_connection_access( - session, user, conn, Permission.LLM_CONFIGS_UPDATE.value + session, auth, conn, Permission.LLM_CONFIGS_UPDATE.value ) model_id = data.model_id.strip() @@ -640,11 +666,11 @@ async def bulk_update_models( connection_id: int, data: ModelsBulkUpdate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): conn = await _load_connection(session, connection_id) await _assert_connection_access( - session, user, conn, Permission.LLM_CONFIGS_UPDATE.value + session, auth, conn, Permission.LLM_CONFIGS_UPDATE.value ) search_space_id = conn.search_space_id @@ -674,7 +700,7 @@ async def update_model( model_id: int, data: ModelUpdate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): result = await session.execute( select(Model) @@ -685,7 +711,7 @@ async def update_model( if not model: raise HTTPException(status_code=404, detail="Model not found") await _assert_connection_access( - session, user, model.connection, Permission.LLM_CONFIGS_UPDATE.value + session, auth, model.connection, Permission.LLM_CONFIGS_UPDATE.value ) search_space_id = model.connection.search_space_id update = data.model_dump(exclude_unset=True) @@ -704,7 +730,7 @@ async def update_model( async def test_connection_model( model_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): result = await session.execute( select(Model) @@ -715,7 +741,7 @@ async def test_connection_model( if not model: raise HTTPException(status_code=404, detail="Model not found") await _assert_connection_access( - session, user, model.connection, Permission.LLM_CONFIGS_UPDATE.value + session, auth, model.connection, Permission.LLM_CONFIGS_UPDATE.value ) result = await test_model(model.connection, model) await session.commit() @@ -730,11 +756,11 @@ async def test_connection_model( async def get_model_roles( search_space_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): await check_permission( session, - user, + auth, search_space_id, Permission.LLM_CONFIGS_CREATE.value, "You don't have permission to view model roles in this search space", @@ -756,11 +782,11 @@ async def update_model_roles( search_space_id: int, data: ModelRolesUpdate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): await check_permission( session, - user, + auth, search_space_id, Permission.LLM_CONFIGS_UPDATE.value, "You don't have permission to update model roles in this search space", diff --git a/surfsense_backend/app/routes/model_list_routes.py b/surfsense_backend/app/routes/model_list_routes.py index 79ae7221f..e2535f684 100644 --- a/surfsense_backend/app/routes/model_list_routes.py +++ b/surfsense_backend/app/routes/model_list_routes.py @@ -10,9 +10,9 @@ import logging from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel -from app.db import User +from app.auth.context import AuthContext from app.services.model_list_service import get_model_list -from app.users import current_active_user +from app.users import require_session_context router = APIRouter() logger = logging.getLogger(__name__) @@ -27,7 +27,7 @@ class ModelListItem(BaseModel): @router.get("/models", response_model=list[ModelListItem]) async def list_available_models( - user: User = Depends(current_active_user), + _auth: AuthContext = Depends(require_session_context), ): """ Return all available models grouped by provider. diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index b5bc2571e..c850c7eed 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -36,6 +36,7 @@ from app.agents.chat.multi_agent_chat.shared.filesystem_selection import ( FilesystemSelection, LocalFilesystemMount, ) +from app.auth.context import AuthContext from app.config import config from app.db import ( ChatComment, @@ -75,7 +76,7 @@ from app.tasks.chat.streaming.flows import ( stream_new_chat, stream_resume_chat, ) -from app.users import current_active_user +from app.users import get_auth_context from app.utils.perf import get_perf_logger from app.utils.rbac import check_permission from app.utils.user_message_multimodal import ( @@ -595,8 +596,9 @@ async def list_threads( search_space_id: int, limit: int | None = None, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ List all accessible threads for the current user in a search space. Returns threads and archived_threads for ThreadListPrimitive. @@ -615,7 +617,7 @@ async def list_threads( try: await check_permission( session, - user, + auth, search_space_id, Permission.CHATS_READ.value, "You don't have permission to read chats in this search space", @@ -702,8 +704,9 @@ async def search_threads( search_space_id: int, title: str, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Search accessible threads by title in a search space. @@ -721,7 +724,7 @@ async def search_threads( try: await check_permission( session, - user, + auth, search_space_id, Permission.CHATS_READ.value, "You don't have permission to read chats in this search space", @@ -794,8 +797,9 @@ async def search_threads( async def create_thread( thread: NewChatThreadCreate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Create a new chat thread. @@ -807,7 +811,7 @@ async def create_thread( try: await check_permission( session, - user, + auth, thread.search_space_id, Permission.CHATS_CREATE.value, "You don't have permission to create chats in this search space", @@ -852,8 +856,9 @@ async def create_thread( async def get_thread_messages( thread_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Get a thread with all its messages. This is used by ThreadHistoryAdapter.load() to restore conversation. @@ -877,7 +882,7 @@ async def get_thread_messages( # Check permission to read chats in this search space await check_permission( session, - user, + auth, thread.search_space_id, Permission.CHATS_READ.value, "You don't have permission to read chats in this search space", @@ -936,8 +941,9 @@ async def get_thread_messages( async def get_thread_full( thread_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Get full thread details with all messages. @@ -964,7 +970,7 @@ async def get_thread_full( await check_permission( session, - user, + auth, thread.search_space_id, Permission.CHATS_READ.value, "You don't have permission to read chats in this search space", @@ -1005,8 +1011,9 @@ async def update_thread( thread_id: int, thread_update: NewChatThreadUpdate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Update a thread (title, archived status). Used for renaming and archiving threads. @@ -1027,7 +1034,7 @@ async def update_thread( await check_permission( session, - user, + auth, db_thread.search_space_id, Permission.CHATS_UPDATE.value, "You don't have permission to update chats in this search space", @@ -1074,8 +1081,9 @@ async def update_thread( async def delete_thread( thread_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Delete a thread and all its messages. @@ -1095,7 +1103,7 @@ async def delete_thread( await check_permission( session, - user, + auth, db_thread.search_space_id, Permission.CHATS_DELETE.value, "You don't have permission to delete chats in this search space", @@ -1146,8 +1154,9 @@ async def update_thread_visibility( thread_id: int, visibility_update: NewChatThreadVisibilityUpdate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Update the visibility/sharing settings of a thread. @@ -1168,7 +1177,7 @@ async def update_thread_visibility( await check_permission( session, - user, + auth, db_thread.search_space_id, Permission.CHATS_UPDATE.value, "You don't have permission to update chats in this search space", @@ -1217,7 +1226,7 @@ async def update_thread_visibility( async def create_thread_snapshot( thread_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): """ Create a public snapshot of the thread. @@ -1229,7 +1238,7 @@ async def create_thread_snapshot( return await create_snapshot( session=session, thread_id=thread_id, - user=user, + auth=auth, ) @@ -1239,7 +1248,7 @@ async def create_thread_snapshot( async def list_thread_snapshots( thread_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): """ List all public snapshots for this thread. @@ -1252,7 +1261,7 @@ async def list_thread_snapshots( snapshots=await list_snapshots_for_thread( session=session, thread_id=thread_id, - user=user, + auth=auth, ) ) @@ -1262,7 +1271,7 @@ async def delete_thread_snapshot( thread_id: int, snapshot_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): """ Delete a specific snapshot. @@ -1275,7 +1284,7 @@ async def delete_thread_snapshot( session=session, thread_id=thread_id, snapshot_id=snapshot_id, - user=user, + auth=auth, ) return {"message": "Snapshot deleted successfully"} @@ -1290,8 +1299,9 @@ async def append_message( thread_id: int, request: Request, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ .. deprecated:: 2026-05 Replaced by the **SSE-based message ID handshake**. The streaming @@ -1321,8 +1331,8 @@ async def append_message( Requires CHATS_UPDATE permission. """ try: - # Capture ``user.id`` as a primitive UUID up front. The - # ``current_active_user`` dependency hands us a ``User`` ORM + # Capture ``user.id`` as a primitive UUID up front. The auth + # dependency hands us a ``User`` ORM # row bound to ``session``; if the outer ``except # IntegrityError`` block below ever fires (an unexpected # constraint like a foreign key violation — the common @@ -1370,7 +1380,7 @@ async def append_message( await check_permission( session, - user, + auth, thread.search_space_id, Permission.CHATS_UPDATE.value, "You don't have permission to update chats in this search space", @@ -1597,8 +1607,9 @@ async def list_messages( skip: int = 0, limit: int = 100, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ List messages in a thread with pagination. @@ -1620,7 +1631,7 @@ async def list_messages( await check_permission( session, - user, + auth, thread.search_space_id, Permission.CHATS_READ.value, "You don't have permission to read chats in this search space", @@ -1662,7 +1673,7 @@ async def list_messages( @router.get("/agent/tools", response_model=list[AgentToolInfo]) async def list_agent_tools( - _user: User = Depends(current_active_user), + _auth: AuthContext = Depends(get_auth_context), ): """Return the list of built-in agent tools with their metadata. @@ -1691,8 +1702,9 @@ async def handle_new_chat( request: NewChatRequest, http_request: Request, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Stream chat responses from the deep agent. @@ -1717,7 +1729,7 @@ async def handle_new_chat( await check_permission( session, - user, + auth, thread.search_space_id, Permission.CHATS_CREATE.value, "You don't have permission to chat in this search space", @@ -1795,6 +1807,7 @@ async def handle_new_chat( filesystem_selection=filesystem_selection, request_id=getattr(http_request.state, "request_id", "unknown"), user_image_data_urls=image_urls, + auth_context=auth, ), media_type="text/event-stream", headers={ @@ -1821,8 +1834,9 @@ async def cancel_active_turn( thread_id: int, response: Response, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """Signal cancellation for the currently running turn on ``thread_id``.""" result = await session.execute( select(NewChatThread).filter(NewChatThread.id == thread_id) @@ -1833,7 +1847,7 @@ async def cancel_active_turn( await check_permission( session, - user, + auth, thread.search_space_id, Permission.CHATS_UPDATE.value, "You don't have permission to update chats in this search space", @@ -1873,8 +1887,9 @@ async def cancel_active_turn( async def get_turn_status( thread_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user result = await session.execute( select(NewChatThread).filter(NewChatThread.id == thread_id) ) @@ -1884,7 +1899,7 @@ async def get_turn_status( await check_permission( session, - user, + auth, thread.search_space_id, Permission.CHATS_READ.value, "You don't have permission to view chats in this search space", @@ -1911,8 +1926,9 @@ async def regenerate_response( request: RegenerateRequest, http_request: Request, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Regenerate the AI response for a chat thread. @@ -1947,7 +1963,7 @@ async def regenerate_response( await check_permission( session, - user, + auth, thread.search_space_id, Permission.CHATS_UPDATE.value, "You don't have permission to update chats in this search space", @@ -2288,6 +2304,7 @@ async def regenerate_response( filesystem_selection=filesystem_selection, request_id=getattr(http_request.state, "request_id", "unknown"), user_image_data_urls=regenerate_image_urls or None, + auth_context=auth, flow="regenerate", ): yield chunk @@ -2356,8 +2373,9 @@ async def resume_chat( request: ResumeRequest, http_request: Request, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user try: result = await session.execute( select(NewChatThread).filter(NewChatThread.id == thread_id) @@ -2369,7 +2387,7 @@ async def resume_chat( await check_permission( session, - user, + auth, thread.search_space_id, Permission.CHATS_CREATE.value, "You don't have permission to chat in this search space", @@ -2413,6 +2431,7 @@ async def resume_chat( filesystem_selection=filesystem_selection, request_id=getattr(http_request.state, "request_id", "unknown"), disabled_tools=request.disabled_tools, + auth_context=auth, ), media_type="text/event-stream", headers={ diff --git a/surfsense_backend/app/routes/notes_routes.py b/surfsense_backend/app/routes/notes_routes.py index 76518de08..e5cca8700 100644 --- a/surfsense_backend/app/routes/notes_routes.py +++ b/surfsense_backend/app/routes/notes_routes.py @@ -9,9 +9,10 @@ from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from app.auth.context import AuthContext from app.db import Document, DocumentType, Permission, User, get_async_session from app.schemas import DocumentRead, PaginatedResponse -from app.users import current_active_user +from app.users import get_auth_context from app.utils.rbac import check_permission router = APIRouter() @@ -27,8 +28,9 @@ async def create_note( search_space_id: int, request: CreateNoteRequest, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Create a new note document. @@ -37,7 +39,7 @@ async def create_note( # Check RBAC permission await check_permission( session, - user, + auth, search_space_id, Permission.DOCUMENTS_CREATE.value, "You don't have permission to create notes in this search space", @@ -98,8 +100,9 @@ async def list_notes( page: int | None = None, page_size: int = 50, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ List all notes in a search space. @@ -108,7 +111,7 @@ async def list_notes( # Check RBAC permission await check_permission( session, - user, + auth, search_space_id, Permission.DOCUMENTS_READ.value, "You don't have permission to read notes in this search space", @@ -191,8 +194,9 @@ async def delete_note( search_space_id: int, note_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Delete a note. @@ -201,7 +205,7 @@ async def delete_note( # Check RBAC permission await check_permission( session, - user, + auth, search_space_id, Permission.DOCUMENTS_DELETE.value, "You don't have permission to delete notes in this search space", diff --git a/surfsense_backend/app/routes/notion_add_connector_route.py b/surfsense_backend/app/routes/notion_add_connector_route.py index 16e80ebcb..b0fafb242 100644 --- a/surfsense_backend/app/routes/notion_add_connector_route.py +++ b/surfsense_backend/app/routes/notion_add_connector_route.py @@ -17,15 +17,15 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm.attributes import flag_modified +from app.auth.context import AuthContext from app.config import config from app.db import ( SearchSourceConnector, SearchSourceConnectorType, - User, get_async_session, ) from app.schemas.notion_auth_credentials import NotionAuthCredentialsBase -from app.users import current_active_user +from app.users import require_session_context from app.utils.connector_naming import ( check_duplicate_connector, extract_identifier_from_credentials, @@ -76,7 +76,10 @@ def make_basic_auth_header(client_id: str, client_secret: str) -> str: @router.get("/auth/notion/connector/add") -async def connect_notion(space_id: int, user: User = Depends(current_active_user)): +async def connect_notion( + space_id: int, + auth: AuthContext = Depends(require_session_context), +): """ Initiate Notion OAuth flow. @@ -87,6 +90,7 @@ async def connect_notion(space_id: int, user: User = Depends(current_active_user Returns: Authorization URL for redirect """ + user = auth.user try: if not space_id: raise HTTPException(status_code=400, detail="space_id is required") @@ -131,10 +135,11 @@ async def reauth_notion( space_id: int, connector_id: int, return_url: str | None = None, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), session: AsyncSession = Depends(get_async_session), ): """Initiate Notion re-authentication for an existing connector.""" + user = auth.user try: result = await session.execute( select(SearchSourceConnector).filter( diff --git a/surfsense_backend/app/routes/oauth_connector_base.py b/surfsense_backend/app/routes/oauth_connector_base.py index 5b75d8519..483caa6c2 100644 --- a/surfsense_backend/app/routes/oauth_connector_base.py +++ b/surfsense_backend/app/routes/oauth_connector_base.py @@ -24,14 +24,14 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm.attributes import flag_modified +from app.auth.context import AuthContext from app.config import config from app.db import ( SearchSourceConnector, SearchSourceConnectorType, - User, get_async_session, ) -from app.users import current_active_user +from app.users import require_session_context from app.utils.connector_naming import ( check_duplicate_connector, generate_unique_connector_name, @@ -361,8 +361,9 @@ class OAuthConnectorRoute: @router.get(f"{oauth.auth_prefix}/connector/add") async def connect( space_id: int, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), ): + user = auth.user if not space_id: raise HTTPException(status_code=400, detail="space_id is required") @@ -406,9 +407,10 @@ class OAuthConnectorRoute: space_id: int, connector_id: int, return_url: str | None = None, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), session: AsyncSession = Depends(get_async_session), ): + user = auth.user result = await session.execute( select(SearchSourceConnector).filter( SearchSourceConnector.id == connector_id, diff --git a/surfsense_backend/app/routes/obsidian_plugin_routes.py b/surfsense_backend/app/routes/obsidian_plugin_routes.py index bd54a4788..56623d61a 100644 --- a/surfsense_backend/app/routes/obsidian_plugin_routes.py +++ b/surfsense_backend/app/routes/obsidian_plugin_routes.py @@ -16,6 +16,7 @@ from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select +from app.auth.context import AuthContext from app.db import ( Document, DocumentType, @@ -52,7 +53,8 @@ from app.services.obsidian_plugin_indexer import ( upsert_note, ) from app.tasks.celery_tasks.obsidian_tasks import index_obsidian_attachment_task -from app.users import current_active_user +from app.users import allow_any_principal, get_auth_context +from app.utils.rbac import check_search_space_access logger = logging.getLogger(__name__) @@ -174,10 +176,11 @@ async def _finish_obsidian_sync_notification( async def _resolve_vault_connector( session: AsyncSession, *, - user: User, + auth: AuthContext, vault_id: str, ) -> SearchSourceConnector: """Find the OBSIDIAN_CONNECTOR row that owns ``vault_id`` for this user.""" + user = auth.user # ``config`` is core ``JSON`` (not ``JSONB``); ``as_string()`` is the # cross-dialect equivalent of ``.astext`` and compiles to ``->>``. stmt = select(SearchSourceConnector).where( @@ -192,6 +195,7 @@ async def _resolve_vault_connector( connector = (await session.execute(stmt)).scalars().first() if connector is not None: + await check_search_space_access(session, auth, connector.search_space_id) return connector raise HTTPException( @@ -221,10 +225,11 @@ def _queue_obsidian_attachment( async def _ensure_search_space_access( session: AsyncSession, *, - user: User, + auth: AuthContext, search_space_id: int, ) -> SearchSpace: """Owner-only access to the search space (shared spaces are a follow-up).""" + user = auth.user result = await session.execute( select(SearchSpace).where( and_(SearchSpace.id == search_space_id, SearchSpace.user_id == user.id) @@ -239,6 +244,7 @@ async def _ensure_search_space_access( "message": "You don't own that search space.", }, ) + await check_search_space_access(session, auth, search_space_id) return space @@ -249,7 +255,7 @@ async def _ensure_search_space_access( @router.get("/health", response_model=HealthResponse) async def obsidian_health( - user: User = Depends(current_active_user), + _auth: AuthContext = Depends(allow_any_principal), ) -> HealthResponse: """Return the API contract handshake; plugin caches it per onload.""" return HealthResponse( @@ -306,7 +312,7 @@ def _display_name(vault_name: str) -> str: @router.post("/connect", response_model=ConnectResponse) async def obsidian_connect( payload: ConnectRequest, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), session: AsyncSession = Depends(get_async_session), ) -> ConnectResponse: """Register a vault, refresh an existing one, or adopt another device's row. @@ -321,8 +327,9 @@ async def obsidian_connect( the partial unique index can never produce two live rows for one vault. """ await _ensure_search_space_access( - session, user=user, search_space_id=payload.search_space_id + session, auth=auth, search_space_id=payload.search_space_id ) + user = auth.user now_iso = datetime.now(UTC).isoformat() cfg = _build_config(payload, now_iso=now_iso) @@ -445,13 +452,14 @@ async def obsidian_connect( @router.post("/sync", response_model=SyncAck) async def obsidian_sync( payload: SyncBatchRequest, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), session: AsyncSession = Depends(get_async_session), ) -> SyncAck: """Batch-upsert notes; returns per-note ack so the plugin can dequeue/retry.""" connector = await _resolve_vault_connector( - session, user=user, vault_id=payload.vault_id + session, auth=auth, vault_id=payload.vault_id ) + user = auth.user notification = None try: notification = await _start_obsidian_sync_notification( @@ -551,12 +559,12 @@ async def obsidian_sync( @router.post("/rename", response_model=RenameAck) async def obsidian_rename( payload: RenameBatchRequest, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), session: AsyncSession = Depends(get_async_session), ) -> RenameAck: """Apply a batch of vault rename events.""" connector = await _resolve_vault_connector( - session, user=user, vault_id=payload.vault_id + session, auth=auth, vault_id=payload.vault_id ) items: list[RenameAckItem] = [] @@ -618,12 +626,12 @@ async def obsidian_rename( @router.delete("/notes", response_model=DeleteAck) async def obsidian_delete_notes( payload: DeleteBatchRequest, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), session: AsyncSession = Depends(get_async_session), ) -> DeleteAck: """Soft-delete a batch of notes by vault-relative path.""" connector = await _resolve_vault_connector( - session, user=user, vault_id=payload.vault_id + session, auth=auth, vault_id=payload.vault_id ) deleted = 0 @@ -662,18 +670,18 @@ async def obsidian_delete_notes( @router.get("/manifest", response_model=ManifestResponse) async def obsidian_manifest( vault_id: str = Query(..., description="Plugin-side stable vault UUID"), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), session: AsyncSession = Depends(get_async_session), ) -> ManifestResponse: """Return ``{path: {hash, mtime}}`` for the plugin's onload reconcile diff.""" - connector = await _resolve_vault_connector(session, user=user, vault_id=vault_id) + connector = await _resolve_vault_connector(session, auth=auth, vault_id=vault_id) return await get_manifest(session, connector=connector, vault_id=vault_id) @router.get("/stats", response_model=StatsResponse) async def obsidian_stats( vault_id: str = Query(..., description="Plugin-side stable vault UUID"), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), session: AsyncSession = Depends(get_async_session), ) -> StatsResponse: """Active-note count + last sync time for the web tile. @@ -681,7 +689,7 @@ async def obsidian_stats( ``files_synced`` excludes tombstones so it matches ``/manifest``; ``last_sync_at`` includes them so deletes advance the freshness signal. """ - connector = await _resolve_vault_connector(session, user=user, vault_id=vault_id) + connector = await _resolve_vault_connector(session, auth=auth, vault_id=vault_id) is_active = Document.document_metadata["deleted_at"].as_string().is_(None) diff --git a/surfsense_backend/app/routes/onedrive_add_connector_route.py b/surfsense_backend/app/routes/onedrive_add_connector_route.py index 2f41efca7..9c55d4fe7 100644 --- a/surfsense_backend/app/routes/onedrive_add_connector_route.py +++ b/surfsense_backend/app/routes/onedrive_add_connector_route.py @@ -21,21 +21,22 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy.orm.attributes import flag_modified +from app.auth.context import AuthContext from app.config import config from app.connectors.onedrive import OneDriveClient, list_folder_contents from app.db import ( SearchSourceConnector, SearchSourceConnectorType, - User, get_async_session, ) -from app.users import current_active_user +from app.users import get_auth_context, require_session_context from app.utils.connector_naming import ( check_duplicate_connector, extract_identifier_from_credentials, generate_unique_connector_name, ) from app.utils.oauth_security import OAuthStateManager, TokenEncryption +from app.utils.rbac import check_search_space_access logger = logging.getLogger(__name__) router = APIRouter() @@ -73,8 +74,12 @@ def get_token_encryption() -> TokenEncryption: @router.get("/auth/onedrive/connector/add") -async def connect_onedrive(space_id: int, user: User = Depends(current_active_user)): +async def connect_onedrive( + space_id: int, + auth: AuthContext = Depends(require_session_context), +): """Initiate OneDrive OAuth flow.""" + user = auth.user try: if not space_id: raise HTTPException(status_code=400, detail="space_id is required") @@ -119,10 +124,11 @@ async def reauth_onedrive( space_id: int, connector_id: int, return_url: str | None = None, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), session: AsyncSession = Depends(get_async_session), ): """Re-authenticate an existing OneDrive connector.""" + user = auth.user try: result = await session.execute( select(SearchSourceConnector).filter( @@ -412,10 +418,11 @@ async def list_onedrive_folders( connector_id: int, parent_id: str | None = None, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): """List folders and files in user's OneDrive.""" connector = None + user = auth.user try: result = await session.execute( select(SearchSourceConnector).filter( @@ -431,6 +438,8 @@ async def list_onedrive_folders( status_code=404, detail="OneDrive connector not found or access denied" ) + await check_search_space_access(session, auth, connector.search_space_id) + onedrive_client = OneDriveClient(session, connector_id) items, error = await list_folder_contents(onedrive_client, parent_id=parent_id) diff --git a/surfsense_backend/app/routes/personal_access_tokens_routes.py b/surfsense_backend/app/routes/personal_access_tokens_routes.py new file mode 100644 index 000000000..a7849a2fc --- /dev/null +++ b/surfsense_backend/app/routes/personal_access_tokens_routes.py @@ -0,0 +1,104 @@ +from datetime import UTC, datetime, timedelta + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import delete +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + +from app.auth.context import AuthContext +from app.config import config +from app.db import PersonalAccessToken, get_async_session +from app.schemas.pat import PATCreate, PATCreated, PATRead +from app.users import require_session_context +from app.utils.pat import generate_pat, hash_pat, token_prefix + +router = APIRouter() + + +def _expires_at(expires_in_days: int | None) -> datetime | None: + max_expiry_days = config.PAT_MAX_EXPIRY_DAYS + + if max_expiry_days is not None: + if expires_in_days is None: + raise HTTPException( + status_code=400, + detail=( + "This deployment requires PATs to have an expiry of " + f"{max_expiry_days} days or less" + ), + ) + if expires_in_days > max_expiry_days: + raise HTTPException( + status_code=400, + detail=f"PAT expiry cannot exceed {max_expiry_days} days", + ) + + if expires_in_days is None: + return None + + return datetime.now(UTC) + timedelta(days=expires_in_days) + + +@router.post("/pats", response_model=PATCreated) +async def create_personal_access_token( + body: PATCreate, + session: AsyncSession = Depends(get_async_session), + auth: AuthContext = Depends(require_session_context), +) -> PATCreated: + token = generate_pat() + pat = PersonalAccessToken( + user_id=auth.user.id, + token_hash=hash_pat(token), + token_prefix=token_prefix(token), + label=body.label.strip(), + expires_at=_expires_at(body.expires_in_days), + ) + session.add(pat) + await session.commit() + await session.refresh(pat) + + return PATCreated( + id=pat.id, + label=pat.label, + token=token, + prefix=pat.token_prefix, + expires_at=pat.expires_at, + ) + + +@router.get("/pats", response_model=list[PATRead]) +async def list_personal_access_tokens( + session: AsyncSession = Depends(get_async_session), + auth: AuthContext = Depends(require_session_context), +) -> list[PATRead]: + result = await session.execute( + select(PersonalAccessToken) + .where(PersonalAccessToken.user_id == auth.user.id) + .order_by(PersonalAccessToken.created_at.desc()) + ) + return [ + PATRead( + id=pat.id, + label=pat.label, + prefix=pat.token_prefix, + expires_at=pat.expires_at, + last_used_at=pat.last_used_at, + created_at=pat.created_at, + ) + for pat in result.scalars().all() + ] + + +@router.delete("/pats/{pat_id}", status_code=204) +async def delete_personal_access_token( + pat_id: int, + session: AsyncSession = Depends(get_async_session), + auth: AuthContext = Depends(require_session_context), +) -> None: + await session.execute( + delete(PersonalAccessToken).where( + PersonalAccessToken.id == pat_id, + PersonalAccessToken.user_id == auth.user.id, + ) + ) + await session.commit() diff --git a/surfsense_backend/app/routes/prompts_routes.py b/surfsense_backend/app/routes/prompts_routes.py index 8dd47537e..b4cb1466c 100644 --- a/surfsense_backend/app/routes/prompts_routes.py +++ b/surfsense_backend/app/routes/prompts_routes.py @@ -3,14 +3,15 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload -from app.db import Prompt, SearchSpaceMembership, User, get_async_session +from app.auth.context import AuthContext +from app.db import Prompt, SearchSpaceMembership, get_async_session from app.schemas.prompts import ( PromptCreate, PromptRead, PromptUpdate, PublicPromptRead, ) -from app.users import current_active_user +from app.users import require_session_context router = APIRouter(tags=["Prompts"]) @@ -19,8 +20,9 @@ router = APIRouter(tags=["Prompts"]) async def list_prompts( search_space_id: int | None = None, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), ): + user = auth.user query = select(Prompt).where(Prompt.user_id == user.id) if search_space_id is not None: query = query.where(Prompt.search_space_id == search_space_id) @@ -33,8 +35,9 @@ async def list_prompts( async def create_prompt( body: PromptCreate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), ): + user = auth.user if body.search_space_id is not None: membership = await session.execute( select(SearchSpaceMembership).where( @@ -67,8 +70,9 @@ async def update_prompt( prompt_id: int, body: PromptUpdate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), ): + user = auth.user result = await session.execute( select(Prompt).where( Prompt.id == prompt_id, @@ -99,8 +103,9 @@ async def update_prompt( async def delete_prompt( prompt_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), ): + user = auth.user result = await session.execute( select(Prompt).where( Prompt.id == prompt_id, @@ -119,8 +124,9 @@ async def delete_prompt( @router.get("/prompts/public", response_model=list[PublicPromptRead]) async def list_public_prompts( session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), ): + user = auth.user result = await session.execute( select(Prompt) .options(selectinload(Prompt.user)) @@ -141,8 +147,9 @@ async def list_public_prompts( async def copy_public_prompt( prompt_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), ): + user = auth.user result = await session.execute( select(Prompt).where( Prompt.id == prompt_id, diff --git a/surfsense_backend/app/routes/public_chat_routes.py b/surfsense_backend/app/routes/public_chat_routes.py index 516e976e6..70f012911 100644 --- a/surfsense_backend/app/routes/public_chat_routes.py +++ b/surfsense_backend/app/routes/public_chat_routes.py @@ -11,7 +11,8 @@ from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import StreamingResponse from sqlalchemy.ext.asyncio import AsyncSession -from app.db import User, get_async_session +from app.auth.context import AuthContext +from app.db import get_async_session from app.schemas.new_chat import ( CloneResponse, PublicChatResponse, @@ -23,7 +24,7 @@ from app.services.public_chat_service import ( get_snapshot_report, get_snapshot_video_presentation, ) -from app.users import current_active_user +from app.users import require_session_context router = APIRouter(prefix="/public", tags=["public"]) @@ -46,8 +47,9 @@ async def read_public_chat( async def clone_public_chat( share_token: str, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), ): + user = auth.user """ Clone a public chat snapshot to the user's account. diff --git a/surfsense_backend/app/routes/rbac_routes.py b/surfsense_backend/app/routes/rbac_routes.py index 3b91e456d..3d50d589d 100644 --- a/surfsense_backend/app/routes/rbac_routes.py +++ b/surfsense_backend/app/routes/rbac_routes.py @@ -18,6 +18,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy.orm import selectinload +from app.auth.context import AuthContext from app.db import ( Permission, SearchSpace, @@ -43,7 +44,7 @@ from app.schemas import ( RoleUpdate, UserSearchSpaceAccess, ) -from app.users import current_active_user +from app.users import get_auth_context from app.utils.rbac import ( check_permission, check_search_space_access, @@ -107,6 +108,8 @@ PERMISSION_DESCRIPTIONS = { "settings:view": "View search space settings", "settings:update": "Modify search space settings", "settings:delete": "Delete the entire search space", + # API access + "api_access:manage": "Enable or disable programmatic API access for a search space", # Automations "automations:create": "Create automations from chat or JSON", "automations:read": "View automations, their triggers, and run history", @@ -120,8 +123,9 @@ PERMISSION_DESCRIPTIONS = { @router.get("/permissions", response_model=PermissionsListResponse) async def list_all_permissions( - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ List all available permissions that can be assigned to roles. """ @@ -156,8 +160,9 @@ async def create_role( search_space_id: int, role_data: RoleCreate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Create a new custom role in a search space. Requires ROLES_CREATE permission. @@ -165,7 +170,7 @@ async def create_role( try: await check_permission( session, - user, + auth, search_space_id, Permission.ROLES_CREATE.value, "You don't have permission to create roles", @@ -237,8 +242,9 @@ async def create_role( async def list_roles( search_space_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ List all roles in a search space. Requires ROLES_READ permission. @@ -246,7 +252,7 @@ async def list_roles( try: await check_permission( session, - user, + auth, search_space_id, Permission.ROLES_READ.value, "You don't have permission to view roles", @@ -275,8 +281,9 @@ async def get_role( search_space_id: int, role_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Get a specific role by ID. Requires ROLES_READ permission. @@ -284,7 +291,7 @@ async def get_role( try: await check_permission( session, - user, + auth, search_space_id, Permission.ROLES_READ.value, "You don't have permission to view roles", @@ -320,8 +327,9 @@ async def update_role( role_id: int, role_update: RoleUpdate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Update a role. Requires ROLES_UPDATE permission. @@ -330,7 +338,7 @@ async def update_role( try: await check_permission( session, - user, + auth, search_space_id, Permission.ROLES_UPDATE.value, "You don't have permission to update roles", @@ -417,8 +425,9 @@ async def delete_role( search_space_id: int, role_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Delete a custom role. Requires ROLES_DELETE permission. @@ -427,7 +436,7 @@ async def delete_role( try: await check_permission( session, - user, + auth, search_space_id, Permission.ROLES_DELETE.value, "You don't have permission to delete roles", @@ -474,8 +483,9 @@ async def delete_role( async def list_members( search_space_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ List all members of a search space. Requires MEMBERS_VIEW permission. @@ -483,7 +493,7 @@ async def list_members( try: await check_permission( session, - user, + auth, search_space_id, Permission.MEMBERS_VIEW.value, "You don't have permission to view members", @@ -539,8 +549,9 @@ async def update_member_role( membership_id: int, membership_update: MembershipUpdate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Update a member's role. Requires MEMBERS_MANAGE_ROLES permission. @@ -549,7 +560,7 @@ async def update_member_role( try: await check_permission( session, - user, + auth, search_space_id, Permission.MEMBERS_MANAGE_ROLES.value, "You don't have permission to manage member roles", @@ -629,8 +640,9 @@ async def update_member_role( async def leave_search_space( search_space_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Leave a search space (remove own membership). Owners cannot leave their search space. @@ -675,8 +687,9 @@ async def remove_member( search_space_id: int, membership_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Remove a member from a search space. Requires MEMBERS_REMOVE permission. @@ -685,7 +698,7 @@ async def remove_member( try: await check_permission( session, - user, + auth, search_space_id, Permission.MEMBERS_REMOVE.value, "You don't have permission to remove members", @@ -733,8 +746,9 @@ async def create_invite( search_space_id: int, invite_data: InviteCreate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Create a new invite link for a search space. Requires MEMBERS_INVITE permission. @@ -742,7 +756,7 @@ async def create_invite( try: await check_permission( session, - user, + auth, search_space_id, Permission.MEMBERS_INVITE.value, "You don't have permission to create invites", @@ -798,8 +812,9 @@ async def create_invite( async def list_invites( search_space_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ List all invites for a search space. Requires MEMBERS_INVITE permission. @@ -807,7 +822,7 @@ async def list_invites( try: await check_permission( session, - user, + auth, search_space_id, Permission.MEMBERS_INVITE.value, "You don't have permission to view invites", @@ -837,8 +852,9 @@ async def update_invite( invite_id: int, invite_update: InviteUpdate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Update an invite. Requires MEMBERS_INVITE permission. @@ -846,7 +862,7 @@ async def update_invite( try: await check_permission( session, - user, + auth, search_space_id, Permission.MEMBERS_INVITE.value, "You don't have permission to update invites", @@ -903,8 +919,9 @@ async def revoke_invite( search_space_id: int, invite_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Revoke (delete) an invite. Requires MEMBERS_INVITE permission. @@ -912,7 +929,7 @@ async def revoke_invite( try: await check_permission( session, - user, + auth, search_space_id, Permission.MEMBERS_INVITE.value, "You don't have permission to revoke invites", @@ -1022,8 +1039,9 @@ async def get_invite_info( async def accept_invite( request: InviteAcceptRequest, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Accept an invite and join a search space. """ @@ -1120,13 +1138,14 @@ async def accept_invite( async def get_my_access( search_space_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Get the current user's access info for a search space. """ try: - membership = await check_search_space_access(session, user, search_space_id) + membership = await check_search_space_access(session, auth, search_space_id) # Get search space name result = await session.execute( diff --git a/surfsense_backend/app/routes/reports_routes.py b/surfsense_backend/app/routes/reports_routes.py index 19961e1a9..d5996485e 100644 --- a/surfsense_backend/app/routes/reports_routes.py +++ b/surfsense_backend/app/routes/reports_routes.py @@ -28,6 +28,7 @@ from sqlalchemy import select from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession +from app.auth.context import AuthContext from app.db import ( Report, SearchSpace, @@ -42,7 +43,7 @@ from app.templates.export_helpers import ( get_reference_docx_path, get_typst_template_path, ) -from app.users import current_active_user +from app.users import get_auth_context from app.utils.rbac import check_search_space_access logger = logging.getLogger(__name__) @@ -158,8 +159,9 @@ def _normalize_latex_delimiters(text: str) -> str: async def _get_report_with_access( report_id: int, session: AsyncSession, - user: User, + auth: AuthContext, ) -> Report: + user = auth.user """Fetch a report and verify the user belongs to its search space. Raises HTTPException(404) if not found, HTTPException(403) if no access. @@ -172,7 +174,7 @@ async def _get_report_with_access( # Lightweight membership check - no granular RBAC, just "is the user a # member of the search space this report belongs to?" - await check_search_space_access(session, user, report.search_space_id) + await check_search_space_access(session, auth, report.search_space_id) return report @@ -206,8 +208,9 @@ async def read_reports( limit: int = Query(default=100, ge=1, le=MAX_REPORT_LIST_LIMIT), search_space_id: int | None = None, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ List reports the user has access to. Filters by search space membership. @@ -215,7 +218,7 @@ async def read_reports( try: if search_space_id is not None: # Verify the caller is a member of the requested search space - await check_search_space_access(session, user, search_space_id) + await check_search_space_access(session, auth, search_space_id) result = await session.execute( select(Report) @@ -247,8 +250,9 @@ async def read_reports( async def read_report( report_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Get a specific report by ID (metadata only, no content). """ @@ -266,8 +270,9 @@ async def read_report( async def read_report_content( report_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Get full Markdown content of a report, including version siblings. """ @@ -298,8 +303,9 @@ async def update_report_content( report_id: int, body: ReportContentUpdate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Update the Markdown content of a report. @@ -339,8 +345,9 @@ async def update_report_content( async def preview_report_pdf( report_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Return a compiled PDF preview for Typst-based reports (resumes). @@ -394,8 +401,9 @@ async def export_report( description="Export format: pdf, docx, html, latex, epub, odt, or plain", ), session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Export a report in the requested format. """ @@ -568,8 +576,9 @@ async def export_report( async def delete_report( report_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Delete a report. """ diff --git a/surfsense_backend/app/routes/sandbox_routes.py b/surfsense_backend/app/routes/sandbox_routes.py index fefe51997..e7974b993 100644 --- a/surfsense_backend/app/routes/sandbox_routes.py +++ b/surfsense_backend/app/routes/sandbox_routes.py @@ -10,8 +10,9 @@ from fastapi.responses import Response from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select +from app.auth.context import AuthContext from app.db import NewChatThread, Permission, User, get_async_session -from app.users import current_active_user +from app.users import get_auth_context from app.utils.rbac import check_permission logger = logging.getLogger(__name__) @@ -47,8 +48,9 @@ async def download_sandbox_file( thread_id: int, path: str = Query(..., description="Absolute path of the file inside the sandbox"), session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """Download a file from the Daytona sandbox associated with a chat thread.""" from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.sandbox import ( @@ -68,7 +70,7 @@ async def download_sandbox_file( await check_permission( session, - user, + auth, thread.search_space_id, Permission.CHATS_READ.value, "You don't have permission to access files in this thread", diff --git a/surfsense_backend/app/routes/search_source_connectors_routes.py b/surfsense_backend/app/routes/search_source_connectors_routes.py index 512b52ae4..fab79ab49 100644 --- a/surfsense_backend/app/routes/search_source_connectors_routes.py +++ b/surfsense_backend/app/routes/search_source_connectors_routes.py @@ -33,6 +33,7 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select +from app.auth.context import AuthContext from app.config import config from app.connectors.github_connector import GitHubConnector from app.db import ( @@ -56,7 +57,7 @@ from app.schemas import ( SearchSourceConnectorUpdate, ) from app.services.composio_service import ComposioService, get_composio_service -from app.users import current_active_user +from app.users import get_auth_context # NOTE: connector indexer functions are imported lazily inside each # ``run_*_indexing`` helper to break a circular import cycle: @@ -143,8 +144,9 @@ class GitHubPATRequest(BaseModel): @router.post("/github/repositories", response_model=list[dict[str, Any]]) async def list_github_repositories( pat_request: GitHubPATRequest, - user: User = Depends(current_active_user), # Ensure the user is logged in + auth: AuthContext = Depends(get_auth_context), # Ensure the user is logged in ): + user = auth.user """ Fetches a list of repositories accessible by the provided GitHub PAT. The PAT is used for this request only and is not stored. @@ -173,8 +175,9 @@ async def create_search_source_connector( ..., description="ID of the search space to associate the connector with" ), session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Create a new search source connector. Requires CONNECTORS_CREATE permission. @@ -186,7 +189,7 @@ async def create_search_source_connector( # Check if user has permission to create connectors await check_permission( session, - user, + auth, search_space_id, Permission.CONNECTORS_CREATE.value, "You don't have permission to create connectors in this search space", @@ -281,8 +284,9 @@ async def read_search_source_connectors( limit: int = 100, search_space_id: int | None = None, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ List all search source connectors for a search space. Requires CONNECTORS_READ permission. @@ -297,7 +301,7 @@ async def read_search_source_connectors( # Check if user has permission to read connectors await check_permission( session, - user, + auth, search_space_id, Permission.CONNECTORS_READ.value, "You don't have permission to view connectors in this search space", @@ -324,8 +328,9 @@ async def read_search_source_connectors( async def read_search_source_connector( connector_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Get a specific search source connector by ID. Requires CONNECTORS_READ permission. @@ -345,7 +350,7 @@ async def read_search_source_connector( # Check permission await check_permission( session, - user, + auth, connector.search_space_id, Permission.CONNECTORS_READ.value, "You don't have permission to view this connector", @@ -367,8 +372,9 @@ async def update_search_source_connector( connector_id: int, connector_update: SearchSourceConnectorUpdate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Update a search source connector. Requires CONNECTORS_UPDATE permission. @@ -386,7 +392,7 @@ async def update_search_source_connector( # Check permission await check_permission( session, - user, + auth, db_connector.search_space_id, Permission.CONNECTORS_UPDATE.value, "You don't have permission to update this connector", @@ -557,8 +563,9 @@ async def update_search_source_connector( async def delete_search_source_connector( connector_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Delete a search source connector and all its associated documents. @@ -588,7 +595,7 @@ async def delete_search_source_connector( # Check permission await check_permission( session, - user, + auth, db_connector.search_space_id, Permission.CONNECTORS_DELETE.value, "You don't have permission to delete this connector", @@ -725,8 +732,9 @@ async def index_connector_content( description="[Google Drive only] Structured request with folders and files to index", ), session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Index content from a KB connector to a search space. @@ -760,7 +768,7 @@ async def index_connector_content( # the read/update/delete handlers — not the client-supplied query param. await check_permission( session, - user, + auth, connector.search_space_id, Permission.CONNECTORS_UPDATE.value, "You don't have permission to index content in this search space", @@ -2645,8 +2653,9 @@ async def create_mcp_connector( connector_data: MCPConnectorCreate, search_space_id: int = Query(..., description="Search space ID"), session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Create a new MCP (Model Context Protocol) connector. @@ -2669,7 +2678,7 @@ async def create_mcp_connector( # Check user has permission to create connectors await check_permission( session, - user, + auth, search_space_id, Permission.CONNECTORS_CREATE.value, "You don't have permission to create connectors in this search space", @@ -2724,8 +2733,9 @@ async def create_mcp_connector( async def list_mcp_connectors( search_space_id: int = Query(..., description="Search space ID"), session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ List all MCP connectors for a search space. @@ -2741,7 +2751,7 @@ async def list_mcp_connectors( # Check user has permission to read connectors await check_permission( session, - user, + auth, search_space_id, Permission.CONNECTORS_READ.value, "You don't have permission to view connectors in this search space", @@ -2775,8 +2785,9 @@ async def list_mcp_connectors( async def get_mcp_connector( connector_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Get a specific MCP connector by ID. @@ -2805,7 +2816,7 @@ async def get_mcp_connector( # Check user has permission to read connectors await check_permission( session, - user, + auth, connector.search_space_id, Permission.CONNECTORS_READ.value, "You don't have permission to view this connector", @@ -2828,8 +2839,9 @@ async def update_mcp_connector( connector_id: int, connector_update: MCPConnectorUpdate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Update an MCP connector. @@ -2859,7 +2871,7 @@ async def update_mcp_connector( # Check user has permission to update connectors await check_permission( session, - user, + auth, connector.search_space_id, Permission.CONNECTORS_UPDATE.value, "You don't have permission to update this connector", @@ -2904,8 +2916,9 @@ async def update_mcp_connector( async def delete_mcp_connector( connector_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Delete an MCP connector. @@ -2931,7 +2944,7 @@ async def delete_mcp_connector( # Check user has permission to delete connectors await check_permission( session, - user, + auth, connector.search_space_id, Permission.CONNECTORS_DELETE.value, "You don't have permission to delete this connector", @@ -2962,8 +2975,9 @@ async def delete_mcp_connector( @router.post("/connectors/mcp/test") async def test_mcp_server_connection( server_config: dict = Body(...), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Test connection to an MCP server and fetch available tools. @@ -3042,8 +3056,9 @@ DRIVE_CONNECTOR_TYPES = { async def get_drive_picker_token( connector_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """Return an OAuth access token + client ID for the Google Picker API.""" result = await session.execute( select(SearchSourceConnector).filter(SearchSourceConnector.id == connector_id) @@ -3054,7 +3069,7 @@ async def get_drive_picker_token( await check_permission( session, - user, + auth, connector.search_space_id, Permission.CONNECTORS_READ.value, "You don't have permission to access this connector", @@ -3164,8 +3179,9 @@ async def trust_mcp_tool( connector_id: int, body: MCPTrustToolRequest, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """Add a tool to the MCP connector's trusted (always-allow) list. Once trusted, the tool executes without HITL approval on subsequent @@ -3209,8 +3225,9 @@ async def untrust_mcp_tool( connector_id: int, body: MCPTrustToolRequest, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """Remove a tool from the MCP connector's trusted list. The tool will require HITL approval again on subsequent calls. diff --git a/surfsense_backend/app/routes/search_spaces_routes.py b/surfsense_backend/app/routes/search_spaces_routes.py index 592a9dd0e..e92f7dfc1 100644 --- a/surfsense_backend/app/routes/search_spaces_routes.py +++ b/surfsense_backend/app/routes/search_spaces_routes.py @@ -5,22 +5,23 @@ from sqlalchemy import func from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select +from app.auth.context import AuthContext from app.db import ( Permission, SearchSpace, SearchSpaceMembership, SearchSpaceRole, - User, get_async_session, get_default_roles_config, ) from app.schemas import ( + SearchSpaceApiAccessUpdate, SearchSpaceCreate, SearchSpaceRead, SearchSpaceUpdate, SearchSpaceWithStats, ) -from app.users import current_active_user +from app.users import allow_any_principal, get_auth_context, require_session_context from app.utils.rbac import check_permission, check_search_space_access logger = logging.getLogger(__name__) @@ -74,8 +75,9 @@ async def create_default_roles_and_membership( async def create_search_space( search_space: SearchSpaceCreate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), ): + user = auth.user try: search_space_data = search_space.model_dump() @@ -108,8 +110,9 @@ async def read_search_spaces( limit: int = 200, owned_only: bool = False, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(allow_any_principal), ): + user = auth.user """ Get all search spaces the user has access to, with member count and ownership info. @@ -123,11 +126,17 @@ async def read_search_spaces( # Exclude spaces that are pending background deletion not_deleting = ~SearchSpace.name.startswith("[DELETING] ") + api_access_filter = ( + SearchSpace.api_access_enabled == True # noqa: E712 + if auth.is_gated + else True + ) + if owned_only: # Return only search spaces where user is the original creator (user_id) result = await session.execute( select(SearchSpace) - .filter(SearchSpace.user_id == user.id, not_deleting) + .filter(SearchSpace.user_id == user.id, not_deleting, api_access_filter) .order_by(SearchSpace.id.asc()) .offset(skip) .limit(limit) @@ -137,7 +146,11 @@ async def read_search_spaces( result = await session.execute( select(SearchSpace) .join(SearchSpaceMembership) - .filter(SearchSpaceMembership.user_id == user.id, not_deleting) + .filter( + SearchSpaceMembership.user_id == user.id, + not_deleting, + api_access_filter, + ) .order_by(SearchSpace.id.asc()) .offset(skip) .limit(limit) @@ -174,6 +187,7 @@ async def read_search_spaces( created_at=space.created_at, user_id=space.user_id, citations_enabled=space.citations_enabled, + api_access_enabled=space.api_access_enabled, qna_custom_instructions=space.qna_custom_instructions, ai_file_sort_enabled=space.ai_file_sort_enabled, member_count=member_count, @@ -192,7 +206,7 @@ async def read_search_spaces( async def read_search_space( search_space_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): """ Get a specific search space by ID. @@ -200,7 +214,7 @@ async def read_search_space( """ try: # Check if user has access (is a member) - await check_search_space_access(session, user, search_space_id) + await check_search_space_access(session, auth, search_space_id) result = await session.execute( select(SearchSpace).filter(SearchSpace.id == search_space_id) @@ -225,7 +239,7 @@ async def update_search_space( search_space_id: int, search_space_update: SearchSpaceUpdate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): """ Update a search space. @@ -235,7 +249,7 @@ async def update_search_space( # Check permission await check_permission( session, - user, + auth, search_space_id, Permission.SETTINGS_UPDATE.value, "You don't have permission to update this search space", @@ -265,17 +279,65 @@ async def update_search_space( ) from e +@router.put("/searchspaces/{search_space_id}/api-access", response_model=SearchSpaceRead) +async def update_search_space_api_access( + search_space_id: int, + body: SearchSpaceApiAccessUpdate, + session: AsyncSession = Depends(get_async_session), + auth: AuthContext = Depends(get_auth_context), +): + """ + Toggle programmatic API/PAT access for a search space. + Requires API_ACCESS_MANAGE permission. + """ + try: + if not auth.is_session: + raise HTTPException( + status_code=403, + detail="This action requires an interactive session", + ) + + await check_permission( + session, + auth, + search_space_id, + Permission.API_ACCESS_MANAGE.value, + "You don't have permission to manage API access for this search space", + ) + + result = await session.execute( + select(SearchSpace).filter(SearchSpace.id == search_space_id) + ) + db_search_space = result.scalars().first() + + if not db_search_space: + raise HTTPException(status_code=404, detail="Search space not found") + + db_search_space.api_access_enabled = body.api_access_enabled + await session.commit() + await session.refresh(db_search_space) + return db_search_space + except HTTPException: + raise + except Exception as e: + await session.rollback() + raise HTTPException( + status_code=500, detail=f"Failed to update API access: {e!s}" + ) from e + + @router.post("/searchspaces/{search_space_id}/ai-sort") async def trigger_ai_sort( search_space_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """Trigger a full AI file sort for all documents in the search space.""" try: await check_permission( session, - user, + auth, search_space_id, Permission.SETTINGS_UPDATE.value, "You don't have permission to trigger AI sort on this search space", @@ -305,7 +367,7 @@ async def trigger_ai_sort( async def delete_search_space( search_space_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): """ Delete a search space. @@ -318,7 +380,7 @@ async def delete_search_space( # Check permission - only those with SETTINGS_DELETE can delete await check_permission( session, - user, + auth, search_space_id, Permission.SETTINGS_DELETE.value, "You don't have permission to delete this search space", @@ -374,7 +436,7 @@ async def delete_search_space( async def list_search_space_snapshots( search_space_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): """ List all public chat snapshots for a search space. @@ -387,6 +449,6 @@ async def list_search_space_snapshots( snapshots = await list_snapshots_for_search_space( session=session, search_space_id=search_space_id, - user=user, + auth=auth, ) return PublicChatSnapshotsBySpaceResponse(snapshots=snapshots) diff --git a/surfsense_backend/app/routes/slack_add_connector_route.py b/surfsense_backend/app/routes/slack_add_connector_route.py index f6a1458a0..ee6f75417 100644 --- a/surfsense_backend/app/routes/slack_add_connector_route.py +++ b/surfsense_backend/app/routes/slack_add_connector_route.py @@ -17,21 +17,22 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select +from app.auth.context import AuthContext from app.config import config from app.db import ( SearchSourceConnector, SearchSourceConnectorType, - User, get_async_session, ) from app.schemas.slack_auth_credentials import SlackAuthCredentialsBase -from app.users import current_active_user +from app.users import get_auth_context, require_session_context from app.utils.connector_naming import ( check_duplicate_connector, extract_identifier_from_credentials, generate_unique_connector_name, ) from app.utils.oauth_security import OAuthStateManager, TokenEncryption +from app.utils.rbac import check_search_space_access logger = logging.getLogger(__name__) @@ -78,7 +79,10 @@ def get_token_encryption() -> TokenEncryption: @router.get("/auth/slack/connector/add") -async def connect_slack(space_id: int, user: User = Depends(current_active_user)): +async def connect_slack( + space_id: int, + auth: AuthContext = Depends(require_session_context), +): """ Initiate Slack OAuth flow. @@ -89,6 +93,7 @@ async def connect_slack(space_id: int, user: User = Depends(current_active_user) Returns: Authorization URL for redirect """ + user = auth.user try: if not space_id: raise HTTPException(status_code=400, detail="space_id is required") @@ -525,7 +530,7 @@ async def refresh_slack_token( async def get_slack_channels( connector_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ) -> list[dict[str, Any]]: """ Get list of Slack channels with bot membership status. @@ -541,6 +546,7 @@ async def get_slack_channels( Returns: List of channels with id, name, is_private, and is_member fields """ + user = auth.user try: # Get the connector and verify ownership result = await session.execute( @@ -559,6 +565,8 @@ async def get_slack_channels( detail="Slack connector not found or access denied", ) + await check_search_space_access(session, auth, connector.search_space_id) + # Get credentials and decrypt bot token credentials = SlackAuthCredentialsBase.from_dict(connector.config) token_encryption = get_token_encryption() diff --git a/surfsense_backend/app/routes/stripe_routes.py b/surfsense_backend/app/routes/stripe_routes.py index 23dce58cd..288e38cc2 100644 --- a/surfsense_backend/app/routes/stripe_routes.py +++ b/surfsense_backend/app/routes/stripe_routes.py @@ -18,6 +18,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from stripe import SignatureVerificationError, StripeClient, StripeError +from app.auth.context import AuthContext from app.config import config from app.db import ( CreditPurchase, @@ -39,7 +40,7 @@ from app.schemas.stripe import ( StripeWebhookResponse, UpdateAutoReloadSettingsRequest, ) -from app.users import current_active_user +from app.users import require_session_context logger = logging.getLogger(__name__) @@ -456,7 +457,7 @@ async def _reconcile_auto_reload_payment_intent( ) async def create_credit_checkout_session( body: CreateCreditCheckoutSessionRequest, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), db_session: AsyncSession = Depends(get_async_session), ) -> CreateCreditCheckoutSessionResponse: """Create a Stripe Checkout Session for buying credit packs. @@ -466,6 +467,7 @@ async def create_credit_checkout_session( cost reported by LiteLLM (premium calls) or ``MICROS_PER_PAGE`` per page (ETL), so $1 of credit always buys $1 worth of usage at cost. """ + user = auth.user _ensure_credit_buying_enabled() stripe_client = get_stripe_client() price_id = _get_required_credit_price_id() @@ -644,7 +646,7 @@ async def stripe_webhook( @router.get("/finalize-checkout", response_model=FinalizeCheckoutResponse) async def finalize_checkout( session_id: str, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), db_session: AsyncSession = Depends(get_async_session), ) -> FinalizeCheckoutResponse: """Synchronously fulfil a credit checkout session from the success page. @@ -659,6 +661,7 @@ async def finalize_checkout( Authorization: the session's ``client_reference_id`` must match the authenticated user's id. """ + user = auth.user stripe_client = get_stripe_client() try: @@ -718,13 +721,14 @@ async def finalize_checkout( @router.get("/credit-status", response_model=CreditStripeStatusResponse) async def get_credit_status( - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), ) -> CreditStripeStatusResponse: """Return credit-buying availability and current balance for the frontend. ``credit_micros_balance`` is in micro-USD (1_000_000 = $1.00); the FE divides by 1M when displaying. """ + user = auth.user return CreditStripeStatusResponse( credit_buying_enabled=config.STRIPE_CREDIT_BUYING_ENABLED, credit_micros_balance=user.credit_micros_balance, @@ -733,12 +737,13 @@ async def get_credit_status( @router.get("/credit-purchases", response_model=CreditPurchaseHistoryResponse) async def get_credit_purchases( - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), db_session: AsyncSession = Depends(get_async_session), offset: int = 0, limit: int = 50, ) -> CreditPurchaseHistoryResponse: """Return the authenticated user's credit purchase history.""" + user = auth.user limit = min(limit, 100) purchases = ( ( @@ -759,7 +764,7 @@ async def get_credit_purchases( @router.get("/purchases", response_model=PagePurchaseHistoryResponse) async def get_page_purchases( - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), db_session: AsyncSession = Depends(get_async_session), offset: int = 0, limit: int = 50, @@ -768,6 +773,7 @@ async def get_page_purchases( Page buying is removed; this endpoint stays for historical records. """ + user = auth.user limit = min(limit, 100) purchases = ( ( @@ -804,7 +810,7 @@ def _auto_reload_settings_response(user: User) -> AutoReloadSettingsResponse: ) async def create_auto_reload_setup_session( body: CreateAutoReloadSetupSessionRequest, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), db_session: AsyncSession = Depends(get_async_session), ) -> CreateAutoReloadSetupSessionResponse: """Start a ``mode=setup`` checkout session to save a card for auto-reload. @@ -813,6 +819,7 @@ async def create_auto_reload_setup_session( Customer so the card can later be charged off-session. On completion the webhook stores the resulting payment method on the user. """ + user = auth.user _ensure_auto_reload_enabled() _ensure_credit_buying_enabled() stripe_client = get_stripe_client() @@ -871,16 +878,17 @@ async def create_auto_reload_setup_session( @router.get("/auto-reload", response_model=AutoReloadSettingsResponse) async def get_auto_reload_settings( - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), ) -> AutoReloadSettingsResponse: """Return the user's auto-reload configuration and saved-card state.""" + user = auth.user return _auto_reload_settings_response(user) @router.put("/auto-reload", response_model=AutoReloadSettingsResponse) async def update_auto_reload_settings( body: UpdateAutoReloadSettingsRequest, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), db_session: AsyncSession = Depends(get_async_session), ) -> AutoReloadSettingsResponse: """Update auto-reload preferences. @@ -889,6 +897,7 @@ async def update_auto_reload_settings( at least ``AUTO_RELOAD_MIN_AMOUNT_MICROS``. Disabling always succeeds and clears any prior failure flag. """ + user = auth.user _ensure_auto_reload_enabled() locked = ( diff --git a/surfsense_backend/app/routes/team_memory_routes.py b/surfsense_backend/app/routes/team_memory_routes.py index b37a99b03..3ded87d36 100644 --- a/surfsense_backend/app/routes/team_memory_routes.py +++ b/surfsense_backend/app/routes/team_memory_routes.py @@ -6,6 +6,7 @@ from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel from sqlalchemy.ext.asyncio import AsyncSession +from app.auth.context import AuthContext from app.db import User, get_async_session from app.services.memory import ( MemoryRead, @@ -15,7 +16,7 @@ from app.services.memory import ( reset_memory, save_memory, ) -from app.users import current_active_user +from app.users import get_auth_context from app.utils.rbac import check_search_space_access router = APIRouter() @@ -29,9 +30,10 @@ class TeamMemoryUpdate(BaseModel): async def get_team_memory( search_space_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): - await check_search_space_access(session, user, search_space_id) + user = auth.user + await check_search_space_access(session, auth, search_space_id) memory_md = await read_memory( scope=MemoryScope.TEAM, target_id=search_space_id, @@ -45,9 +47,10 @@ async def update_team_memory( search_space_id: int, body: TeamMemoryUpdate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): - await check_search_space_access(session, user, search_space_id) + user = auth.user + await check_search_space_access(session, auth, search_space_id) result = await save_memory( scope=MemoryScope.TEAM, target_id=search_space_id, @@ -63,9 +66,10 @@ async def update_team_memory( async def reset_team_memory( search_space_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): - await check_search_space_access(session, user, search_space_id) + user = auth.user + await check_search_space_access(session, auth, search_space_id) result = await reset_memory( scope=MemoryScope.TEAM, target_id=search_space_id, diff --git a/surfsense_backend/app/routes/teams_add_connector_route.py b/surfsense_backend/app/routes/teams_add_connector_route.py index 9d0f5144f..3782b4720 100644 --- a/surfsense_backend/app/routes/teams_add_connector_route.py +++ b/surfsense_backend/app/routes/teams_add_connector_route.py @@ -14,15 +14,15 @@ from fastapi.responses import RedirectResponse from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession +from app.auth.context import AuthContext from app.config import config from app.db import ( SearchSourceConnector, SearchSourceConnectorType, - User, get_async_session, ) from app.schemas.teams_auth_credentials import TeamsAuthCredentialsBase -from app.users import current_active_user +from app.users import require_session_context from app.utils.connector_naming import ( check_duplicate_connector, extract_identifier_from_credentials, @@ -74,7 +74,10 @@ def get_token_encryption() -> TokenEncryption: @router.get("/auth/teams/connector/add") -async def connect_teams(space_id: int, user: User = Depends(current_active_user)): +async def connect_teams( + space_id: int, + auth: AuthContext = Depends(require_session_context), +): """ Initiate Microsoft Teams OAuth flow. @@ -85,6 +88,7 @@ async def connect_teams(space_id: int, user: User = Depends(current_active_user) Returns: Authorization URL for redirect """ + user = auth.user try: if not space_id: raise HTTPException(status_code=400, detail="space_id is required") diff --git a/surfsense_backend/app/routes/video_presentations_routes.py b/surfsense_backend/app/routes/video_presentations_routes.py index ed694b9bf..189a050e4 100644 --- a/surfsense_backend/app/routes/video_presentations_routes.py +++ b/surfsense_backend/app/routes/video_presentations_routes.py @@ -16,6 +16,7 @@ from sqlalchemy import select from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession +from app.auth.context import AuthContext from app.db import ( Permission, SearchSpace, @@ -25,7 +26,7 @@ from app.db import ( get_async_session, ) from app.schemas import VideoPresentationRead -from app.users import current_active_user +from app.users import get_auth_context from app.utils.rbac import check_permission router = APIRouter() @@ -37,8 +38,9 @@ async def read_video_presentations( limit: int = 100, search_space_id: int | None = None, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ List video presentations the user has access to. Requires VIDEO_PRESENTATIONS_READ permission for the search space(s). @@ -49,7 +51,7 @@ async def read_video_presentations( if search_space_id is not None: await check_permission( session, - user, + auth, search_space_id, Permission.VIDEO_PRESENTATIONS_READ.value, "You don't have permission to read video presentations in this search space", @@ -89,8 +91,9 @@ async def read_video_presentations( async def read_video_presentation( video_presentation_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Get a specific video presentation by ID. Requires authentication with VIDEO_PRESENTATIONS_READ permission. @@ -112,7 +115,7 @@ async def read_video_presentation( await check_permission( session, - user, + auth, video_pres.search_space_id, Permission.VIDEO_PRESENTATIONS_READ.value, "You don't have permission to read video presentations in this search space", @@ -132,8 +135,9 @@ async def read_video_presentation( async def delete_video_presentation( video_presentation_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Delete a video presentation. Requires VIDEO_PRESENTATIONS_DELETE permission for the search space. @@ -151,7 +155,7 @@ async def delete_video_presentation( await check_permission( session, - user, + auth, db_video_pres.search_space_id, Permission.VIDEO_PRESENTATIONS_DELETE.value, "You don't have permission to delete video presentations in this search space", @@ -175,8 +179,9 @@ async def stream_slide_audio( video_presentation_id: int, slide_number: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), ): + user = auth.user """ Stream the audio file for a specific slide in a video presentation. The slide_number is 1-based. Audio path is read from the slides JSONB. @@ -194,7 +199,7 @@ async def stream_slide_audio( await check_permission( session, - user, + auth, video_pres.search_space_id, Permission.VIDEO_PRESENTATIONS_READ.value, "You don't have permission to access video presentations in this search space", diff --git a/surfsense_backend/app/routes/youtube_routes.py b/surfsense_backend/app/routes/youtube_routes.py index 9fc6d1dfc..c9d958aa8 100644 --- a/surfsense_backend/app/routes/youtube_routes.py +++ b/surfsense_backend/app/routes/youtube_routes.py @@ -8,8 +8,8 @@ import time from fastapi import APIRouter, Depends, HTTPException, Query from scrapling.fetchers import AsyncFetcher -from app.db import User -from app.users import current_active_user +from app.auth.context import AuthContext +from app.users import require_session_context from app.utils.proxy import get_proxy_url router = APIRouter() @@ -29,7 +29,7 @@ _INNERTUBE_CLIENT = { @router.get("/youtube/playlist-videos") async def get_playlist_videos( url: str = Query(..., description="YouTube playlist URL"), - _user: User = Depends(current_active_user), + _auth: AuthContext = Depends(require_session_context), ): """Resolve a YouTube playlist URL into individual video URLs.""" match = _PLAYLIST_ID_RE.search(url) diff --git a/surfsense_backend/app/schemas/__init__.py b/surfsense_backend/app/schemas/__init__.py index 7b508a132..1566310e1 100644 --- a/surfsense_backend/app/schemas/__init__.py +++ b/surfsense_backend/app/schemas/__init__.py @@ -104,6 +104,7 @@ from .search_source_connector import ( SearchSourceConnectorUpdate, ) from .search_space import ( + SearchSpaceApiAccessUpdate, SearchSpaceBase, SearchSpaceCreate, SearchSpaceRead, @@ -243,6 +244,7 @@ __all__ = [ "SearchSourceConnectorUpdate", # Search space schemas "SearchSpaceBase", + "SearchSpaceApiAccessUpdate", "SearchSpaceCreate", "SearchSpaceRead", "SearchSpaceUpdate", diff --git a/surfsense_backend/app/schemas/pat.py b/surfsense_backend/app/schemas/pat.py new file mode 100644 index 000000000..a4f70e21e --- /dev/null +++ b/surfsense_backend/app/schemas/pat.py @@ -0,0 +1,27 @@ +from datetime import datetime + +from pydantic import BaseModel, ConfigDict, Field + + +class PATCreate(BaseModel): + label: str = Field(min_length=1, max_length=120) + expires_in_days: int | None = Field(default=None, gt=0) + + +class PATCreated(BaseModel): + id: int + label: str + token: str + prefix: str + expires_at: datetime | None = None + + +class PATRead(BaseModel): + id: int + label: str + prefix: str + expires_at: datetime | None = None + last_used_at: datetime | None = None + created_at: datetime + + model_config = ConfigDict(from_attributes=True) diff --git a/surfsense_backend/app/schemas/search_space.py b/surfsense_backend/app/schemas/search_space.py index 70ed0004e..d74c46716 100644 --- a/surfsense_backend/app/schemas/search_space.py +++ b/surfsense_backend/app/schemas/search_space.py @@ -24,11 +24,16 @@ class SearchSpaceUpdate(BaseModel): ai_file_sort_enabled: bool | None = None +class SearchSpaceApiAccessUpdate(BaseModel): + api_access_enabled: bool + + class SearchSpaceRead(SearchSpaceBase, IDModel, TimestampModel): id: int created_at: datetime user_id: uuid.UUID citations_enabled: bool + api_access_enabled: bool = False qna_custom_instructions: str | None = None shared_memory_md: str | None = None ai_file_sort_enabled: bool = False diff --git a/surfsense_backend/app/services/chat_comments_service.py b/surfsense_backend/app/services/chat_comments_service.py index 905482010..b44f6f37c 100644 --- a/surfsense_backend/app/services/chat_comments_service.py +++ b/surfsense_backend/app/services/chat_comments_service.py @@ -9,6 +9,7 @@ from sqlalchemy import delete, or_, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload +from app.auth.context import AuthContext from app.db import ( ChatComment, ChatCommentMention, @@ -138,8 +139,9 @@ async def get_comment_thread_participants( async def get_comments_for_message( session: AsyncSession, message_id: int, - user: User, + auth: AuthContext, ) -> CommentListResponse: + user = auth.user """ Get all comments for a message with their replies. @@ -169,7 +171,7 @@ async def get_comments_for_message( # Check permission to read comments await check_permission( session, - user, + auth, search_space_id, Permission.COMMENTS_READ.value, "You don't have permission to read comments in this search space", @@ -268,8 +270,9 @@ async def get_comments_for_message( async def get_comments_for_messages_batch( session: AsyncSession, message_ids: list[int], - user: User, + auth: AuthContext, ) -> CommentBatchResponse: + user = auth.user """ Batch-fetch comments for multiple messages in a single DB round-trip. @@ -295,7 +298,7 @@ async def get_comments_for_messages_batch( for ss_id in search_space_ids: await check_permission( session, - user, + auth, ss_id, Permission.COMMENTS_READ.value, "You don't have permission to read comments in this search space", @@ -409,8 +412,9 @@ async def create_comment( session: AsyncSession, message_id: int, content: str, - user: User, + auth: AuthContext, ) -> CommentResponse: + user = auth.user """ Create a top-level comment on an AI response. @@ -521,8 +525,9 @@ async def create_reply( session: AsyncSession, comment_id: int, content: str, - user: User, + auth: AuthContext, ) -> CommentReplyResponse: + user = auth.user """ Create a reply to an existing comment. @@ -657,8 +662,9 @@ async def update_comment( session: AsyncSession, comment_id: int, content: str, - user: User, + auth: AuthContext, ) -> CommentReplyResponse: + user = auth.user """ Update a comment's content (author only). @@ -797,8 +803,9 @@ async def update_comment( async def delete_comment( session: AsyncSession, comment_id: int, - user: User, + auth: AuthContext, ) -> dict: + user = auth.user """ Delete a comment (author or user with COMMENTS_DELETE permission). @@ -844,9 +851,10 @@ async def delete_comment( async def get_user_mentions( session: AsyncSession, - user: User, + auth: AuthContext, search_space_id: int | None = None, ) -> MentionListResponse: + user = auth.user """ Get mentions for the current user, optionally filtered by search space. diff --git a/surfsense_backend/app/services/public_chat_service.py b/surfsense_backend/app/services/public_chat_service.py index d17f411b8..0df69de09 100644 --- a/surfsense_backend/app/services/public_chat_service.py +++ b/surfsense_backend/app/services/public_chat_service.py @@ -21,6 +21,7 @@ from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload +from app.auth.context import AuthContext from app.db import ( ChatVisibility, NewChatMessage, @@ -163,8 +164,9 @@ def compute_content_hash(messages: list[dict]) -> str: async def create_snapshot( session: AsyncSession, thread_id: int, - user: User, + auth: AuthContext, ) -> dict: + user = auth.user """ Create a public snapshot of a chat thread. @@ -186,7 +188,7 @@ async def create_snapshot( await check_permission( session, - user, + auth, thread.search_space_id, Permission.PUBLIC_SHARING_CREATE.value, "You don't have permission to create public share links", @@ -431,8 +433,9 @@ async def get_public_chat( async def list_snapshots_for_thread( session: AsyncSession, thread_id: int, - user: User, + auth: AuthContext, ) -> list[dict]: + user = auth.user """List all public snapshots for a thread.""" from app.config import config @@ -447,7 +450,7 @@ async def list_snapshots_for_thread( # Check permission to view public share links await check_permission( session, - user, + auth, thread.search_space_id, Permission.PUBLIC_SHARING_VIEW.value, "You don't have permission to view public share links", @@ -477,14 +480,15 @@ async def list_snapshots_for_thread( async def list_snapshots_for_search_space( session: AsyncSession, search_space_id: int, - user: User, + auth: AuthContext, ) -> list[dict]: + user = auth.user """List all public snapshots for a search space.""" from app.config import config await check_permission( session, - user, + auth, search_space_id, Permission.PUBLIC_SHARING_VIEW.value, "You don't have permission to view public share links", @@ -534,8 +538,9 @@ async def delete_snapshot( session: AsyncSession, thread_id: int, snapshot_id: int, - user: User, + auth: AuthContext, ) -> bool: + user = auth.user """Delete a specific snapshot. Only thread owner can delete.""" # Get snapshot with thread result = await session.execute( @@ -553,7 +558,7 @@ async def delete_snapshot( await check_permission( session, - user, + auth, snapshot.thread.search_space_id, Permission.PUBLIC_SHARING_DELETE.value, "You don't have permission to delete public share links", diff --git a/surfsense_backend/app/tasks/chat/streaming/agent/builder.py b/surfsense_backend/app/tasks/chat/streaming/agent/builder.py index dcbd37521..9d7d1b0c5 100644 --- a/surfsense_backend/app/tasks/chat/streaming/agent/builder.py +++ b/surfsense_backend/app/tasks/chat/streaming/agent/builder.py @@ -13,6 +13,7 @@ from app.agents.chat.multi_agent_chat.shared.filesystem_selection import ( FilesystemSelection, ) from app.agents.chat.runtime.llm_config import AgentConfig +from app.auth.context import AuthContext from app.db import ChatVisibility from app.services.connector_service import ConnectorService @@ -33,6 +34,7 @@ async def build_main_agent_for_thread( filesystem_selection: FilesystemSelection | None, disabled_tools: list[str] | None = None, mentioned_document_ids: list[int] | None = None, + auth_context: AuthContext | None = None, ) -> Any: return await agent_factory( llm=llm, @@ -48,4 +50,5 @@ async def build_main_agent_for_thread( filesystem_selection=filesystem_selection, disabled_tools=disabled_tools, mentioned_document_ids=mentioned_document_ids, + auth_context=auth_context, ) diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/orchestrator.py b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/orchestrator.py index 1e6097e53..69343ffa4 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/orchestrator.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/orchestrator.py @@ -35,6 +35,7 @@ from app.agents.chat.multi_agent_chat.shared.filesystem_selection import ( FilesystemMode, FilesystemSelection, ) +from app.auth.context import AuthContext from app.db import ChatVisibility, async_session_maker from app.observability import otel as ot from app.services.new_streaming_service import VercelStreamingService @@ -136,6 +137,7 @@ async def stream_new_chat( filesystem_selection: FilesystemSelection | None = None, request_id: str | None = None, user_image_data_urls: list[str] | None = None, + auth_context: AuthContext | None = None, flow: Literal["new", "regenerate"] = "new", ) -> AsyncGenerator[str, None]: """Stream a new chat turn using the SurfSense deep agent. @@ -412,6 +414,7 @@ async def stream_new_chat( filesystem_selection=filesystem_selection, disabled_tools=disabled_tools, mentioned_document_ids=mentioned_document_ids, + auth_context=auth_context, ) _perf_log.info( "[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0 @@ -664,6 +667,7 @@ async def stream_new_chat( filesystem_selection=filesystem_selection, disabled_tools=disabled_tools, mentioned_document_ids=mentioned_document_ids, + auth_context=auth_context, ) _perf_log.info( "[stream_new_chat] Runtime rate-limit recovery repinned " diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/orchestrator.py b/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/orchestrator.py index e1552e79e..33fcee3da 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/orchestrator.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/orchestrator.py @@ -29,6 +29,7 @@ from app.agents.chat.multi_agent_chat.shared.filesystem_selection import ( FilesystemMode, FilesystemSelection, ) +from app.auth.context import AuthContext from app.db import ChatVisibility, async_session_maker from app.observability import otel as ot from app.services.chat_session_state_service import set_ai_responding @@ -102,6 +103,7 @@ async def stream_resume_chat( filesystem_selection: FilesystemSelection | None = None, request_id: str | None = None, disabled_tools: list[str] | None = None, + auth_context: AuthContext | None = None, ) -> AsyncGenerator[str, None]: """Resume a paused HITL turn with the user's decisions. @@ -346,6 +348,7 @@ async def stream_resume_chat( thread_visibility=visibility, filesystem_selection=filesystem_selection, disabled_tools=disabled_tools, + auth_context=auth_context, ) _perf_log.info( "[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0 @@ -481,6 +484,7 @@ async def stream_resume_chat( thread_visibility=visibility, filesystem_selection=filesystem_selection, disabled_tools=disabled_tools, + auth_context=auth_context, ) _perf_log.info( "[stream_resume] Runtime rate-limit recovery repinned " diff --git a/surfsense_backend/app/users.py b/surfsense_backend/app/users.py index 66e0cc8dd..d668dba45 100644 --- a/surfsense_backend/app/users.py +++ b/surfsense_backend/app/users.py @@ -3,7 +3,7 @@ import uuid from datetime import UTC, datetime import httpx -from fastapi import Depends, Request, Response +from fastapi import Depends, HTTPException, Request, Response, status from fastapi.responses import JSONResponse, RedirectResponse from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, models from fastapi_users.authentication import ( @@ -14,7 +14,9 @@ from fastapi_users.authentication import ( from fastapi_users.db import SQLAlchemyUserDatabase from pydantic import BaseModel from sqlalchemy import update +from sqlalchemy.ext.asyncio import AsyncSession +from app.auth.context import AuthContext from app.config import config from app.db import ( Prompt, @@ -23,10 +25,12 @@ from app.db import ( SearchSpaceRole, User, async_session_maker, + get_async_session, get_default_roles_config, get_user_db, ) from app.prompts.system_defaults import SYSTEM_PROMPT_DEFAULTS +from app.utils.pat import PAT_PREFIX, maybe_touch_last_used, resolve_pat from app.utils.refresh_tokens import create_refresh_token logger = logging.getLogger(__name__) @@ -298,5 +302,75 @@ auth_backend = AuthenticationBackend( fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend]) -current_active_user = fastapi_users.current_user(active=True) + +async def get_auth_context( + request: Request, + session: AsyncSession = Depends(get_async_session), + user_manager: UserManager = Depends(get_user_manager), +) -> AuthContext: + """Resolve the authenticated principal. + + Use this for authorization-sensitive routes where session-vs-PAT matters. + FastAPI-Users still handles JWT mechanics; PATs are resolved here so RBAC + receives the full SurfSense principal instead of a bare User. + """ + auth_header = request.headers.get("Authorization") + if not auth_header: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Unauthorized", + ) + + scheme, _, token = auth_header.partition(" ") + if scheme.lower() != "bearer" or not token: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Unauthorized", + ) + + if token.startswith(PAT_PREFIX): + pat = await resolve_pat(session, token) + if pat and pat.user and pat.user.is_active: + maybe_touch_last_used(pat) + return AuthContext.pat_auth(pat.user, pat) + + try: + user = await get_jwt_strategy().read_token(token, user_manager) + except Exception: + logger.exception("Failed to read access token") + user = None + + if not user or not user.is_active: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Unauthorized", + ) + + return AuthContext.session(user) + + +async def allow_any_principal( + auth: AuthContext = Depends(get_auth_context), +) -> AuthContext: + """Allow either session or PAT principals for bootstrap probes only. + + Routes using this dependency intentionally have no search-space gate. + Adding a new call site is a security decision and must be covered by + the fail-closed PAT allowlist test. + """ + return auth + + +async def require_session_context( + auth: AuthContext = Depends(get_auth_context), +) -> AuthContext: + """Require an interactive session and reject PAT-authenticated requests.""" + if not auth.is_session: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="This action requires an interactive session", + ) + return auth + + current_optional_user = fastapi_users.current_user(active=True, optional=True) diff --git a/surfsense_backend/app/utils/pat.py b/surfsense_backend/app/utils/pat.py new file mode 100644 index 000000000..46e3d4d08 --- /dev/null +++ b/surfsense_backend/app/utils/pat.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +import asyncio +import hashlib +import logging +import secrets +from datetime import UTC, datetime, timedelta + +from sqlalchemy import update +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from sqlalchemy.orm import selectinload + +from app.db import PersonalAccessToken, User, async_session_maker + +logger = logging.getLogger(__name__) + +PAT_PREFIX = "ss_pat_" +PAT_TOKEN_BYTES = 32 +LAST_USED_THROTTLE = timedelta(minutes=10) + + +def generate_pat() -> str: + return f"{PAT_PREFIX}{secrets.token_urlsafe(PAT_TOKEN_BYTES)}" + + +def hash_pat(token: str) -> str: + return hashlib.sha256(token.encode()).hexdigest() + + +def token_prefix(token: str) -> str: + return token[:16] + + +async def resolve_pat( + session: AsyncSession, + token: str, +) -> PersonalAccessToken | None: + now = datetime.now(UTC) + result = await session.execute( + select(PersonalAccessToken) + .options(selectinload(PersonalAccessToken.user)) + .join(User) + .where( + PersonalAccessToken.token_hash == hash_pat(token), + (PersonalAccessToken.expires_at.is_(None)) + | (PersonalAccessToken.expires_at > now), + User.is_active == True, # noqa: E712 + ) + ) + return result.scalars().first() + + +async def _touch_last_used(token_id: int) -> None: + try: + async with async_session_maker() as session: + await session.execute( + update(PersonalAccessToken) + .where(PersonalAccessToken.id == token_id) + .values(last_used_at=datetime.now(UTC)) + ) + await session.commit() + except Exception: + logger.exception("Failed to update PAT last_used_at for token %s", token_id) + + +def maybe_touch_last_used(pat: PersonalAccessToken) -> None: + last_used_at = pat.last_used_at + now = datetime.now(UTC) + if last_used_at is not None and now - last_used_at < LAST_USED_THROTTLE: + return + + asyncio.create_task(_touch_last_used(pat.id)) diff --git a/surfsense_backend/app/utils/rbac.py b/surfsense_backend/app/utils/rbac.py index 6cb180d80..8777f09f6 100644 --- a/surfsense_backend/app/utils/rbac.py +++ b/surfsense_backend/app/utils/rbac.py @@ -11,12 +11,12 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy.orm import selectinload +from app.auth.context import AuthContext from app.db import ( Permission, SearchSpace, SearchSpaceMembership, SearchSpaceRole, - User, has_permission, ) @@ -80,9 +80,33 @@ async def get_user_permissions( return [] +async def _enforce_api_access_gate( + session: AsyncSession, + auth: AuthContext, + search_space_id: int, + search_space: SearchSpace | None = None, +) -> SearchSpace: + if search_space is None: + result = await session.execute( + select(SearchSpace).filter(SearchSpace.id == search_space_id) + ) + search_space = result.scalars().first() + + if not search_space: + raise HTTPException(status_code=404, detail="Search space not found") + + if auth.is_gated and not search_space.api_access_enabled: + raise HTTPException( + status_code=403, + detail="API access is not enabled for this search space.", + ) + + return search_space + + async def check_permission( session: AsyncSession, - user: User, + auth: AuthContext, search_space_id: int, required_permission: str, error_message: str = "You don't have permission to perform this action", @@ -104,7 +128,7 @@ async def check_permission( Raises: HTTPException: If user doesn't have access or permission """ - membership = await get_user_membership(session, user.id, search_space_id) + membership = await get_user_membership(session, auth.user.id, search_space_id) if not membership: raise HTTPException( @@ -123,12 +147,14 @@ async def check_permission( if not has_permission(permissions, required_permission): raise HTTPException(status_code=403, detail=error_message) + await _enforce_api_access_gate(session, auth, search_space_id) + return membership async def check_search_space_access( session: AsyncSession, - user: User, + auth: AuthContext, search_space_id: int, ) -> SearchSpaceMembership: """ @@ -146,7 +172,7 @@ async def check_search_space_access( Raises: HTTPException: If user doesn't have access """ - membership = await get_user_membership(session, user.id, search_space_id) + membership = await get_user_membership(session, auth.user.id, search_space_id) if not membership: raise HTTPException( @@ -154,6 +180,8 @@ async def check_search_space_access( detail="You don't have access to this search space", ) + await _enforce_api_access_gate(session, auth, search_space_id) + return membership @@ -179,7 +207,7 @@ async def is_search_space_owner( async def get_search_space_with_access_check( session: AsyncSession, - user: User, + auth: AuthContext, search_space_id: int, required_permission: str | None = None, ) -> tuple[SearchSpace, SearchSpaceMembership]: @@ -210,10 +238,12 @@ async def get_search_space_with_access_check( # Check access if required_permission: membership = await check_permission( - session, user, search_space_id, required_permission + session, auth, search_space_id, required_permission ) else: - membership = await check_search_space_access(session, user, search_space_id) + membership = await check_search_space_access(session, auth, search_space_id) + + await _enforce_api_access_gate(session, auth, search_space_id, search_space) return search_space, membership diff --git a/surfsense_backend/tests/README.md b/surfsense_backend/tests/README.md index 5764252a5..23b161f99 100644 --- a/surfsense_backend/tests/README.md +++ b/surfsense_backend/tests/README.md @@ -42,7 +42,7 @@ Maximize logic covered by unit tests; keep integration tests for what genuinely `conftest.py` is scoped to its directory and below. Keep truly global fixtures in `tests/conftest.py`; put module-specific fixtures in that module's `conftest.py` so a DB fixture never loads for a pure unit test. -For API integration tests, override `get_async_session` and `current_active_user` to ride the test's transactional `db_session` (see `tests/integration/notifications/conftest.py`): rows seeded in the test and rows read via the endpoint share one transaction that rolls back automatically. +For API integration tests, override `get_async_session` and `get_auth_context` to ride the test's transactional `db_session` (see `tests/integration/notifications/conftest.py`): rows seeded in the test and rows read via the endpoint share one transaction that rolls back automatically. ## Import mode diff --git a/surfsense_backend/tests/integration/chat/test_append_message_recovery.py b/surfsense_backend/tests/integration/chat/test_append_message_recovery.py index a5182a978..c6a40c356 100644 --- a/surfsense_backend/tests/integration/chat/test_append_message_recovery.py +++ b/surfsense_backend/tests/integration/chat/test_append_message_recovery.py @@ -40,6 +40,7 @@ import pytest_asyncio from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession +from app.auth.context import AuthContext from app.db import ( ChatVisibility, NewChatMessage, @@ -395,7 +396,7 @@ class TestAppendMessageRecoveryAfterFinalize: thread_id=thread_id, request=request, session=db_session, - user=db_user, + auth=AuthContext.session(db_user), ) # Response must echo the SERVER's rich payload, not the FE's @@ -469,7 +470,7 @@ class TestAppendMessageRecoveryAfterFinalize: thread_id=thread_id, request=_FakeRequest(fe_request_body), session=db_session, - user=db_user, + auth=AuthContext.session(db_user), ) assert fe_response.role == NewChatMessageRole.ASSISTANT @@ -552,7 +553,7 @@ class TestAppendMessageRecoveryAfterFinalize: } ), session=db_session, - user=db_user, + auth=AuthContext.session(db_user), ) assert ok_response.role == NewChatMessageRole.USER assert ok_response.turn_id is None diff --git a/surfsense_backend/tests/integration/chat/test_thread_visibility.py b/surfsense_backend/tests/integration/chat/test_thread_visibility.py index 464d389db..ba6f2a66f 100644 --- a/surfsense_backend/tests/integration/chat/test_thread_visibility.py +++ b/surfsense_backend/tests/integration/chat/test_thread_visibility.py @@ -16,6 +16,7 @@ from fastapi import HTTPException from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from app.auth.context import AuthContext from app.db import ( ChatVisibility, SearchSpace, @@ -33,6 +34,10 @@ from app.schemas.new_chat import ( pytestmark = pytest.mark.integration +def _auth(user: User) -> AuthContext: + return AuthContext.session(user) + + @pytest_asyncio.fixture async def db_member(db_session: AsyncSession, db_search_space: SearchSpace) -> User: member = User( @@ -85,7 +90,7 @@ async def _create_thread( visibility=ChatVisibility.PRIVATE, ), session=db_session, - user=db_user, + auth=_auth(db_user), ) @@ -108,13 +113,13 @@ async def test_private_thread_is_hidden_from_other_search_space_member( member_threads = await new_chat_routes.list_threads( search_space_id=db_search_space.id, session=db_session, - user=db_member, + auth=_auth(db_member), ) member_search = await new_chat_routes.search_threads( search_space_id=db_search_space.id, title="Visibility", session=db_session, - user=db_member, + auth=_auth(db_member), ) assert thread.id not in _active_thread_ids(member_threads) @@ -123,7 +128,7 @@ async def test_private_thread_is_hidden_from_other_search_space_member( await new_chat_routes.get_thread_full( thread_id=thread.id, session=db_session, - user=db_member, + auth=_auth(db_member), ) assert exc_info.value.status_code == 403 @@ -142,24 +147,24 @@ async def test_creator_can_share_thread_and_member_can_list_search_read_it( visibility=ChatVisibility.SEARCH_SPACE, ), session=db_session, - user=db_user, + auth=_auth(db_user), ) member_threads = await new_chat_routes.list_threads( search_space_id=db_search_space.id, session=db_session, - user=db_member, + auth=_auth(db_member), ) member_search = await new_chat_routes.search_threads( search_space_id=db_search_space.id, title="Visibility", session=db_session, - user=db_member, + auth=_auth(db_member), ) full_thread = await new_chat_routes.get_thread_full( thread_id=thread.id, session=db_session, - user=db_member, + auth=_auth(db_member), ) assert updated.visibility == ChatVisibility.SEARCH_SPACE @@ -181,20 +186,20 @@ async def test_rename_and_archive_do_not_reset_shared_visibility( visibility=ChatVisibility.SEARCH_SPACE, ), session=db_session, - user=db_user, + auth=_auth(db_user), ) renamed = await new_chat_routes.update_thread( thread_id=thread.id, thread_update=NewChatThreadUpdate(title="Renamed Shared Chat"), session=db_session, - user=db_user, + auth=_auth(db_user), ) archived = await new_chat_routes.update_thread( thread_id=thread.id, thread_update=NewChatThreadUpdate(archived=True), session=db_session, - user=db_user, + auth=_auth(db_user), ) assert renamed.visibility == ChatVisibility.SEARCH_SPACE @@ -215,7 +220,7 @@ async def test_non_creator_cannot_change_shared_thread_back_to_private( visibility=ChatVisibility.SEARCH_SPACE, ), session=db_session, - user=db_user, + auth=_auth(db_user), ) with pytest.raises(HTTPException) as exc_info: @@ -225,7 +230,7 @@ async def test_non_creator_cannot_change_shared_thread_back_to_private( visibility=ChatVisibility.PRIVATE, ), session=db_session, - user=db_member, + auth=_auth(db_member), ) assert exc_info.value.status_code == 403 @@ -244,7 +249,7 @@ async def test_creator_can_make_shared_thread_private_again( visibility=ChatVisibility.SEARCH_SPACE, ), session=db_session, - user=db_user, + auth=_auth(db_user), ) private_again = await new_chat_routes.update_thread_visibility( @@ -253,18 +258,18 @@ async def test_creator_can_make_shared_thread_private_again( visibility=ChatVisibility.PRIVATE, ), session=db_session, - user=db_user, + auth=_auth(db_user), ) member_threads = await new_chat_routes.list_threads( search_space_id=db_search_space.id, session=db_session, - user=db_member, + auth=_auth(db_member), ) member_search = await new_chat_routes.search_threads( search_space_id=db_search_space.id, title="Visibility", session=db_session, - user=db_member, + auth=_auth(db_member), ) assert private_again.visibility == ChatVisibility.PRIVATE @@ -274,6 +279,6 @@ async def test_creator_can_make_shared_thread_private_again( await new_chat_routes.get_thread_full( thread_id=thread.id, session=db_session, - user=db_member, + auth=_auth(db_member), ) assert exc_info.value.status_code == 403 diff --git a/surfsense_backend/tests/integration/composio/conftest.py b/surfsense_backend/tests/integration/composio/conftest.py index 44d707ec3..578b5b228 100644 --- a/surfsense_backend/tests/integration/composio/conftest.py +++ b/surfsense_backend/tests/integration/composio/conftest.py @@ -15,6 +15,7 @@ from httpx import ASGITransport from sqlalchemy.ext.asyncio import AsyncSession from app.app import app, limiter +from app.auth.context import AuthContext from app.config import config from app.db import ( SearchSourceConnector, @@ -22,7 +23,7 @@ from app.db import ( User, get_async_session, ) -from app.users import current_active_user +from app.users import get_auth_context pytestmark = pytest.mark.integration @@ -40,12 +41,12 @@ async def client( async def override_session() -> AsyncGenerator[AsyncSession, None]: yield db_session - async def override_user() -> User: - return db_user + async def override_auth() -> AuthContext: + return AuthContext.session(db_user) previous_overrides = app.dependency_overrides.copy() app.dependency_overrides[get_async_session] = override_session - app.dependency_overrides[current_active_user] = override_user + app.dependency_overrides[get_auth_context] = override_auth try: async with httpx.AsyncClient( diff --git a/surfsense_backend/tests/integration/notifications/conftest.py b/surfsense_backend/tests/integration/notifications/conftest.py index 17a44a51d..e410d0d55 100644 --- a/surfsense_backend/tests/integration/notifications/conftest.py +++ b/surfsense_backend/tests/integration/notifications/conftest.py @@ -1,9 +1,9 @@ """Notifications integration fixtures. -The app's DB session and current-user dependencies are overridden to ride the +The app's DB session and auth-context dependencies are overridden to ride the test's transactional `db_session`, so API calls and seeded rows share one -transaction that rolls back per test. Overriding `current_active_user` also -bypasses real JWT auth, so these tests don't depend on AUTH_TYPE. +transaction that rolls back per test. Overriding `get_auth_context` also bypasses +real JWT auth, so these tests don't depend on AUTH_TYPE. """ from __future__ import annotations @@ -17,8 +17,9 @@ from httpx import ASGITransport from sqlalchemy.ext.asyncio import AsyncSession from app.app import app, limiter +from app.auth.context import AuthContext from app.db import User, get_async_session -from app.users import current_active_user +from app.users import get_auth_context pytestmark = pytest.mark.integration @@ -33,12 +34,12 @@ async def client( async def override_session() -> AsyncGenerator[AsyncSession, None]: yield db_session - async def override_user() -> User: - return db_user + async def override_auth() -> AuthContext: + return AuthContext.session(db_user) previous_overrides = app.dependency_overrides.copy() app.dependency_overrides[get_async_session] = override_session - app.dependency_overrides[current_active_user] = override_user + app.dependency_overrides[get_auth_context] = override_auth try: async with httpx.AsyncClient( diff --git a/surfsense_backend/tests/integration/podcasts/conftest.py b/surfsense_backend/tests/integration/podcasts/conftest.py index 75248a6a1..067924ad5 100644 --- a/surfsense_backend/tests/integration/podcasts/conftest.py +++ b/surfsense_backend/tests/integration/podcasts/conftest.py @@ -24,6 +24,7 @@ from httpx import ASGITransport from sqlalchemy.ext.asyncio import AsyncSession from app.app import app, limiter +from app.auth.context import AuthContext from app.config import config as app_config from app.db import SearchSpace, User, get_async_session from app.podcasts.persistence import Podcast, PodcastStatus @@ -39,7 +40,7 @@ from app.podcasts.schemas import ( from app.podcasts.service import PodcastService from app.podcasts.tts import SynthesisRequest, SynthesizedAudio, TextToSpeech from app.routes.search_spaces_routes import create_default_roles_and_membership -from app.users import current_active_user +from app.users import get_auth_context pytestmark = pytest.mark.integration @@ -54,12 +55,12 @@ async def client( async def override_session() -> AsyncGenerator[AsyncSession, None]: yield db_session - async def override_user() -> User: - return db_user + async def override_auth() -> AuthContext: + return AuthContext.session(db_user) previous_overrides = app.dependency_overrides.copy() app.dependency_overrides[get_async_session] = override_session - app.dependency_overrides[current_active_user] = override_user + app.dependency_overrides[get_auth_context] = override_auth try: async with httpx.AsyncClient( @@ -290,7 +291,7 @@ def act_as(): """ def _act(user: User) -> None: - app.dependency_overrides[current_active_user] = lambda: user + app.dependency_overrides[get_auth_context] = lambda: AuthContext.session(user) return _act diff --git a/surfsense_backend/tests/integration/test_connector_index_authz.py b/surfsense_backend/tests/integration/test_connector_index_authz.py index cea2407cc..b25df7087 100644 --- a/surfsense_backend/tests/integration/test_connector_index_authz.py +++ b/surfsense_backend/tests/integration/test_connector_index_authz.py @@ -23,6 +23,7 @@ import pytest from fastapi import HTTPException from sqlalchemy.ext.asyncio import AsyncSession +from app.auth.context import AuthContext from app.db import ( SearchSourceConnector, SearchSourceConnectorType, @@ -109,7 +110,7 @@ class TestConnectorIndexCrossSpaceAuthz: connector_id=connector_a.id, search_space_id=space_b.id, # the attacker's own space session=db_session, - user=attacker, + auth=AuthContext.session(attacker), ) assert exc_info.value.status_code == 404 @@ -140,7 +141,7 @@ class TestConnectorIndexCrossSpaceAuthz: connector_id=connector.id, search_space_id=space.id, # the connector's own space session=db_session, - user=owner, + auth=AuthContext.session(owner), ) check_permission_mock.assert_awaited_once() diff --git a/surfsense_backend/tests/integration/test_obsidian_plugin_routes.py b/surfsense_backend/tests/integration/test_obsidian_plugin_routes.py index 22f6c6de5..d56c18420 100644 --- a/surfsense_backend/tests/integration/test_obsidian_plugin_routes.py +++ b/surfsense_backend/tests/integration/test_obsidian_plugin_routes.py @@ -28,6 +28,7 @@ from sqlalchemy import func, select, text from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession +from app.auth.context import AuthContext from app.db import ( SearchSourceConnector, SearchSourceConnectorType, @@ -42,6 +43,7 @@ from app.routes.obsidian_plugin_routes import ( obsidian_stats, obsidian_sync, ) +from app.routes.search_spaces_routes import create_default_roles_and_membership from app.schemas.obsidian_plugin import ( ConnectRequest, DeleteAck, @@ -65,6 +67,10 @@ pytestmark = pytest.mark.integration # --------------------------------------------------------------------------- +def _auth(user: User) -> AuthContext: + return AuthContext.session(user) + + def _make_note_payload(vault_id: str, path: str, content_hash: str) -> NotePayload: """Minimal NotePayload that the schema accepts; the indexer is mocked out so the values don't have to round-trip through the real pipeline.""" @@ -102,6 +108,8 @@ async def race_user_and_space(async_engine): ) space = SearchSpace(name="Race Space", user_id=user_id) setup.add_all([user, space]) + await setup.flush() + await create_default_roles_and_membership(setup, space.id, user_id) await setup.commit() await setup.refresh(space) space_id = space.id @@ -116,6 +124,14 @@ async def race_user_and_space(async_engine): text("DELETE FROM search_source_connectors WHERE user_id = :uid"), {"uid": user_id}, ) + await cleanup.execute( + text("DELETE FROM search_space_memberships WHERE search_space_id = :id"), + {"id": space_id}, + ) + await cleanup.execute( + text("DELETE FROM search_space_roles WHERE search_space_id = :id"), + {"id": space_id}, + ) await cleanup.execute( text("DELETE FROM searchspaces WHERE id = :id"), {"id": space_id}, @@ -154,7 +170,7 @@ class TestConnectRace: search_space_id=space_id, vault_fingerprint=fingerprint, ) - await obsidian_connect(payload, user=fresh_user, session=s) + await obsidian_connect(payload, auth=_auth(fresh_user), session=s) results = await asyncio.gather(_call("a"), _call("b"), return_exceptions=True) for r in results: @@ -281,7 +297,7 @@ class TestConnectRace: search_space_id=space_id, vault_fingerprint=fingerprint, ), - user=fresh_user, + auth=_auth(fresh_user), session=s, ) @@ -294,7 +310,7 @@ class TestConnectRace: search_space_id=space_id, vault_fingerprint=fingerprint, ), - user=fresh_user, + auth=_auth(fresh_user), session=s, ) @@ -337,7 +353,7 @@ class TestWireContractSmoke: search_space_id=db_search_space.id, vault_fingerprint="fp-" + uuid.uuid4().hex, ), - user=db_user, + auth=_auth(db_user), session=db_session, ) assert connect_resp.connector_id > 0 @@ -361,7 +377,7 @@ class TestWireContractSmoke: _make_note_payload(vault_id, "fail.md", "hash-fail"), ], ), - user=db_user, + auth=_auth(db_user), session=db_session, ) @@ -394,7 +410,7 @@ class TestWireContractSmoke: _make_note_payload(vault_id, "fail.md", "h2"), ], ), - user=db_user, + auth=_auth(db_user), session=db_session, ) assert sync_resp.indexed == 1 @@ -420,7 +436,7 @@ class TestWireContractSmoke: RenameItem(old_path="missing.md", new_path="x.md"), ], ), - user=db_user, + auth=_auth(db_user), session=db_session, ) assert isinstance(rename_resp, RenameAck) @@ -441,7 +457,7 @@ class TestWireContractSmoke: ): delete_resp = await obsidian_delete_notes( DeleteBatchRequest(vault_id=vault_id, paths=["b.md", "ghost.md"]), - user=db_user, + auth=_auth(db_user), session=db_session, ) assert isinstance(delete_resp, DeleteAck) @@ -456,7 +472,7 @@ class TestWireContractSmoke: # upsert_note was mocked) but the response shape is what we care # about. manifest_resp = await obsidian_manifest( - vault_id=vault_id, user=db_user, session=db_session + vault_id=vault_id, auth=_auth(db_user), session=db_session ) assert isinstance(manifest_resp, ManifestResponse) assert manifest_resp.vault_id == vault_id @@ -464,7 +480,7 @@ class TestWireContractSmoke: # 6. /stats — same; row count is 0 because upsert_note was mocked. stats_resp = await obsidian_stats( - vault_id=vault_id, user=db_user, session=db_session + vault_id=vault_id, auth=_auth(db_user), session=db_session ) assert isinstance(stats_resp, StatsResponse) assert stats_resp.vault_id == vault_id @@ -482,7 +498,7 @@ class TestWireContractSmoke: search_space_id=db_search_space.id, vault_fingerprint="fp-" + uuid.uuid4().hex, ), - user=db_user, + auth=_auth(db_user), session=db_session, ) @@ -511,7 +527,7 @@ class TestWireContractSmoke: binary_note, ], ), - user=db_user, + auth=_auth(db_user), session=db_session, ) @@ -533,7 +549,7 @@ class TestWireContractSmoke: search_space_id=db_search_space.id, vault_fingerprint="fp-" + uuid.uuid4().hex, ), - user=db_user, + auth=_auth(db_user), session=db_session, ) @@ -562,7 +578,7 @@ class TestWireContractSmoke: bad_note, ], ), - user=db_user, + auth=_auth(db_user), session=db_session, ) @@ -587,7 +603,7 @@ class TestWireContractSmoke: search_space_id=db_search_space.id, vault_fingerprint="fp-" + uuid.uuid4().hex, ), - user=db_user, + auth=_auth(db_user), session=db_session, ) @@ -616,7 +632,7 @@ class TestWireContractSmoke: mismatched, ], ), - user=db_user, + auth=_auth(db_user), session=db_session, ) diff --git a/surfsense_backend/tests/integration/test_pat_fail_closed_authz.py b/surfsense_backend/tests/integration/test_pat_fail_closed_authz.py new file mode 100644 index 000000000..5bec3f48a --- /dev/null +++ b/surfsense_backend/tests/integration/test_pat_fail_closed_authz.py @@ -0,0 +1,71 @@ +"""Runtime smoke tests for fail-closed PAT authorization primitives.""" + +from __future__ import annotations + +import pytest +from fastapi import HTTPException +from sqlalchemy.ext.asyncio import AsyncSession + +from app.auth.context import AuthContext +from app.db import PersonalAccessToken, SearchSpace, User +from app.users import allow_any_principal, require_session_context +from app.utils.rbac import check_search_space_access + +pytestmark = pytest.mark.integration + + +def _pat_auth(user: User) -> AuthContext: + pat = PersonalAccessToken( + user_id=user.id, + user=user, + token_hash="0" * 64, + token_prefix="ss_pat_test", + label="Test PAT", + ) + return AuthContext.pat_auth(user, pat) + + +async def test_pat_is_rejected_by_session_only_dependency(db_user: User): + auth = _pat_auth(db_user) + + with pytest.raises(HTTPException) as exc_info: + await require_session_context(auth=auth) + + assert exc_info.value.status_code == 403 + + +async def test_pat_is_allowed_by_bootstrap_dependency(db_user: User): + auth = _pat_auth(db_user) + + assert await allow_any_principal(auth=auth) is auth + + +async def test_pat_is_rejected_for_api_disabled_space( + db_session: AsyncSession, + db_user: User, + db_search_space: SearchSpace, +): + db_search_space.api_access_enabled = False + await db_session.flush() + auth = _pat_auth(db_user) + + with pytest.raises(HTTPException) as exc_info: + await check_search_space_access(db_session, auth, db_search_space.id) + + assert exc_info.value.status_code == 403 + assert exc_info.value.detail == "API access is not enabled for this search space." + + +async def test_pat_is_allowed_for_api_enabled_space( + db_session: AsyncSession, + db_user: User, + db_search_space: SearchSpace, +): + db_search_space.api_access_enabled = True + await db_session.flush() + auth = _pat_auth(db_user) + + membership = await check_search_space_access(db_session, auth, db_search_space.id) + + assert membership.user_id == db_user.id + assert membership.search_space_id == db_search_space.id diff --git a/surfsense_backend/tests/unit/automations/services/test_automation_service_policy.py b/surfsense_backend/tests/unit/automations/services/test_automation_service_policy.py index c97dec6a2..477d927e2 100644 --- a/surfsense_backend/tests/unit/automations/services/test_automation_service_policy.py +++ b/surfsense_backend/tests/unit/automations/services/test_automation_service_policy.py @@ -15,6 +15,7 @@ import pytest from fastapi import HTTPException import app.automations.services.automation as automation_mod +from app.auth.context import AuthContext from app.automations.schemas.api import AutomationCreate, AutomationUpdate from app.automations.schemas.definition.envelope import ( AutomationDefinition, @@ -45,7 +46,8 @@ class _FakeSession: def _service(search_space: Any) -> AutomationService: return AutomationService( - session=_FakeSession(search_space), user=SimpleNamespace(id="u-1") + session=_FakeSession(search_space), + auth=AuthContext.session(SimpleNamespace(id="u-1")), ) diff --git a/surfsense_backend/tests/unit/gateway/test_byo_long_poll_lifespan.py b/surfsense_backend/tests/unit/gateway/test_byo_long_poll_lifespan.py index de4386abb..c184af601 100644 --- a/surfsense_backend/tests/unit/gateway/test_byo_long_poll_lifespan.py +++ b/surfsense_backend/tests/unit/gateway/test_byo_long_poll_lifespan.py @@ -38,7 +38,9 @@ async def cleanup_supervisors(): @pytest.mark.asyncio async def test_start_byo_long_poll_noops_when_mode_is_webhook(monkeypatch): + monkeypatch.setattr(byo_long_poll.config, "GATEWAY_ENABLED", True) monkeypatch.setattr(byo_long_poll.config, "GATEWAY_TELEGRAM_INTAKE_MODE", "webhook") + monkeypatch.setattr(byo_long_poll.config, "GATEWAY_WHATSAPP_INTAKE_MODE", "disabled") await byo_long_poll.start_byo_long_poll_supervisors() @@ -47,9 +49,11 @@ async def test_start_byo_long_poll_noops_when_mode_is_webhook(monkeypatch): @pytest.mark.asyncio async def test_start_byo_long_poll_noops_when_no_byo_accounts(mocker, monkeypatch): + monkeypatch.setattr(byo_long_poll.config, "GATEWAY_ENABLED", True) monkeypatch.setattr( byo_long_poll.config, "GATEWAY_TELEGRAM_INTAKE_MODE", "longpoll" ) + monkeypatch.setattr(byo_long_poll.config, "GATEWAY_WHATSAPP_INTAKE_MODE", "disabled") session = mocker.AsyncMock() session.execute.return_value = ScalarResult([]) monkeypatch.setattr( @@ -67,9 +71,11 @@ async def test_start_byo_long_poll_noops_when_no_byo_accounts(mocker, monkeypatc async def test_start_byo_long_poll_spawns_one_supervisor_per_account( mocker, monkeypatch ): + monkeypatch.setattr(byo_long_poll.config, "GATEWAY_ENABLED", True) monkeypatch.setattr( byo_long_poll.config, "GATEWAY_TELEGRAM_INTAKE_MODE", "longpoll" ) + monkeypatch.setattr(byo_long_poll.config, "GATEWAY_WHATSAPP_INTAKE_MODE", "disabled") accounts = [mocker.Mock(id=1), mocker.Mock(id=2)] session = mocker.AsyncMock() session.execute.return_value = ScalarResult(accounts) @@ -115,9 +121,11 @@ async def test_supervisor_retries_after_run_returns(mocker, monkeypatch): @pytest.mark.asyncio async def test_shutdown_cancels_running_supervisors(mocker, monkeypatch): + monkeypatch.setattr(byo_long_poll.config, "GATEWAY_ENABLED", True) monkeypatch.setattr( byo_long_poll.config, "GATEWAY_TELEGRAM_INTAKE_MODE", "longpoll" ) + monkeypatch.setattr(byo_long_poll.config, "GATEWAY_WHATSAPP_INTAKE_MODE", "disabled") session = mocker.AsyncMock() session.execute.return_value = ScalarResult([mocker.Mock(id=1)]) monkeypatch.setattr( diff --git a/surfsense_backend/tests/unit/gateway/test_inbox_worker.py b/surfsense_backend/tests/unit/gateway/test_inbox_worker.py index 1e5b2a184..0ee661102 100644 --- a/surfsense_backend/tests/unit/gateway/test_inbox_worker.py +++ b/surfsense_backend/tests/unit/gateway/test_inbox_worker.py @@ -27,6 +27,7 @@ async def test_inbox_worker_claims_and_processes_in_fastapi_process( async def test_start_stop_gateway_inbox_worker(mocker, monkeypatch): started = asyncio.Event() stopped = asyncio.Event() + monkeypatch.setattr(inbox_worker.config, "GATEWAY_ENABLED", True) async def run_forever(): started.set() diff --git a/surfsense_backend/tests/unit/gateway/test_webhook_routes.py b/surfsense_backend/tests/unit/gateway/test_webhook_routes.py index aa8bd3a89..354c3037d 100644 --- a/surfsense_backend/tests/unit/gateway/test_webhook_routes.py +++ b/surfsense_backend/tests/unit/gateway/test_webhook_routes.py @@ -9,6 +9,7 @@ from types import SimpleNamespace import pytest +from app.auth.context import AuthContext from app.db import ExternalChatAccount, ExternalChatAccountMode, ExternalChatPlatform from app.routes import gateway_webhook_routes as routes @@ -333,7 +334,9 @@ async def test_discord_gateway_install_returns_oauth_url(monkeypatch, mocker): response = await routes.install_discord_gateway( search_space_id=123, - user=SimpleNamespace(id="00000000-0000-0000-0000-000000000001"), + auth=AuthContext.session( + SimpleNamespace(id="00000000-0000-0000-0000-000000000001") + ), session=mocker.AsyncMock(), ) diff --git a/surfsense_backend/tests/unit/routes/test_revert_turn_route.py b/surfsense_backend/tests/unit/routes/test_revert_turn_route.py index 35d409a40..44fcfe042 100644 --- a/surfsense_backend/tests/unit/routes/test_revert_turn_route.py +++ b/surfsense_backend/tests/unit/routes/test_revert_turn_route.py @@ -19,6 +19,7 @@ from unittest.mock import AsyncMock, patch import pytest from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags +from app.auth.context import AuthContext from app.routes import agent_revert_route from app.services.revert_service import RevertOutcome @@ -147,7 +148,7 @@ class TestFlagGuard: thread_id=1, chat_turn_id="42:1700000000000", session=session, - user=_FakeUser(), + auth=AuthContext.session(_FakeUser()), ) assert getattr(exc.value, "status_code", None) == 503 @@ -167,7 +168,7 @@ class TestRevertTurnDispatch: thread_id=1, chat_turn_id="ct-empty", session=session, - user=_FakeUser(), + auth=AuthContext.session(_FakeUser()), ) assert response.status == "ok" assert response.total == 0 @@ -209,7 +210,7 @@ class TestRevertTurnDispatch: thread_id=1, chat_turn_id="ct-3", session=session, - user=_FakeUser(), + auth=AuthContext.session(_FakeUser()), ) assert response.status == "ok" @@ -248,7 +249,7 @@ class TestRevertTurnDispatch: thread_id=1, chat_turn_id="ct-i", session=session, - user=_FakeUser(), + auth=AuthContext.session(_FakeUser()), ) assert response.status == "ok" assert response.already_reverted == 1 @@ -275,7 +276,7 @@ class TestRevertTurnDispatch: thread_id=1, chat_turn_id="ct-rev", session=session, - user=_FakeUser(), + auth=AuthContext.session(_FakeUser()), ) assert response.status == "ok" assert response.results[0].status == "skipped" @@ -315,7 +316,7 @@ class TestRevertTurnDispatch: thread_id=1, chat_turn_id="ct-mix", session=session, - user=_FakeUser(), + auth=AuthContext.session(_FakeUser()), ) assert response.status == "partial" assert response.reverted == 1 @@ -354,7 +355,7 @@ class TestRevertTurnDispatch: thread_id=1, chat_turn_id="ct-fail", session=session, - user=_FakeUser(), + auth=AuthContext.session(_FakeUser()), ) assert response.status == "partial" assert response.failed == 1 @@ -386,7 +387,7 @@ class TestRevertTurnDispatch: thread_id=1, chat_turn_id="ct-perm", session=session, - user=_FakeUser(id="not-owner"), + auth=AuthContext.session(_FakeUser(id="not-owner")), ) assert response.status == "partial" assert response.results[0].status == "permission_denied" @@ -449,7 +450,7 @@ class TestRevertTurnDispatch: thread_id=1, chat_turn_id="ct-mixed-all", session=session, - user=_FakeUser(), # only id=7 has a different user_id + auth=AuthContext.session(_FakeUser()), # only id=7 has a different user_id ) assert response.total == len(rows) == 6 @@ -518,7 +519,7 @@ class TestRevertTurnDispatch: thread_id=1, chat_turn_id="ct-race", session=session, - user=_FakeUser(), + auth=AuthContext.session(_FakeUser()), ) assert response.failed == 0, ( diff --git a/surfsense_backend/tests/unit/test_pat_fail_closed_static.py b/surfsense_backend/tests/unit/test_pat_fail_closed_static.py new file mode 100644 index 000000000..01ecd918f --- /dev/null +++ b/surfsense_backend/tests/unit/test_pat_fail_closed_static.py @@ -0,0 +1,101 @@ +"""Static guards for the fail-closed PAT authorization model.""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +pytestmark = pytest.mark.unit + +BACKEND_ROOT = Path(__file__).resolve().parents[2] +APP_ROOT = BACKEND_ROOT / "app" + +ALLOW_ANY_EXPECTED = { + "app.py": "auth: AuthContext = Depends(allow_any_principal)", + "routes/obsidian_plugin_routes.py": ( + "_auth: AuthContext = Depends(allow_any_principal)" + ), + "routes/search_spaces_routes.py": ( + "auth: AuthContext = Depends(allow_any_principal)" + ), +} + +CONNECTOR_LISTERS = [ + "routes/slack_add_connector_route.py", + "routes/composio_routes.py", + "routes/google_drive_add_connector_route.py", + "routes/discord_add_connector_route.py", + "routes/dropbox_add_connector_route.py", + "routes/onedrive_add_connector_route.py", +] + + +def _python_files() -> list[Path]: + return [ + path + for path in APP_ROOT.rglob("*.py") + if "__pycache__" not in path.parts + ] + + +def test_current_active_user_is_removed_from_app_tree() -> None: + offenders = [ + str(path.relative_to(BACKEND_ROOT)) + for path in _python_files() + if "current_active_user" in path.read_text() + ] + + assert offenders == [] + + +def test_allow_any_principal_is_only_used_by_bootstrap_allowlist() -> None: + actual: dict[str, int] = {} + for path in _python_files(): + text = path.read_text() + count = text.count("Depends(allow_any_principal)") + if count: + actual[str(path.relative_to(APP_ROOT))] = count + + assert actual == dict.fromkeys(ALLOW_ANY_EXPECTED, 1) + + for rel_path, expected_snippet in ALLOW_ANY_EXPECTED.items(): + text = (APP_ROOT / rel_path).read_text() + assert expected_snippet in text + + +def test_connector_listers_route_pat_through_search_space_gate() -> None: + for rel_path in CONNECTOR_LISTERS: + text = (APP_ROOT / rel_path).read_text() + assert "auth: AuthContext = Depends(get_auth_context)" in text, rel_path + assert ( + "await check_search_space_access(session, auth, connector.search_space_id)" + in text + ), rel_path + + +def test_identity_routes_are_session_only() -> None: + session_only_files = [ + "routes/prompts_routes.py", + "routes/memory_routes.py", + "routes/model_list_routes.py", + "routes/agent_flags_route.py", + "routes/youtube_routes.py", + "routes/incentive_tasks_routes.py", + "notifications/api/api.py", + "routes/chat_comments_routes.py", + "routes/public_chat_routes.py", + ] + + for rel_path in session_only_files: + text = (APP_ROOT / rel_path).read_text() + assert "require_session_context" in text, rel_path + assert "Depends(get_auth_context)" not in text, rel_path + + +def test_model_connection_personal_writes_default_to_session_required() -> None: + text = (APP_ROOT / "routes/model_connections_routes.py").read_text() + + assert "allow_spaceless_pat: bool = False" in text + assert "auth.is_gated and not allow_spaceless_pat" in text + assert "Managing personal model connections requires an interactive session" in text diff --git a/surfsense_browser_extension/routes/pages/ApiKeyForm.tsx b/surfsense_browser_extension/routes/pages/ApiKeyForm.tsx index 537eba3da..d045d8129 100644 --- a/surfsense_browser_extension/routes/pages/ApiKeyForm.tsx +++ b/surfsense_browser_extension/routes/pages/ApiKeyForm.tsx @@ -16,7 +16,7 @@ const ApiKeyForm = () => { const validateForm = () => { if (!apiKey) { - setError("API key is required"); + setError("Personal access token is required"); return false; } setError(""); @@ -39,11 +39,11 @@ const ApiKeyForm = () => { setLoading(false); if (response.ok) { - // Store the API key as the token + // Store the PAT as the bearer token for existing background handlers. await storage.set("token", apiKey); navigation("/"); } else { - setError("Invalid API key. Please check and try again."); + setError("Invalid personal access token. Please check and try again."); } } catch (error) { setLoading(false); @@ -67,15 +67,15 @@ const ApiKeyForm = () => {
- Your API key connects this extension to the SurfSense. + Your personal access token connects this extension to SurfSense.