Merge pull request #1524 from AnishSarkar22/feat/api-key

feat(auth): replace JWT-as-API-key with hashed PATs + per-space API access gate
This commit is contained in:
Rohan Verma 2026-06-23 00:55:34 -07:00 committed by GitHub
commit a0f44b283c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
118 changed files with 2273 additions and 859 deletions

View file

@ -84,6 +84,9 @@ SECRET_KEY=SECRET
# JWT Token Lifetimes (optional, defaults shown)
# ACCESS_TOKEN_LIFETIME_SECONDS=86400 # 1 day
# REFRESH_TOKEN_LIFETIME_SECONDS=1209600 # 2 weeks
# Personal Access Tokens (PATs). Empty/unset = no maximum; users may create
# never-expiring PATs. When set, PAT creation requires an expiry <= this many days.
# PAT_MAX_EXPIRY_DAYS=
NEXT_FRONTEND_URL=http://localhost:3000

View file

@ -0,0 +1,83 @@
"""Add personal access tokens and search-space API access gate.
Revision ID: 166
Revises: 165
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
revision: str = "166"
down_revision: str | None = "165"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
op.execute(
"""
CREATE TABLE IF NOT EXISTS personal_access_tokens (
id SERIAL PRIMARY KEY,
user_id UUID NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
token_hash VARCHAR(64) NOT NULL,
token_prefix VARCHAR(16) NOT NULL,
label VARCHAR NOT NULL,
expires_at TIMESTAMP WITH TIME ZONE,
last_used_at TIMESTAMP WITH TIME ZONE,
created_at TIMESTAMP WITH TIME ZONE NOT NULL
);
"""
)
op.execute(
"CREATE UNIQUE INDEX IF NOT EXISTS ix_personal_access_tokens_token_hash "
"ON personal_access_tokens (token_hash)"
)
op.execute(
"CREATE INDEX IF NOT EXISTS ix_personal_access_tokens_user_id "
"ON personal_access_tokens (user_id)"
)
op.execute(
"CREATE INDEX IF NOT EXISTS ix_personal_access_tokens_id "
"ON personal_access_tokens (id)"
)
op.execute(
"CREATE INDEX IF NOT EXISTS ix_personal_access_tokens_created_at "
"ON personal_access_tokens (created_at)"
)
op.execute(
"CREATE INDEX IF NOT EXISTS ix_personal_access_tokens_expires_at "
"ON personal_access_tokens (expires_at)"
)
bind = op.get_bind()
api_access_column_exists = bind.execute(
sa.text(
"""
SELECT EXISTS (
SELECT FROM information_schema.columns
WHERE table_schema = current_schema()
AND table_name = 'searchspaces'
AND column_name = 'api_access_enabled'
)
"""
)
).scalar()
op.execute(
"ALTER TABLE searchspaces ADD COLUMN IF NOT EXISTS "
"api_access_enabled BOOLEAN NOT NULL DEFAULT false"
)
if not api_access_column_exists:
op.execute("UPDATE searchspaces SET api_access_enabled = true")
def downgrade() -> None:
op.execute(
"ALTER TABLE searchspaces DROP COLUMN IF EXISTS api_access_enabled"
)
op.execute("DROP TABLE IF EXISTS personal_access_tokens")

View file

@ -34,6 +34,7 @@ from app.agents.chat.runtime.llm_config import AgentConfig
from app.agents.chat.runtime.prompt_caching import (
apply_litellm_prompt_caching,
)
from app.auth.context import AuthContext
from app.db import ChatVisibility
from app.services.connector_service import ConnectorService
from app.services.user_tool_allowlist import (
@ -73,6 +74,7 @@ async def create_multi_agent_chat_deep_agent(
anon_session_id: str | None = None,
filesystem_selection: FilesystemSelection | None = None,
image_gen_model_id: int | None = None,
auth_context: AuthContext | None = None,
):
"""Deep agent with SurfSense tools/middleware; registry route subagents behind ``task`` when enabled.
@ -139,6 +141,7 @@ async def create_multi_agent_chat_deep_agent(
"connector_service": connector_service,
"firecrawl_api_key": firecrawl_api_key,
"user_id": user_id,
"auth_context": auth_context,
"thread_id": thread_id,
"thread_visibility": visibility,
"available_connectors": available_connectors,

View file

@ -30,9 +30,10 @@ from pydantic import ValidationError
from app.agents.chat.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import (
request_approval,
)
from app.auth.context import AuthContext
from app.automations.schemas.api import AutomationCreate
from app.automations.services.automation import AutomationService
from app.db import User, async_session_maker
from app.db import async_session_maker
from app.utils.content_utils import extract_text_content
from .prompt import build_draft_prompt
@ -47,6 +48,7 @@ def create_create_automation_tool(
search_space_id: int,
user_id: str | UUID,
llm: Any,
auth_context: AuthContext | None = None,
):
"""Factory for the ``create_automation`` tool.
@ -56,8 +58,6 @@ def create_create_automation_tool(
``AsyncSession`` is opened per call to avoid stale sessions on
compiled-agent cache hits (same pattern as the Notion / memory tools).
"""
uid = UUID(user_id) if isinstance(user_id, str) else user_id
@tool
async def create_automation(intent: str, runtime: ToolRuntime) -> dict[str, Any]:
"""Draft + save an automation from a natural-language intent.
@ -165,14 +165,17 @@ def create_create_automation_tool(
"issues": _format_validation_issues(exc),
}
if auth_context is None:
logger.error(
"create_automation called without AuthContext; refusing to persist"
)
return {
"status": "error",
"message": "authorization context missing for automation creation",
}
async with async_session_maker() as session:
user = await session.get(User, uid)
if user is None:
return {
"status": "error",
"message": "user not found in this session",
}
service = AutomationService(session=session, user=user)
service = AutomationService(session=session, auth=auth_context)
created = await service.create(final_validated)
return {
"status": "saved",

View file

@ -60,6 +60,7 @@ def _build_create_automation_tool(deps: dict[str, Any]) -> BaseTool:
return create_create_automation_tool(
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
auth_context=deps.get("auth_context"),
llm=deps["llm"],
)

View file

@ -27,6 +27,7 @@ from app.agents.chat.runtime.checkpointer import (
close_checkpointer,
setup_checkpointer_tables,
)
from app.auth.context import AuthContext
from app.config import (
config,
initialize_image_gen_router,
@ -34,7 +35,7 @@ from app.config import (
initialize_openrouter_integration,
initialize_pricing_registration,
)
from app.db import User, create_db_and_tables, get_async_session
from app.db import create_db_and_tables, get_async_session
from app.exceptions import GENERIC_5XX_MESSAGE, ISSUES_URL, SurfSenseError
from app.gateway.byo_long_poll import (
start_byo_long_poll_supervisors,
@ -55,7 +56,7 @@ from app.routes import router as crud_router
from app.routes.auth_routes import router as auth_router
from app.schemas import UserCreate, UserRead, UserUpdate
from app.session_events import register_session_hooks
from app.users import SECRET, auth_backend, current_active_user, fastapi_users
from app.users import SECRET, allow_any_principal, auth_backend, fastapi_users
from app.utils.perf import log_system_snapshot
_error_logger = logging.getLogger("surfsense.errors")
@ -1032,7 +1033,7 @@ async def readiness_check():
@app.get("/verify-token")
async def authenticated_route(
user: User = Depends(current_active_user),
auth: AuthContext = Depends(allow_any_principal),
session: AsyncSession = Depends(get_async_session),
):
return {"message": "Token is valid"}
return {"message": "Token is valid", "method": auth.method}

View file

@ -0,0 +1 @@
"""Authentication principals and helpers."""

View file

@ -0,0 +1,38 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Literal
from app.db import PersonalAccessToken, User
AuthMethod = Literal["session", "pat", "system"]
@dataclass(frozen=True)
class AuthContext:
"""Typed principal for authorization decisions."""
user: User
method: AuthMethod
pat: PersonalAccessToken | None = None
source: str | None = None
@classmethod
def session(cls, user: User) -> AuthContext:
return cls(user=user, method="session")
@classmethod
def pat_auth(cls, user: User, pat: PersonalAccessToken) -> AuthContext:
return cls(user=user, method="pat", pat=pat)
@classmethod
def system(cls, user: User, source: str) -> AuthContext:
return cls(user=user, method="system", source=source)
@property
def is_gated(self) -> bool:
return self.method == "pat"
@property
def is_session(self) -> bool:
return self.method == "session"

View file

@ -16,7 +16,8 @@ from app.agents.chat.runtime.mention_resolver import (
substitute_in_text,
)
from app.agents.chat.shared.context import SurfSenseContextSchema
from app.db import ChatVisibility, async_session_maker
from app.auth.context import AuthContext
from app.db import ChatVisibility, User, async_session_maker
from app.schemas.new_chat import MentionedDocumentInfo
from ...types import ActionContext
@ -147,6 +148,12 @@ async def run_agent_task(
decision = "approve" if auto_approve_all else "reject"
async with async_session_maker() as agent_session:
auth_context = None
if ctx.creator_user_id:
user = await agent_session.get(User, ctx.creator_user_id)
if user is not None:
auth_context = AuthContext.system(user, source="automation")
deps = await build_dependencies(
session=agent_session,
search_space_id=ctx.search_space_id,
@ -168,6 +175,7 @@ async def run_agent_task(
thread_visibility=ChatVisibility.PRIVATE,
mentioned_document_ids=mentioned_document_ids,
image_gen_model_id=ctx.image_gen_model_id,
auth_context=auth_context,
)
agent_query, runtime_context = await _resolve_mention_context(

View file

@ -27,17 +27,19 @@ from app.automations.services.model_policy import (
)
from app.automations.triggers import get_trigger
from app.automations.triggers.builtin.schedule import compute_next_fire_at
from app.db import Permission, SearchSpace, User, get_async_session
from app.users import current_active_user
from app.auth.context import AuthContext
from app.db import Permission, SearchSpace, get_async_session
from app.users import get_auth_context
from app.utils.rbac import check_permission
class AutomationService:
"""Lifecycle of the ``Automation`` resource."""
def __init__(self, *, session: AsyncSession, user: User) -> None:
def __init__(self, *, session: AsyncSession, auth: AuthContext) -> None:
self.session = session
self.user = user
self.auth = auth
self.user = auth.user
async def create(self, payload: AutomationCreate) -> Automation:
"""Create an automation and its initial triggers in one transaction."""
@ -235,7 +237,7 @@ class AutomationService:
async def _authorize(self, search_space_id: int, permission: str) -> None:
await check_permission(
self.session,
self.user,
self.auth,
search_space_id,
permission,
f"You don't have permission to {permission.split(':')[1]} automations in this search space",
@ -274,6 +276,6 @@ def _build_trigger(spec: TriggerCreate) -> AutomationTrigger:
def get_automation_service(
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
) -> AutomationService:
return AutomationService(session=session, user=user)
return AutomationService(session=session, auth=auth)

View file

@ -8,17 +8,18 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.automations.persistence.models.automation import Automation
from app.automations.persistence.models.run import AutomationRun
from app.db import Permission, User, get_async_session
from app.users import current_active_user
from app.auth.context import AuthContext
from app.db import Permission, get_async_session
from app.users import get_auth_context
from app.utils.rbac import check_permission
class RunService:
"""Read-only access to ``AutomationRun`` history."""
def __init__(self, *, session: AsyncSession, user: User) -> None:
def __init__(self, *, session: AsyncSession, auth: AuthContext) -> None:
self.session = session
self.user = user
self.auth = auth
async def list(
self,
@ -63,7 +64,7 @@ class RunService:
)
await check_permission(
self.session,
self.user,
self.auth,
automation.search_space_id,
permission,
f"You don't have permission to {permission.split(':')[1]} automations in this search space",
@ -73,6 +74,6 @@ class RunService:
def get_run_service(
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
) -> RunService:
return RunService(session=session, user=user)
return RunService(session=session, auth=auth)

View file

@ -14,17 +14,18 @@ from app.automations.persistence.models.trigger import AutomationTrigger
from app.automations.schemas.api import TriggerCreate, TriggerUpdate
from app.automations.triggers import get_trigger
from app.automations.triggers.builtin.schedule import compute_next_fire_at
from app.db import Permission, User, get_async_session
from app.users import current_active_user
from app.auth.context import AuthContext
from app.db import Permission, get_async_session
from app.users import get_auth_context
from app.utils.rbac import check_permission
class TriggerService:
"""Lifecycle of the ``AutomationTrigger`` sub-resource."""
def __init__(self, *, session: AsyncSession, user: User) -> None:
def __init__(self, *, session: AsyncSession, auth: AuthContext) -> None:
self.session = session
self.user = user
self.auth = auth
async def add(
self, *, automation_id: int, payload: TriggerCreate
@ -101,7 +102,7 @@ class TriggerService:
)
await check_permission(
self.session,
self.user,
self.auth,
automation.search_space_id,
permission,
f"You don't have permission to {permission.split(':')[1]} automations in this search space",
@ -144,6 +145,6 @@ def _initial_next_fire(
def get_trigger_service(
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
) -> TriggerService:
return TriggerService(session=session, user=user)
return TriggerService(session=session, auth=auth)

View file

@ -919,6 +919,10 @@ class Config:
REFRESH_TOKEN_LIFETIME_SECONDS = int(
os.getenv("REFRESH_TOKEN_LIFETIME_SECONDS", str(14 * 24 * 60 * 60)) # 2 weeks
)
_PAT_MAX_EXPIRY_DAYS = os.getenv("PAT_MAX_EXPIRY_DAYS", "").strip()
PAT_MAX_EXPIRY_DAYS = (
int(_PAT_MAX_EXPIRY_DAYS) if _PAT_MAX_EXPIRY_DAYS else None
)
# ETL Service
ETL_SERVICE = os.getenv("ETL_SERVICE")

View file

@ -368,6 +368,9 @@ class Permission(StrEnum):
SETTINGS_UPDATE = "settings:update"
SETTINGS_DELETE = "settings:delete" # Delete the entire search space
# API Access
API_ACCESS_MANAGE = "api_access:manage"
# Public Sharing
PUBLIC_SHARING_VIEW = "public_sharing:view"
PUBLIC_SHARING_CREATE = "public_sharing:create"
@ -1693,6 +1696,9 @@ class SearchSpace(BaseModel, TimestampMixin):
citations_enabled = Column(
Boolean, nullable=False, default=True
) # Enable/disable citations
api_access_enabled = Column(
Boolean, nullable=False, default=False, server_default="false"
)
qna_custom_instructions = Column(
Text, nullable=True, default=""
) # User's custom instructions
@ -2330,6 +2336,11 @@ if config.AUTH_TYPE == "GOOGLE":
back_populates="user",
cascade="all, delete-orphan",
)
personal_access_tokens = relationship(
"PersonalAccessToken",
back_populates="user",
cascade="all, delete-orphan",
)
else:
@ -2462,6 +2473,11 @@ else:
back_populates="user",
cascade="all, delete-orphan",
)
personal_access_tokens = relationship(
"PersonalAccessToken",
back_populates="user",
cascade="all, delete-orphan",
)
class AgentActionLog(BaseModel):
@ -2712,6 +2728,36 @@ class RefreshToken(Base, TimestampMixin):
return not self.is_expired and not self.is_revoked
class PersonalAccessToken(BaseModel, TimestampMixin):
"""
Stores hashed Personal Access Tokens for programmatic API access.
Plaintext tokens are shown once on creation and are never persisted.
"""
__tablename__ = "personal_access_tokens"
user_id = Column(
UUID(as_uuid=True),
ForeignKey("user.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
user = relationship("User", back_populates="personal_access_tokens")
token_hash = Column(String(64), unique=True, nullable=False, index=True)
token_prefix = Column(String(16), nullable=False)
label = Column(String, nullable=False)
expires_at = Column(TIMESTAMP(timezone=True), nullable=True, index=True)
last_used_at = Column(TIMESTAMP(timezone=True), nullable=True)
@property
def is_expired(self) -> bool:
return self.expires_at is not None and datetime.now(UTC) >= self.expires_at
@property
def is_valid(self) -> bool:
return not self.is_expired
# Register model packages that live outside this file so their classes
# are present in Base.metadata before configure_mappers() resolves any
# string-based relationship() references.

View file

@ -9,7 +9,8 @@ from fastapi.responses import StreamingResponse
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import Document, Permission, User, get_async_session
from app.auth.context import AuthContext
from app.db import Document, Permission, get_async_session
from app.file_storage.persistence.enums import DocumentFileKind
from app.file_storage.schemas import DocumentFileRead
from app.file_storage.service import (
@ -17,14 +18,14 @@ from app.file_storage.service import (
list_document_files,
open_document_file_stream,
)
from app.users import current_active_user
from app.users import get_auth_context
from app.utils.rbac import check_permission
router = APIRouter()
async def _load_readable_document(
*, document_id: int, session: AsyncSession, user: User
*, document_id: int, session: AsyncSession, auth: AuthContext
) -> Document:
"""Load a document the user may read, or raise 404/403."""
document = (
@ -35,7 +36,7 @@ async def _load_readable_document(
await check_permission(
session,
user,
auth,
document.search_space_id,
Permission.DOCUMENTS_READ.value,
"You don't have permission to read documents in this search space",
@ -57,10 +58,10 @@ def _content_disposition(filename: str) -> str:
async def read_document_files(
document_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
) -> list[DocumentFileRead]:
"""Return metadata for every stored file of a document (gates the UI)."""
await _load_readable_document(document_id=document_id, session=session, user=user)
await _load_readable_document(document_id=document_id, session=session, auth=auth)
records = await list_document_files(session, document_id=document_id)
return [DocumentFileRead.model_validate(r) for r in records]
@ -69,10 +70,10 @@ async def read_document_files(
async def download_original_document_file(
document_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
) -> StreamingResponse:
"""Stream the document's original uploaded file."""
await _load_readable_document(document_id=document_id, session=session, user=user)
await _load_readable_document(document_id=document_id, session=session, auth=auth)
record = await get_document_file(
session, document_id=document_id, kind=DocumentFileKind.ORIGINAL

View file

@ -9,6 +9,7 @@ from collections.abc import AsyncIterator
from sqlalchemy import update
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.db import ExternalChatBinding, NewChatMessage
from app.gateway.auth_invariant import assert_authorization_invariant
from app.gateway.base.translator import BaseStreamTranslator, GatewayStreamEvent
@ -64,6 +65,7 @@ async def call_agent_for_gateway(
request_id: str | None = None,
) -> None:
user = await assert_authorization_invariant(session, binding)
auth_context = AuthContext.system(user, source="gateway")
thread = await get_or_create_thread_for_binding(session, binding)
await session.commit()
@ -81,6 +83,7 @@ async def call_agent_for_gateway(
current_user_display_name=user.display_name or "A team member",
disabled_tools=sorted(DEFAULT_HITL_TOOL_NAMES),
request_id=request_id or "gateway",
auth_context=auth_context,
)
events = _events_from_sse(stream)
try:

View file

@ -5,6 +5,7 @@ from __future__ import annotations
from fastapi import HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.db import ExternalChatBinding, Permission, User
from app.gateway.bindings import suspend_binding
from app.observability.metrics import record_gateway_auth_invariant_failure
@ -39,11 +40,13 @@ async def assert_authorization_invariant(
if user is None:
await _fail(session, binding, "owner_missing")
auth = AuthContext.system(user, source="gateway")
try:
await check_search_space_access(session, user, binding.search_space_id)
await check_search_space_access(session, auth, binding.search_space_id)
await check_permission(
session,
user,
auth,
binding.search_space_id,
Permission.CHATS_CREATE.value,
"External chat owner no longer has permission to chat in this search space",

View file

@ -8,7 +8,8 @@ from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import case, desc, func, literal, literal_column, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import User, get_async_session
from app.auth.context import AuthContext
from app.db import get_async_session
from app.notifications.api.schemas import (
BatchUnreadCountResponse,
CategoryUnreadCount,
@ -27,7 +28,7 @@ from app.notifications.api.transform import (
from app.notifications.constants import CATEGORY_TYPES, SYNC_WINDOW_DAYS
from app.notifications.persistence import Notification
from app.notifications.types import NotificationCategory, NotificationType
from app.users import current_active_user
from app.users import require_session_context
router = APIRouter(prefix="/notifications", tags=["notifications"])
@ -35,10 +36,11 @@ router = APIRouter(prefix="/notifications", tags=["notifications"])
@router.get("/unread-counts-batch", response_model=BatchUnreadCountResponse)
async def get_unread_counts_batch(
search_space_id: int | None = Query(None, description="Filter by search space ID"),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
) -> BatchUnreadCountResponse:
"""Unread counts for every category in a single query."""
user = auth.user
cutoff_date = datetime.now(UTC) - timedelta(days=SYNC_WINDOW_DAYS)
base_filter = [
@ -86,10 +88,11 @@ async def get_unread_counts_batch(
@router.get("/source-types", response_model=SourceTypesResponse)
async def get_notification_source_types(
search_space_id: int | None = Query(None, description="Filter by search space ID"),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
) -> SourceTypesResponse:
"""Distinct connector/document source types for the Status tab filter."""
user = auth.user
base_filter = [Notification.user_id == user.id]
if search_space_id is not None:
@ -160,7 +163,7 @@ async def get_unread_count(
category: NotificationCategory | None = Query(
None, description="Filter by category: 'comments' or 'status'"
),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
) -> UnreadCountResponse:
"""Total and recent (within sync window) unread counts for the user.
@ -168,6 +171,7 @@ async def get_unread_count(
Returning both lets a client hold the older count static while
live-syncing the recent ones.
"""
user = auth.user
cutoff_date = datetime.now(UTC) - timedelta(days=SYNC_WINDOW_DAYS)
base_filter = [
@ -230,10 +234,11 @@ async def list_notifications(
),
limit: int = Query(50, ge=1, le=100, description="Number of items to return"),
offset: int = Query(0, ge=0, description="Number of items to skip"),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
) -> NotificationListResponse:
"""Paginated inbox fallback for items outside the Zero sync window."""
user = auth.user
query = select(Notification).where(Notification.user_id == user.id)
count_query = select(func.count(Notification.id)).where(
Notification.user_id == user.id
@ -328,10 +333,11 @@ async def list_notifications(
@router.patch("/{notification_id}/read", response_model=MarkReadResponse)
async def mark_notification_as_read(
notification_id: int,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
) -> MarkReadResponse:
"""Mark one of the user's notifications read; Zero syncs the change."""
user = auth.user
# Scope to the caller's own notifications.
result = await session.execute(
select(Notification).where(
@ -364,10 +370,11 @@ async def mark_notification_as_read(
@router.patch("/read-all", response_model=MarkAllReadResponse)
async def mark_all_notifications_as_read(
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
) -> MarkAllReadResponse:
"""Mark all of the user's notifications read; Zero syncs the changes."""
user = auth.user
result = await session.execute(
update(Notification)
.where(

View file

@ -18,12 +18,12 @@ from fastapi.responses import StreamingResponse
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.config import config as app_config
from app.db import (
Permission,
SearchSpace,
SearchSpaceMembership,
User,
get_async_session,
)
from app.podcasts.generation.brief import propose_brief
@ -42,7 +42,7 @@ from app.podcasts.voices import (
provider_from_service,
render_voice_preview,
)
from app.users import current_active_user
from app.users import get_auth_context
from app.utils.rbac import check_permission
from .schemas import (
@ -63,13 +63,14 @@ async def list_podcasts(
skip: int = 0,
limit: int = 100,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
if skip < 0 or limit < 1:
raise HTTPException(status_code=400, detail="Invalid pagination parameters")
if search_space_id is not None:
await _require(session, user, search_space_id, Permission.PODCASTS_READ)
await _require(session, auth, search_space_id, Permission.PODCASTS_READ)
query = (
select(Podcast)
.where(Podcast.search_space_id == search_space_id)
@ -132,7 +133,7 @@ async def list_languages():
@router.get("/podcasts/voices/{voice_id}/preview")
async def preview_voice(
voice_id: str,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
"""A short audio sample of a voice, so users pick by sound."""
if not app_config.TTS_SERVICE:
@ -156,9 +157,9 @@ async def preview_voice(
async def create_podcast(
body: CreatePodcastRequest,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
await _require(session, user, body.search_space_id, Permission.PODCASTS_CREATE)
await _require(session, auth, body.search_space_id, Permission.PODCASTS_CREATE)
service = PodcastService(session)
podcast = await service.create(
@ -185,9 +186,9 @@ async def create_podcast(
async def get_podcast(
podcast_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_READ)
podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_READ)
return PodcastDetail.of(podcast)
@ -196,9 +197,9 @@ async def update_spec(
podcast_id: int,
body: UpdateSpecRequest,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE)
podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_UPDATE)
async with _lifecycle_errors():
await PodcastService(session).update_spec(
podcast, body.spec, body.expected_version
@ -211,10 +212,10 @@ async def update_spec(
async def approve_brief(
podcast_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
"""Approve the brief and start drafting the transcript."""
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE)
podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_UPDATE)
async with _lifecycle_errors():
await PodcastService(session).begin_drafting(podcast)
await session.commit()
@ -228,10 +229,10 @@ async def approve_brief(
async def regenerate_transcript(
podcast_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
"""Reopen the brief gate for a fresh take; drafting waits for re-approval."""
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE)
podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_UPDATE)
async with _lifecycle_errors():
await PodcastService(session).regenerate(podcast)
await session.commit()
@ -242,10 +243,10 @@ async def regenerate_transcript(
async def revert_regeneration(
podcast_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
"""Back out of a regeneration and return to the finished episode."""
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE)
podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_UPDATE)
async with _lifecycle_errors():
await PodcastService(session).revert_regeneration(podcast)
await session.commit()
@ -256,9 +257,9 @@ async def revert_regeneration(
async def cancel_podcast(
podcast_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE)
podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_UPDATE)
async with _lifecycle_errors():
await PodcastService(session).cancel(podcast)
await session.commit()
@ -269,9 +270,9 @@ async def cancel_podcast(
async def delete_podcast(
podcast_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_DELETE)
podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_DELETE)
await purge_audio(podcast)
await session.delete(podcast)
await session.commit()
@ -282,9 +283,9 @@ async def delete_podcast(
async def stream_podcast(
podcast_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_READ)
podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_READ)
if podcast.storage_key:
# Verify first so a missing object is a 404, not a mid-stream crash.
@ -323,13 +324,13 @@ async def stream_podcast(
async def _require(
session: AsyncSession,
user: User,
auth: AuthContext,
search_space_id: int,
permission: Permission,
) -> None:
await check_permission(
session,
user,
auth,
search_space_id,
permission.value,
"You don't have permission for podcasts in this search space",
@ -338,14 +339,14 @@ async def _require(
async def _load(
session: AsyncSession,
user: User,
auth: AuthContext,
podcast_id: int,
permission: Permission,
) -> Podcast:
podcast = await PodcastRepository(session).get(podcast_id)
if podcast is None:
raise HTTPException(status_code=404, detail="Podcast not found")
await _require(session, user, podcast.search_space_id, permission)
await _require(session, auth, podcast.search_space_id, permission)
return podcast

View file

@ -54,6 +54,7 @@ from .notes_routes import router as notes_router
from .notion_add_connector_route import router as notion_add_connector_router
from .obsidian_plugin_routes import router as obsidian_plugin_router
from .onedrive_add_connector_route import router as onedrive_add_connector_router
from .personal_access_tokens_routes import router as personal_access_tokens_router
from .prompts_routes import router as prompts_router
from .public_chat_routes import router as public_chat_router
from .rbac_routes import router as rbac_router
@ -113,6 +114,7 @@ router.include_router(slack_add_connector_router)
router.include_router(teams_add_connector_router)
router.include_router(onedrive_add_connector_router)
router.include_router(obsidian_plugin_router) # Obsidian plugin push API
router.include_router(personal_access_tokens_router) # Personal access token manager
router.include_router(discord_add_connector_router)
router.include_router(jira_add_connector_router)
router.include_router(confluence_add_connector_router)

View file

@ -28,6 +28,7 @@ from pydantic import BaseModel
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.agents.chat.multi_agent_chat.shared.feature_flags import get_flags
from app.db import (
AgentActionLog,
@ -36,7 +37,7 @@ from app.db import (
User,
get_async_session,
)
from app.users import current_active_user
from app.users import get_auth_context
from app.utils.rbac import check_permission
logger = logging.getLogger(__name__)
@ -111,8 +112,9 @@ async def list_thread_actions(
page: int = Query(0, ge=0),
page_size: int = Query(50, ge=1, le=200),
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
) -> AgentActionListResponse:
user = auth.user
"""List agent actions for a thread, newest first.
Authorization:
@ -132,7 +134,7 @@ async def list_thread_actions(
await check_permission(
session,
user,
auth,
thread.search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to view this thread's action log.",

View file

@ -26,9 +26,9 @@ from app.agents.chat.multi_agent_chat.shared.feature_flags import (
AgentFeatureFlags,
get_flags,
)
from app.auth.context import AuthContext
from app.config import config
from app.db import User
from app.users import current_active_user
from app.users import require_session_context
router = APIRouter()
@ -75,6 +75,6 @@ class AgentFeatureFlagsRead(BaseModel):
@router.get("/agent/flags", response_model=AgentFeatureFlagsRead)
async def get_agent_flags(
_user: User = Depends(current_active_user),
_auth: AuthContext = Depends(require_session_context),
) -> AgentFeatureFlagsRead:
return AgentFeatureFlagsRead.from_flags(get_flags())

View file

@ -30,6 +30,7 @@ from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.agents.chat.multi_agent_chat.shared.feature_flags import get_flags
from app.db import (
AgentPermissionRule,
@ -39,7 +40,7 @@ from app.db import (
User,
get_async_session,
)
from app.users import current_active_user
from app.users import get_auth_context
from app.utils.rbac import check_permission
logger = logging.getLogger(__name__)
@ -133,15 +134,16 @@ def _to_read(row: AgentPermissionRule) -> AgentPermissionRuleRead:
async def _ensure_search_space_membership_admin(
session: AsyncSession, user: User, search_space_id: int
session: AsyncSession, auth: AuthContext, search_space_id: int
) -> None:
user = auth.user
"""Curating agent rules == "settings" administration on the space."""
space = await session.get(SearchSpace, search_space_id)
if space is None:
raise HTTPException(status_code=404, detail="Search space not found.")
await check_permission(
session,
user,
auth,
search_space_id,
Permission.SETTINGS_UPDATE.value,
"You don't have permission to manage agent permission rules in this space.",
@ -160,8 +162,9 @@ async def _ensure_search_space_membership_admin(
async def list_rules(
search_space_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
) -> list[AgentPermissionRuleRead]:
user = auth.user
_flag_guard()
await _ensure_search_space_membership_admin(session, user, search_space_id)
@ -183,8 +186,9 @@ async def create_rule(
search_space_id: int,
payload: AgentPermissionRuleCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
) -> AgentPermissionRuleRead:
user = auth.user
_flag_guard()
await _ensure_search_space_membership_admin(session, user, search_space_id)
@ -232,8 +236,9 @@ async def update_rule(
rule_id: int,
payload: AgentPermissionRuleUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
) -> AgentPermissionRuleRead:
user = auth.user
_flag_guard()
await _ensure_search_space_membership_admin(session, user, search_space_id)
@ -266,8 +271,9 @@ async def delete_rule(
search_space_id: int,
rule_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
) -> None:
user = auth.user
_flag_guard()
await _ensure_search_space_membership_admin(session, user, search_space_id)

View file

@ -5,7 +5,7 @@ here" affordance. To prevent accidental usage during the gap we return
``503 Service Unavailable`` until the ``SURFSENSE_ENABLE_REVERT_ROUTE``
flag flips. Once enabled, the route runs:
1. Authentication via :func:`current_active_user`.
1. Authentication via an interactive session context.
2. Action lookup; 404 if the action does not belong to the thread.
3. Authorization via :func:`app.services.revert_service.can_revert`.
4. Revert dispatch via :func:`app.services.revert_service.revert_action`.
@ -33,9 +33,9 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.chat.multi_agent_chat.shared.feature_flags import get_flags
from app.auth.context import AuthContext
from app.db import (
AgentActionLog,
User,
get_async_session,
)
from app.services.revert_service import (
@ -45,7 +45,7 @@ from app.services.revert_service import (
load_thread,
revert_action,
)
from app.users import current_active_user
from app.users import require_session_context
logger = logging.getLogger(__name__)
@ -57,8 +57,9 @@ async def revert_agent_action(
thread_id: int,
action_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
) -> dict:
user = auth.user
flags = get_flags()
if flags.disable_new_agent_stack or not flags.enable_revert_route:
raise HTTPException(
@ -269,7 +270,7 @@ async def revert_agent_turn(
thread_id: int,
chat_turn_id: str,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
) -> RevertTurnResponse:
"""Revert every reversible action emitted during ``chat_turn_id``.
@ -281,6 +282,7 @@ async def revert_agent_turn(
Partial success is intentional and returned with HTTP 200. Callers
must inspect ``results[*].status`` to find rows that need attention.
"""
user = auth.user
flags = get_flags()
if flags.disable_new_agent_stack or not flags.enable_revert_route:

View file

@ -10,16 +10,16 @@ from pydantic import ValidationError
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.config import config
from app.connectors.airtable_connector import fetch_airtable_user_email
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
User,
get_async_session,
)
from app.schemas.airtable_auth_credentials import AirtableAuthCredentialsBase
from app.users import current_active_user
from app.users import require_session_context
from app.utils.connector_naming import (
check_duplicate_connector,
generate_unique_connector_name,
@ -78,7 +78,10 @@ def make_basic_auth_header(client_id: str, client_secret: str) -> str:
@router.get("/auth/airtable/connector/add")
async def connect_airtable(space_id: int, user: User = Depends(current_active_user)):
async def connect_airtable(
space_id: int,
auth: AuthContext = Depends(require_session_context),
):
"""
Initiate Airtable OAuth flow.
@ -89,6 +92,7 @@ async def connect_airtable(space_id: int, user: User = Depends(current_active_us
Returns:
Authorization URL for redirect
"""
user = auth.user
try:
if not space_id:
raise HTTPException(status_code=400, detail="space_id is required")

View file

@ -5,6 +5,7 @@ import logging
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select
from app.auth.context import AuthContext
from app.db import User, async_session_maker
from app.schemas.auth import (
LogoutAllResponse,
@ -13,7 +14,7 @@ from app.schemas.auth import (
RefreshTokenRequest,
RefreshTokenResponse,
)
from app.users import current_active_user, get_jwt_strategy
from app.users import get_jwt_strategy, require_session_context
from app.utils.refresh_tokens import (
revoke_all_user_tokens,
revoke_refresh_token,
@ -83,11 +84,14 @@ async def revoke_token(request: LogoutRequest):
@router.post("/logout-all", response_model=LogoutAllResponse)
async def logout_all_devices(user: User = Depends(current_active_user)):
async def logout_all_devices(
auth: AuthContext = Depends(require_session_context),
):
"""
Logout from all devices by revoking all refresh tokens for the user.
Requires valid access token.
"""
user = auth.user
await revoke_all_user_tokens(user.id)
logger.info(f"User {user.id} logged out from all devices")
return LogoutAllResponse()

View file

@ -5,7 +5,8 @@ Routes for chat comments and mentions.
from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import User, get_async_session
from app.auth.context import AuthContext
from app.db import get_async_session
from app.schemas.chat_comments import (
CommentBatchRequest,
CommentBatchResponse,
@ -25,7 +26,7 @@ from app.services.chat_comments_service import (
get_user_mentions,
update_comment,
)
from app.users import current_active_user
from app.users import require_session_context
router = APIRouter()
@ -34,20 +35,20 @@ router = APIRouter()
async def batch_list_comments(
request: CommentBatchRequest,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
):
"""Batch-fetch comments for multiple messages in one request."""
return await get_comments_for_messages_batch(session, request.message_ids, user)
return await get_comments_for_messages_batch(session, request.message_ids, auth)
@router.get("/messages/{message_id}/comments", response_model=CommentListResponse)
async def list_comments(
message_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
):
"""List all comments for a message with their replies."""
return await get_comments_for_message(session, message_id, user)
return await get_comments_for_message(session, message_id, auth)
@router.post("/messages/{message_id}/comments", response_model=CommentResponse)
@ -55,10 +56,10 @@ async def add_comment(
message_id: int,
request: CommentCreateRequest,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
):
"""Create a top-level comment on an AI response."""
return await create_comment(session, message_id, request.content, user)
return await create_comment(session, message_id, request.content, auth)
@router.post("/comments/{comment_id}/replies", response_model=CommentReplyResponse)
@ -66,10 +67,10 @@ async def add_reply(
comment_id: int,
request: CommentCreateRequest,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
):
"""Reply to an existing comment."""
return await create_reply(session, comment_id, request.content, user)
return await create_reply(session, comment_id, request.content, auth)
@router.put("/comments/{comment_id}", response_model=CommentReplyResponse)
@ -77,20 +78,20 @@ async def edit_comment(
comment_id: int,
request: CommentUpdateRequest,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
):
"""Update a comment's content (author only)."""
return await update_comment(session, comment_id, request.content, user)
return await update_comment(session, comment_id, request.content, auth)
@router.delete("/comments/{comment_id}")
async def remove_comment(
comment_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
):
"""Delete a comment (author or user with COMMENTS_DELETE permission)."""
return await delete_comment(session, comment_id, user)
return await delete_comment(session, comment_id, auth)
# =============================================================================
@ -102,7 +103,7 @@ async def remove_comment(
async def list_mentions(
search_space_id: int | None = None,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
):
"""List mentions for the current user."""
return await get_user_mentions(session, user, search_space_id)
return await get_user_mentions(session, auth, search_space_id)

View file

@ -16,15 +16,15 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.auth.context import AuthContext
from app.config import config
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
User,
get_async_session,
)
from app.schemas.clickup_auth_credentials import ClickUpAuthCredentialsBase
from app.users import current_active_user
from app.users import require_session_context
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
logger = logging.getLogger(__name__)
@ -61,7 +61,10 @@ def get_token_encryption() -> TokenEncryption:
@router.get("/auth/clickup/connector/add")
async def connect_clickup(space_id: int, user: User = Depends(current_active_user)):
async def connect_clickup(
space_id: int,
auth: AuthContext = Depends(require_session_context),
):
"""
Initiate ClickUp OAuth flow.
@ -72,6 +75,7 @@ async def connect_clickup(space_id: int, user: User = Depends(current_active_use
Returns:
Authorization URL for redirect
"""
user = auth.user
try:
if not space_id:
raise HTTPException(status_code=400, detail="space_id is required")

View file

@ -22,11 +22,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm.attributes import flag_modified
from app.auth.context import AuthContext
from app.config import config
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
User,
get_async_session,
)
from app.services.composio_service import (
@ -35,12 +35,13 @@ from app.services.composio_service import (
TOOLKIT_TO_CONNECTOR_TYPE,
ComposioService,
)
from app.users import current_active_user
from app.users import get_auth_context, require_session_context
from app.utils.connector_naming import (
count_connectors_of_type,
get_base_name_for_type,
)
from app.utils.oauth_security import OAuthStateManager
from app.utils.rbac import check_search_space_access
logger = logging.getLogger(__name__)
@ -68,13 +69,16 @@ def get_state_manager() -> OAuthStateManager:
@router.get("/composio/toolkits")
async def list_composio_toolkits(user: User = Depends(current_active_user)):
async def list_composio_toolkits(
auth: AuthContext = Depends(require_session_context),
):
"""
List available Composio toolkits.
Returns:
JSON with list of available toolkits and their metadata.
"""
del auth
if not ComposioService.is_enabled():
raise HTTPException(
status_code=503,
@ -98,7 +102,7 @@ async def initiate_composio_auth(
toolkit_id: str = Query(
..., description="Composio toolkit ID (e.g., 'googledrive', 'gmail')"
),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
):
"""
Initiate Composio OAuth flow for a specific toolkit.
@ -110,6 +114,7 @@ async def initiate_composio_auth(
Returns:
JSON with auth_url to redirect user to Composio authorization
"""
user = auth.user
if not ComposioService.is_enabled():
raise HTTPException(
status_code=503,
@ -446,7 +451,7 @@ async def reauth_composio_connector(
space_id: int,
connector_id: int,
return_url: str | None = None,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
):
"""
@ -460,6 +465,7 @@ async def reauth_composio_connector(
connector_id: ID of the existing Composio connector to re-authenticate
return_url: Optional frontend path to redirect to after completion
"""
user = auth.user
if not ComposioService.is_enabled():
raise HTTPException(
status_code=503, detail="Composio integration is not enabled."
@ -644,7 +650,7 @@ async def list_composio_drive_folders(
connector_id: int,
parent_id: str | None = None,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
"""
List folders AND files in user's Google Drive via Composio.
@ -659,6 +665,7 @@ async def list_composio_drive_folders(
)
connector = None
user = auth.user
try:
result = await session.execute(
select(SearchSourceConnector).filter(
@ -676,6 +683,8 @@ async def list_composio_drive_folders(
detail="Composio Google Drive connector not found or access denied",
)
await check_search_space_access(session, auth, connector.search_space_id)
composio_connected_account_id = connector.config.get(
"composio_connected_account_id"
)

View file

@ -15,15 +15,15 @@ from pydantic import ValidationError
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.config import config
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
User,
get_async_session,
)
from app.schemas.atlassian_auth_credentials import AtlassianAuthCredentialsBase
from app.users import current_active_user
from app.users import require_session_context
from app.utils.connector_naming import (
check_duplicate_connector,
extract_identifier_from_credentials,
@ -77,7 +77,10 @@ def get_token_encryption() -> TokenEncryption:
@router.get("/auth/confluence/connector/add")
async def connect_confluence(space_id: int, user: User = Depends(current_active_user)):
async def connect_confluence(
space_id: int,
auth: AuthContext = Depends(require_session_context),
):
"""
Initiate Confluence OAuth flow.
@ -88,6 +91,7 @@ async def connect_confluence(space_id: int, user: User = Depends(current_active_
Returns:
Authorization URL for redirect
"""
user = auth.user
try:
if not space_id:
raise HTTPException(status_code=400, detail="space_id is required")
@ -421,10 +425,11 @@ async def reauth_confluence(
space_id: int,
connector_id: int,
return_url: str | None = None,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
):
"""Initiate Confluence re-authentication to upgrade OAuth scopes."""
user = auth.user
try:
from sqlalchemy.future import select

View file

@ -15,21 +15,22 @@ from pydantic import ValidationError
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.config import config
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
User,
get_async_session,
)
from app.schemas.discord_auth_credentials import DiscordAuthCredentialsBase
from app.users import current_active_user
from app.users import get_auth_context, require_session_context
from app.utils.connector_naming import (
check_duplicate_connector,
extract_identifier_from_credentials,
generate_unique_connector_name,
)
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
from app.utils.rbac import check_search_space_access
logger = logging.getLogger(__name__)
@ -77,7 +78,10 @@ def get_token_encryption() -> TokenEncryption:
@router.get("/auth/discord/connector/add")
async def connect_discord(space_id: int, user: User = Depends(current_active_user)):
async def connect_discord(
space_id: int,
auth: AuthContext = Depends(require_session_context),
):
"""
Initiate Discord OAuth flow.
@ -88,6 +92,7 @@ async def connect_discord(space_id: int, user: User = Depends(current_active_use
Returns:
Authorization URL for redirect
"""
user = auth.user
try:
if not space_id:
raise HTTPException(status_code=400, detail="space_id is required")
@ -610,7 +615,7 @@ def _compute_channel_permissions(
async def get_discord_channels(
connector_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
"""
Get list of Discord text channels for a connector with permission info.
@ -628,6 +633,7 @@ async def get_discord_channels(
"""
from sqlalchemy import select
user = auth.user
try:
# Get connector and verify ownership
result = await session.execute(
@ -646,6 +652,8 @@ async def get_discord_channels(
detail="Discord connector not found or access denied",
)
await check_search_space_access(session, auth, connector.search_space_id)
# Get credentials and decrypt bot token
credentials = DiscordAuthCredentialsBase.from_dict(connector.config)
token_encryption = get_token_encryption()

View file

@ -7,6 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from app.auth.context import AuthContext
from app.agents.chat.runtime.path_resolver import virtual_path_to_doc
from app.db import (
Chunk,
@ -35,7 +36,7 @@ from app.schemas import (
PaginatedResponse,
)
from app.services.task_dispatcher import TaskDispatcher, get_task_dispatcher
from app.users import current_active_user
from app.users import get_auth_context
from app.utils.rbac import check_permission
try:
@ -60,8 +61,9 @@ MAX_FILE_SIZE_BYTES = 500 * 1024 * 1024 # 500 MB per file
async def create_documents(
request: DocumentsCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Create new documents.
Requires DOCUMENTS_CREATE permission.
@ -70,7 +72,7 @@ async def create_documents(
# Check permission
await check_permission(
session,
user,
auth,
request.search_space_id,
Permission.DOCUMENTS_CREATE.value,
"You don't have permission to create documents in this search space",
@ -128,9 +130,10 @@ async def create_documents_file_upload(
use_vision_llm: bool = Form(False),
processing_mode: str = Form("basic"),
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
dispatcher: TaskDispatcher = Depends(get_task_dispatcher),
):
user = auth.user
"""
Upload files as documents with real-time status tracking.
@ -159,7 +162,7 @@ async def create_documents_file_upload(
try:
await check_permission(
session,
user,
auth,
search_space_id,
Permission.DOCUMENTS_CREATE.value,
"You don't have permission to create documents in this search space",
@ -340,8 +343,9 @@ async def read_documents(
sort_by: str = "created_at",
sort_order: str = "desc",
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
List documents the user has access to, with optional filtering and pagination.
Requires DOCUMENTS_READ permission for the search space(s).
@ -369,7 +373,7 @@ async def read_documents(
if search_space_id is not None:
await check_permission(
session,
user,
auth,
search_space_id,
Permission.DOCUMENTS_READ.value,
"You don't have permission to read documents in this search space",
@ -519,8 +523,9 @@ async def search_documents(
search_space_id: int | None = None,
document_types: str | None = None,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Search documents by title substring, optionally filtered by search_space_id and document_types.
Requires DOCUMENTS_READ permission for the search space(s).
@ -549,7 +554,7 @@ async def search_documents(
if search_space_id is not None:
await check_permission(
session,
user,
auth,
search_space_id,
Permission.DOCUMENTS_READ.value,
"You don't have permission to read documents in this search space",
@ -677,8 +682,9 @@ async def search_document_titles(
page: int = 0,
page_size: int = 20,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Lightweight document title search optimized for mention picker (@mentions).
@ -703,7 +709,7 @@ async def search_document_titles(
# Check permission for the search space
await check_permission(
session,
user,
auth,
search_space_id,
Permission.DOCUMENTS_READ.value,
"You don't have permission to read documents in this search space",
@ -781,8 +787,9 @@ async def get_document_by_virtual_path(
search_space_id: int,
virtual_path: str,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Resolve a knowledge-base document by its agent-facing virtual path.
The agent renders every document under ``/documents/...`` with a
@ -804,7 +811,7 @@ async def get_document_by_virtual_path(
try:
await check_permission(
session,
user,
auth,
search_space_id,
Permission.DOCUMENTS_READ.value,
"You don't have permission to read documents in this search space",
@ -838,8 +845,9 @@ async def get_documents_status(
search_space_id: int,
document_ids: str,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Batch status endpoint for documents in a search space.
@ -849,7 +857,7 @@ async def get_documents_status(
try:
await check_permission(
session,
user,
auth,
search_space_id,
Permission.DOCUMENTS_READ.value,
"You don't have permission to read documents in this search space",
@ -905,8 +913,9 @@ async def get_documents_status(
async def get_document_type_counts(
search_space_id: int | None = None,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Get counts of documents by type for search spaces the user has access to.
Requires DOCUMENTS_READ permission for the search space(s).
@ -926,7 +935,7 @@ async def get_document_type_counts(
# Check permission for specific search space
await check_permission(
session,
user,
auth,
search_space_id,
Permission.DOCUMENTS_READ.value,
"You don't have permission to read documents in this search space",
@ -965,7 +974,7 @@ async def get_document_by_chunk_id(
5, ge=0, description="Number of chunks before/after the cited chunk to include"
),
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
"""
Retrieves a document based on a chunk ID, including a window of chunks around the cited one.
@ -995,7 +1004,7 @@ async def get_document_by_chunk_id(
await check_permission(
session,
user,
auth,
document.search_space_id,
Permission.DOCUMENTS_READ.value,
"You don't have permission to read documents in this search space",
@ -1060,12 +1069,13 @@ async def get_document_by_chunk_id(
async def get_watched_folders(
search_space_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Return root folders that are marked as watched (metadata->>'watched' = 'true')."""
await check_permission(
session,
user,
auth,
search_space_id,
Permission.DOCUMENTS_READ.value,
"You don't have permission to read documents in this search space",
@ -1101,8 +1111,9 @@ async def get_document_chunks_paginated(
None, ge=0, description="Direct offset; overrides page * page_size"
),
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Paginated chunk loading for a document.
Supports both page-based and offset-based access.
@ -1120,7 +1131,7 @@ async def get_document_chunks_paginated(
await check_permission(
session,
user,
auth,
document.search_space_id,
Permission.DOCUMENTS_READ.value,
"You don't have permission to read documents in this search space",
@ -1162,8 +1173,9 @@ async def get_document_chunks_paginated(
async def read_document(
document_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Get a specific document by ID.
Requires DOCUMENTS_READ permission for the search space.
@ -1182,7 +1194,7 @@ async def read_document(
# Check permission for the search space
await check_permission(
session,
user,
auth,
document.search_space_id,
Permission.DOCUMENTS_READ.value,
"You don't have permission to read documents in this search space",
@ -1216,8 +1228,9 @@ async def update_document(
document_id: int,
document_update: DocumentUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Update a document.
Requires DOCUMENTS_UPDATE permission for the search space.
@ -1236,7 +1249,7 @@ async def update_document(
# Check permission for the search space
await check_permission(
session,
user,
auth,
db_document.search_space_id,
Permission.DOCUMENTS_UPDATE.value,
"You don't have permission to update documents in this search space",
@ -1275,8 +1288,9 @@ async def update_document(
async def delete_document(
document_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Delete a document.
Requires DOCUMENTS_DELETE permission for the search space.
@ -1311,7 +1325,7 @@ async def delete_document(
# Check permission for the search space
await check_permission(
session,
user,
auth,
document.search_space_id,
Permission.DOCUMENTS_DELETE.value,
"You don't have permission to delete documents in this search space",
@ -1355,8 +1369,9 @@ async def delete_document(
async def list_document_versions(
document_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""List all versions for a document, ordered by version_number descending."""
document = (
await session.execute(select(Document).where(Document.id == document_id))
@ -1396,8 +1411,9 @@ async def get_document_version(
document_id: int,
version_number: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Get full version content including source_markdown."""
document = (
await session.execute(select(Document).where(Document.id == document_id))
@ -1434,8 +1450,9 @@ async def restore_document_version(
document_id: int,
version_number: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Restore a previous version: snapshot current state, then overwrite document content."""
document = (
await session.execute(select(Document).where(Document.id == document_id))
@ -1517,8 +1534,9 @@ class FolderSyncFinalizeRequest(PydanticBaseModel):
async def folder_mtime_check(
request: FolderMtimeCheckRequest,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Pre-upload optimization: check which files need uploading based on mtime.
Returns the subset of relative paths where the file is new or has a
@ -1528,7 +1546,7 @@ async def folder_mtime_check(
await check_permission(
session,
user,
auth,
request.search_space_id,
Permission.DOCUMENTS_CREATE.value,
"You don't have permission to create documents in this search space",
@ -1587,8 +1605,9 @@ async def folder_upload(
use_vision_llm: bool = Form(False),
processing_mode: str = Form("basic"),
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Upload files from the desktop app for folder indexing.
Files are written to temp storage and dispatched to a Celery task.
@ -1603,7 +1622,7 @@ async def folder_upload(
await check_permission(
session,
user,
auth,
search_space_id,
Permission.DOCUMENTS_CREATE.value,
"You don't have permission to create documents in this search space",
@ -1733,8 +1752,9 @@ async def folder_upload(
async def folder_unlink(
request: FolderUnlinkRequest,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Handle file deletion events from the desktop watcher.
For each relative path, find the matching document and delete it.
@ -1746,7 +1766,7 @@ async def folder_unlink(
await check_permission(
session,
user,
auth,
request.search_space_id,
Permission.DOCUMENTS_DELETE.value,
"You don't have permission to delete documents in this search space",
@ -1787,8 +1807,9 @@ async def folder_unlink(
async def folder_sync_finalize(
request: FolderSyncFinalizeRequest,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Finalize a full folder scan by deleting orphaned documents.
The client sends the complete list of relative paths currently in the
@ -1803,7 +1824,7 @@ async def folder_sync_finalize(
await check_permission(
session,
user,
auth,
request.search_space_id,
Permission.DOCUMENTS_DELETE.value,
"You don't have permission to delete documents in this search space",

View file

@ -21,21 +21,22 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm.attributes import flag_modified
from app.auth.context import AuthContext
from app.config import config
from app.connectors.dropbox import DropboxClient, list_folder_contents
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
User,
get_async_session,
)
from app.users import current_active_user
from app.users import get_auth_context, require_session_context
from app.utils.connector_naming import (
check_duplicate_connector,
extract_identifier_from_credentials,
generate_unique_connector_name,
)
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
from app.utils.rbac import check_search_space_access
logger = logging.getLogger(__name__)
router = APIRouter()
@ -66,8 +67,12 @@ def get_token_encryption() -> TokenEncryption:
@router.get("/auth/dropbox/connector/add")
async def connect_dropbox(space_id: int, user: User = Depends(current_active_user)):
async def connect_dropbox(
space_id: int,
auth: AuthContext = Depends(require_session_context),
):
"""Initiate Dropbox OAuth flow."""
user = auth.user
try:
if not space_id:
raise HTTPException(status_code=400, detail="space_id is required")
@ -109,10 +114,11 @@ async def reauth_dropbox(
space_id: int,
connector_id: int,
return_url: str | None = None,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
):
"""Re-authenticate an existing Dropbox connector."""
user = auth.user
try:
result = await session.execute(
select(SearchSourceConnector).filter(
@ -405,10 +411,11 @@ async def list_dropbox_folders(
connector_id: int,
parent_path: str = "",
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
"""List folders and files in user's Dropbox."""
connector = None
user = auth.user
try:
result = await session.execute(
select(SearchSourceConnector).filter(
@ -424,6 +431,8 @@ async def list_dropbox_folders(
status_code=404, detail="Dropbox connector not found or access denied"
)
await check_search_space_access(session, auth, connector.search_space_id)
dropbox_client = DropboxClient(session, connector_id)
items, error = await list_folder_contents(dropbox_client, path=parent_path)

View file

@ -18,6 +18,7 @@ from fastapi.responses import StreamingResponse
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.db import Chunk, Document, DocumentType, Permission, User, get_async_session
from app.routes.reports_routes import (
_FILE_EXTENSIONS,
@ -31,7 +32,7 @@ from app.templates.export_helpers import (
get_reference_docx_path,
get_typst_template_path,
)
from app.users import current_active_user
from app.users import get_auth_context
from app.utils.rbac import check_permission
logger = logging.getLogger(__name__)
@ -47,8 +48,9 @@ async def get_editor_content(
search_space_id: int,
document_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Get document content for editing.
@ -60,7 +62,7 @@ async def get_editor_content(
# Check RBAC permission
await check_permission(
session,
user,
auth,
search_space_id,
Permission.DOCUMENTS_READ.value,
"You don't have permission to read documents in this search space",
@ -178,15 +180,16 @@ async def download_document_markdown(
search_space_id: int,
document_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Download the full document content as a .md file.
Reconstructs markdown from source_markdown or chunks.
"""
await check_permission(
session,
user,
auth,
search_space_id,
Permission.DOCUMENTS_READ.value,
"You don't have permission to read documents in this search space",
@ -244,8 +247,9 @@ async def save_document(
document_id: int,
data: dict[str, Any],
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Save document markdown and trigger reindexing.
Called when user clicks 'Save & Exit'.
@ -259,7 +263,7 @@ async def save_document(
# Check RBAC permission
await check_permission(
session,
user,
auth,
search_space_id,
Permission.DOCUMENTS_UPDATE.value,
"You don't have permission to update documents in this search space",
@ -331,12 +335,13 @@ async def export_document(
description="Export format: pdf, docx, html, latex, epub, odt, or plain",
),
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Export a document in the requested format (reuses the report export pipeline)."""
await check_permission(
session,
user,
auth,
search_space_id,
Permission.DOCUMENTS_READ.value,
"You don't have permission to read documents in this search space",

View file

@ -7,9 +7,10 @@ from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.db import Permission, User, get_async_session
from app.services.export_service import build_export_zip
from app.users import current_active_user
from app.users import get_auth_context
from app.utils.rbac import check_permission
logger = logging.getLogger(__name__)
@ -24,12 +25,13 @@ async def export_knowledge_base(
None, description="Export only this folder's subtree"
),
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Export documents as a ZIP of markdown files preserving folder structure."""
await check_permission(
session,
user,
auth,
search_space_id,
Permission.DOCUMENTS_READ.value,
"You don't have permission to export documents in this search space",

View file

@ -5,6 +5,7 @@ from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.auth.context import AuthContext
from app.db import Document, Folder, Permission, User, get_async_session
from app.schemas import (
BulkDocumentMove,
@ -23,7 +24,7 @@ from app.services.folder_service import (
get_subtree_max_depth,
validate_folder_depth,
)
from app.users import current_active_user
from app.users import get_auth_context
from app.utils.rbac import check_permission
router = APIRouter()
@ -33,13 +34,14 @@ router = APIRouter()
async def create_folder(
request: FolderCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Create a new folder. Requires DOCUMENTS_CREATE permission."""
try:
await check_permission(
session,
user,
auth,
request.search_space_id,
Permission.DOCUMENTS_CREATE.value,
"You don't have permission to create folders in this search space",
@ -91,13 +93,14 @@ async def create_folder(
async def list_folders(
search_space_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""List all folders in a search space (flat). Requires DOCUMENTS_READ permission."""
try:
await check_permission(
session,
user,
auth,
search_space_id,
Permission.DOCUMENTS_READ.value,
"You don't have permission to read folders in this search space",
@ -122,8 +125,9 @@ async def list_folders(
async def get_folder(
folder_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Get a single folder. Requires DOCUMENTS_READ permission."""
try:
folder = await session.get(Folder, folder_id)
@ -132,7 +136,7 @@ async def get_folder(
await check_permission(
session,
user,
auth,
folder.search_space_id,
Permission.DOCUMENTS_READ.value,
"You don't have permission to read folders in this search space",
@ -152,8 +156,9 @@ async def get_folder(
async def get_folder_breadcrumb(
folder_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Get ancestor chain for breadcrumb display. Requires DOCUMENTS_READ permission."""
try:
folder = await session.get(Folder, folder_id)
@ -162,7 +167,7 @@ async def get_folder_breadcrumb(
await check_permission(
session,
user,
auth,
folder.search_space_id,
Permission.DOCUMENTS_READ.value,
"You don't have permission to read folders in this search space",
@ -196,8 +201,9 @@ async def get_folder_breadcrumb(
async def stop_watching_folder(
folder_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Clear the watched flag from a folder's metadata."""
folder = await session.get(Folder, folder_id)
if not folder:
@ -205,7 +211,7 @@ async def stop_watching_folder(
await check_permission(
session,
user,
auth,
folder.search_space_id,
Permission.DOCUMENTS_UPDATE.value,
"You don't have permission to update folders in this search space",
@ -224,8 +230,9 @@ async def update_folder(
folder_id: int,
request: FolderUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Rename a folder. Requires DOCUMENTS_UPDATE permission."""
try:
folder = await session.get(Folder, folder_id)
@ -234,7 +241,7 @@ async def update_folder(
await check_permission(
session,
user,
auth,
folder.search_space_id,
Permission.DOCUMENTS_UPDATE.value,
"You don't have permission to update folders in this search space",
@ -264,8 +271,9 @@ async def move_folder(
folder_id: int,
request: FolderMove,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Move a folder to a new parent. Requires DOCUMENTS_UPDATE permission."""
try:
folder = await session.get(Folder, folder_id)
@ -274,7 +282,7 @@ async def move_folder(
await check_permission(
session,
user,
auth,
folder.search_space_id,
Permission.DOCUMENTS_UPDATE.value,
"You don't have permission to move folders in this search space",
@ -324,8 +332,9 @@ async def reorder_folder(
folder_id: int,
request: FolderReorder,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Reorder a folder among its siblings via fractional indexing. Requires DOCUMENTS_UPDATE."""
try:
folder = await session.get(Folder, folder_id)
@ -334,7 +343,7 @@ async def reorder_folder(
await check_permission(
session,
user,
auth,
folder.search_space_id,
Permission.DOCUMENTS_UPDATE.value,
"You don't have permission to reorder folders in this search space",
@ -365,8 +374,9 @@ async def reorder_folder(
async def delete_folder(
folder_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Mark documents for deletion and dispatch Celery to delete docs first, then folders."""
try:
folder = await session.get(Folder, folder_id)
@ -375,7 +385,7 @@ async def delete_folder(
await check_permission(
session,
user,
auth,
folder.search_space_id,
Permission.DOCUMENTS_DELETE.value,
"You don't have permission to delete folders in this search space",
@ -439,8 +449,9 @@ async def move_document(
document_id: int,
request: DocumentMove,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Move a document to a folder (or root). Requires DOCUMENTS_UPDATE permission."""
try:
result = await session.execute(
@ -452,7 +463,7 @@ async def move_document(
await check_permission(
session,
user,
auth,
document.search_space_id,
Permission.DOCUMENTS_UPDATE.value,
"You don't have permission to move documents in this search space",
@ -485,8 +496,9 @@ async def move_document(
async def bulk_move_documents(
request: BulkDocumentMove,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Move multiple documents to a folder (or root). Requires DOCUMENTS_UPDATE permission."""
try:
if not request.document_ids:
@ -504,7 +516,7 @@ async def bulk_move_documents(
for ss_id in search_space_ids:
await check_permission(
session,
user,
auth,
ss_id,
Permission.DOCUMENTS_UPDATE.value,
"You don't have permission to move documents in this search space",

View file

@ -20,6 +20,7 @@ from sqlalchemy import or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from starlette.responses import JSONResponse, RedirectResponse, Response
from app.auth.context import AuthContext
from app.config import config
from app.db import (
ExternalChatAccount,
@ -51,7 +52,7 @@ from app.observability.metrics import (
record_gateway_inbox_write,
record_gateway_webhook_parse_error,
)
from app.users import current_active_user
from app.users import get_auth_context
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
from app.utils.rbac import check_search_space_access
@ -250,14 +251,15 @@ def _telegram_message(payload: dict[str, Any]) -> dict[str, Any] | None:
@router.get("/slack/install")
async def install_slack_gateway(
search_space_id: int,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session),
) -> dict[str, str]:
user = auth.user
if not _slack_gateway_enabled():
raise HTTPException(
status_code=500, detail="Slack gateway OAuth is not configured"
)
await check_search_space_access(session, user, search_space_id)
await check_search_space_access(session, auth, search_space_id)
state = _get_state_manager().generate_secure_state(search_space_id, user.id)
auth_params = {
"client_id": config.GATEWAY_SLACK_CLIENT_ID,
@ -409,14 +411,15 @@ async def slack_gateway_callback(
@router.get("/discord/install")
async def install_discord_gateway(
search_space_id: int,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session),
) -> dict[str, str]:
user = auth.user
if not _discord_gateway_enabled():
raise HTTPException(
status_code=500, detail="Discord gateway OAuth is not configured"
)
await check_search_space_access(session, user, search_space_id)
await check_search_space_access(session, auth, search_space_id)
state = _get_state_manager().generate_secure_state(search_space_id, user.id)
auth_params = {
"client_id": config.DISCORD_CLIENT_ID,
@ -712,10 +715,11 @@ async def telegram_webhook(
@router.post("/bindings/start", response_model=StartBindingResponse)
async def start_binding(
body: StartBindingRequest,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session),
) -> StartBindingResponse:
await check_search_space_access(session, user, body.search_space_id)
user = auth.user
await check_search_space_access(session, auth, body.search_space_id)
code = generate_pairing_code()
if body.platform == ExternalChatPlatform.TELEGRAM:
if not _telegram_gateway_enabled():
@ -774,9 +778,10 @@ async def start_binding(
@router.get("/bindings")
async def list_bindings(
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session),
) -> list[dict[str, Any]]:
user = auth.user
result = await session.execute(
select(ExternalChatBinding, ExternalChatAccount)
.join(
@ -803,9 +808,10 @@ async def list_bindings(
@router.get("/connections")
async def list_connections(
platform: ExternalChatPlatform | None = None,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session),
) -> list[dict[str, Any]]:
user = auth.user
active_whatsapp_mode = _active_whatsapp_account_mode()
if platform == ExternalChatPlatform.WHATSAPP and active_whatsapp_mode is None:
return []
@ -946,9 +952,10 @@ async def list_connections(
@router.get("/platforms")
async def list_platforms(
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session),
) -> list[dict[str, Any]]:
user = auth.user
result = await session.execute(
select(ExternalChatAccount).where(
(ExternalChatAccount.owner_user_id == user.id)
@ -970,8 +977,9 @@ async def list_platforms(
@config_router.get("/config")
async def get_gateway_config(
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
) -> dict[str, bool | str]:
user = auth.user
if not config.GATEWAY_ENABLED:
return {
"enabled": False,
@ -993,9 +1001,10 @@ async def get_gateway_config(
async def update_binding_search_space(
binding_id: int,
body: UpdateBindingSearchSpaceRequest,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session),
) -> dict[str, bool]:
user = auth.user
binding = await session.get(ExternalChatBinding, binding_id)
if binding is None or binding.user_id != user.id:
raise HTTPException(status_code=404, detail="Binding not found")
@ -1010,7 +1019,7 @@ async def update_binding_search_space(
if account is None or _is_inactive_whatsapp_account(account):
raise HTTPException(status_code=404, detail="Binding not found")
await check_search_space_access(session, user, body.search_space_id)
await check_search_space_access(session, auth, body.search_space_id)
if binding.search_space_id != body.search_space_id:
binding.search_space_id = body.search_space_id
binding.new_chat_thread_id = None
@ -1023,9 +1032,10 @@ async def update_binding_search_space(
async def update_gateway_account_search_space(
account_id: int,
body: UpdateAccountSearchSpaceRequest,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session),
) -> dict[str, bool]:
user = auth.user
account = await session.get(ExternalChatAccount, account_id)
if (
account is None
@ -1036,7 +1046,7 @@ async def update_gateway_account_search_space(
):
raise HTTPException(status_code=404, detail="Gateway account not found")
await check_search_space_access(session, user, body.search_space_id)
await check_search_space_access(session, auth, body.search_space_id)
account.owner_search_space_id = body.search_space_id
account.updated_at = datetime.now(UTC)
@ -1061,9 +1071,10 @@ async def update_gateway_account_search_space(
@router.delete("/bindings/{binding_id}")
async def delete_binding(
binding_id: int,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session),
) -> dict[str, bool]:
user = auth.user
binding = await session.get(ExternalChatBinding, binding_id)
if binding is None or binding.user_id != user.id:
raise HTTPException(status_code=404, detail="Binding not found")
@ -1078,9 +1089,10 @@ async def delete_binding(
@router.delete("/accounts/{account_id}")
async def delete_gateway_account(
account_id: int,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session),
) -> dict[str, bool]:
user = auth.user
account = await session.get(ExternalChatAccount, account_id)
if (
account is None
@ -1114,9 +1126,10 @@ async def delete_gateway_account(
@router.post("/bindings/{binding_id}/resume")
async def resume_external_chat_binding(
binding_id: int,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session),
) -> dict[str, bool]:
user = auth.user
binding = await session.get(ExternalChatBinding, binding_id)
if binding is None or binding.user_id != user.id:
raise HTTPException(status_code=404, detail="Binding not found")

View file

@ -10,6 +10,7 @@ from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.config import config
from app.db import (
ExternalChatAccount,
@ -20,7 +21,7 @@ from app.db import (
get_async_session,
)
from app.gateway.whatsapp.adapter_baileys import WhatsAppBaileysAdapter
from app.users import current_active_user
from app.users import get_auth_context
from app.utils.rbac import check_search_space_access
router = APIRouter(prefix="/gateway/whatsapp/baileys", tags=["gateway"])
@ -60,11 +61,12 @@ async def _get_user_whatsapp_account(
@router.post("/pair")
async def request_pairing_code(
body: BaileysPairRequest,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session),
) -> dict[str, Any]:
user = auth.user
_ensure_baileys_enabled()
await check_search_space_access(session, user, body.search_space_id)
await check_search_space_access(session, auth, body.search_space_id)
adapter = WhatsAppBaileysAdapter()
try:
pairing = await adapter.request_pairing_code(phone_number=body.phone_number)
@ -97,8 +99,9 @@ async def request_pairing_code(
@router.get("/health")
async def bridge_health(
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
) -> dict[str, Any]:
user = auth.user
_ensure_baileys_enabled()
adapter = WhatsAppBaileysAdapter()
try:

View file

@ -15,15 +15,15 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.auth.context import AuthContext
from app.config import config
from app.connectors.google_gmail_connector import fetch_google_user_email
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
User,
get_async_session,
)
from app.users import current_active_user
from app.users import require_session_context
from app.utils.connector_naming import (
check_duplicate_connector,
generate_unique_connector_name,
@ -88,7 +88,11 @@ def get_google_flow():
@router.get("/auth/google/calendar/connector/add")
async def connect_calendar(space_id: int, user: User = Depends(current_active_user)):
async def connect_calendar(
space_id: int,
auth: AuthContext = Depends(require_session_context),
):
user = auth.user
try:
if not space_id:
raise HTTPException(status_code=400, detail="space_id is required")
@ -127,10 +131,11 @@ async def reauth_calendar(
space_id: int,
connector_id: int,
return_url: str | None = None,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
):
"""Initiate Google Calendar re-authentication for an existing connector."""
user = auth.user
try:
result = await session.execute(
select(SearchSourceConnector).filter(

View file

@ -23,6 +23,7 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.auth.context import AuthContext
from app.config import config
from app.connectors.google_drive import (
GoogleDriveClient,
@ -33,10 +34,9 @@ from app.connectors.google_gmail_connector import fetch_google_user_email
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
User,
get_async_session,
)
from app.users import current_active_user
from app.users import get_auth_context, require_session_context
from app.utils.connector_naming import (
check_duplicate_connector,
generate_unique_connector_name,
@ -46,6 +46,7 @@ from app.utils.oauth_security import (
TokenEncryption,
generate_code_verifier,
)
from app.utils.rbac import check_search_space_access
# Relax token scope validation for Google OAuth
os.environ["OAUTHLIB_RELAX_TOKEN_SCOPE"] = "1"
@ -110,7 +111,10 @@ def get_google_flow():
@router.get("/auth/google/drive/connector/add")
async def connect_drive(space_id: int, user: User = Depends(current_active_user)):
async def connect_drive(
space_id: int,
auth: AuthContext = Depends(require_session_context),
):
"""
Initiate Google Drive OAuth flow.
@ -120,6 +124,7 @@ async def connect_drive(space_id: int, user: User = Depends(current_active_user)
Returns:
JSON with auth_url to redirect user to Google authorization
"""
user = auth.user
try:
if not space_id:
raise HTTPException(status_code=400, detail="space_id is required")
@ -165,7 +170,7 @@ async def reauth_drive(
space_id: int,
connector_id: int,
return_url: str | None = None,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
):
"""
@ -178,6 +183,7 @@ async def reauth_drive(
Returns:
JSON with auth_url to redirect user to Google authorization
"""
user = auth.user
try:
result = await session.execute(
select(SearchSourceConnector).filter(
@ -472,7 +478,7 @@ async def list_google_drive_folders(
connector_id: int,
parent_id: str | None = None,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
"""
List folders AND files in user's Google Drive with hierarchical support.
@ -492,6 +498,7 @@ async def list_google_drive_folders(
]
}
"""
user = auth.user
try:
# Get connector and verify ownership
result = await session.execute(
@ -510,6 +517,8 @@ async def list_google_drive_folders(
detail="Google Drive connector not found or access denied",
)
await check_search_space_access(session, auth, connector.search_space_id)
# Initialize Drive client (credentials will be loaded on first API call)
drive_client = GoogleDriveClient(session, connector_id)

View file

@ -15,15 +15,15 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.auth.context import AuthContext
from app.config import config
from app.connectors.google_gmail_connector import fetch_google_user_email
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
User,
get_async_session,
)
from app.users import current_active_user
from app.users import require_session_context
from app.utils.connector_naming import (
check_duplicate_connector,
generate_unique_connector_name,
@ -92,7 +92,10 @@ def get_google_flow():
@router.get("/auth/google/gmail/connector/add")
async def connect_gmail(space_id: int, user: User = Depends(current_active_user)):
async def connect_gmail(
space_id: int,
auth: AuthContext = Depends(require_session_context),
):
"""
Initiate Google Gmail OAuth flow.
@ -102,6 +105,7 @@ async def connect_gmail(space_id: int, user: User = Depends(current_active_user)
Returns:
JSON with auth_url to redirect user to Google authorization
"""
user = auth.user
try:
if not space_id:
raise HTTPException(status_code=400, detail="space_id is required")
@ -145,10 +149,11 @@ async def reauth_gmail(
space_id: int,
connector_id: int,
return_url: str | None = None,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
):
"""Initiate Gmail re-authentication for an existing connector."""
user = auth.user
try:
result = await session.execute(
select(SearchSourceConnector).filter(

View file

@ -16,6 +16,7 @@ from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.auth.context import AuthContext
from app.config import config
from app.db import (
ImageGeneration,
@ -46,7 +47,7 @@ from app.services.image_gen_router_service import (
)
from app.services.model_capabilities import has_capability
from app.services.model_resolver import to_litellm
from app.users import current_active_user
from app.users import get_auth_context
from app.utils.rbac import check_permission
from app.utils.signed_image_urls import verify_image_token
@ -245,8 +246,9 @@ async def _execute_image_generation(
async def create_image_generation(
data: ImageGenerationCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Create and execute an image generation request.
Premium configs are gated by the user's shared premium credit pool.
@ -270,7 +272,7 @@ async def create_image_generation(
try:
await check_permission(
session,
user,
auth,
data.search_space_id,
Permission.IMAGE_GENERATIONS_CREATE.value,
"You don't have permission to create image generations in this search space",
@ -365,8 +367,9 @@ async def list_image_generations(
skip: int = 0,
limit: int = 50,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""List image generations."""
if skip < 0 or limit < 1:
raise HTTPException(status_code=400, detail="Invalid pagination parameters")
@ -377,7 +380,7 @@ async def list_image_generations(
if search_space_id is not None:
await check_permission(
session,
user,
auth,
search_space_id,
Permission.IMAGE_GENERATIONS_READ.value,
"You don't have permission to read image generations in this search space",
@ -417,8 +420,9 @@ async def list_image_generations(
async def get_image_generation(
image_gen_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Get a specific image generation by ID."""
try:
result = await session.execute(
@ -430,7 +434,7 @@ async def get_image_generation(
await check_permission(
session,
user,
auth,
image_gen.search_space_id,
Permission.IMAGE_GENERATIONS_READ.value,
"You don't have permission to read image generations in this search space",
@ -449,8 +453,9 @@ async def get_image_generation(
async def delete_image_generation(
image_gen_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Delete an image generation record."""
try:
result = await session.execute(
@ -462,7 +467,7 @@ async def delete_image_generation(
await check_permission(
session,
user,
auth,
db_image_gen.search_space_id,
Permission.IMAGE_GENERATIONS_DELETE.value,
"You don't have permission to delete image generations in this search space",

View file

@ -8,10 +8,10 @@ from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.db import (
INCENTIVE_TASKS_CONFIG,
IncentiveTaskType,
User,
UserIncentiveTask,
get_async_session,
)
@ -21,19 +21,20 @@ from app.schemas.incentive_tasks import (
IncentiveTasksResponse,
TaskAlreadyCompletedResponse,
)
from app.users import current_active_user
from app.users import require_session_context
router = APIRouter(prefix="/incentive-tasks", tags=["incentive-tasks"])
@router.get("", response_model=IncentiveTasksResponse)
async def get_incentive_tasks(
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
) -> IncentiveTasksResponse:
"""
Get all available incentive tasks with the user's completion status.
"""
user = auth.user
# Get all completed tasks for this user
result = await session.execute(
select(UserIncentiveTask).where(UserIncentiveTask.user_id == user.id)
@ -75,7 +76,7 @@ async def get_incentive_tasks(
)
async def complete_task(
task_type: IncentiveTaskType,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
) -> CompleteTaskResponse | TaskAlreadyCompletedResponse:
"""
@ -84,6 +85,7 @@ async def complete_task(
Each task can only be completed once. If the task was already completed,
returns the existing completion information without awarding additional credit.
"""
user = auth.user
# Validate task type exists in config
task_config = INCENTIVE_TASKS_CONFIG.get(task_type)
if not task_config:

View file

@ -16,15 +16,15 @@ from pydantic import ValidationError
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.config import config
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
User,
get_async_session,
)
from app.schemas.atlassian_auth_credentials import AtlassianAuthCredentialsBase
from app.users import current_active_user
from app.users import require_session_context
from app.utils.connector_naming import (
check_duplicate_connector,
extract_identifier_from_credentials,
@ -75,7 +75,10 @@ def get_token_encryption() -> TokenEncryption:
@router.get("/auth/jira/connector/add")
async def connect_jira(space_id: int, user: User = Depends(current_active_user)):
async def connect_jira(
space_id: int,
auth: AuthContext = Depends(require_session_context),
):
"""
Initiate Jira OAuth flow.
@ -86,6 +89,7 @@ async def connect_jira(space_id: int, user: User = Depends(current_active_user))
Returns:
Authorization URL for redirect
"""
user = auth.user
try:
if not space_id:
raise HTTPException(status_code=400, detail="space_id is required")
@ -438,10 +442,11 @@ async def reauth_jira(
space_id: int,
connector_id: int,
return_url: str | None = None,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
):
"""Initiate Jira re-authentication to upgrade OAuth scopes."""
user = auth.user
try:
from sqlalchemy.future import select

View file

@ -17,16 +17,16 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.auth.context import AuthContext
from app.config import config
from app.connectors.linear_connector import fetch_linear_organization_name
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
User,
get_async_session,
)
from app.schemas.linear_auth_credentials import LinearAuthCredentialsBase
from app.users import current_active_user
from app.users import require_session_context
from app.utils.connector_naming import (
check_duplicate_connector,
generate_unique_connector_name,
@ -79,7 +79,10 @@ def make_basic_auth_header(client_id: str, client_secret: str) -> str:
@router.get("/auth/linear/connector/add")
async def connect_linear(space_id: int, user: User = Depends(current_active_user)):
async def connect_linear(
space_id: int,
auth: AuthContext = Depends(require_session_context),
):
"""
Initiate Linear OAuth flow.
@ -90,6 +93,7 @@ async def connect_linear(space_id: int, user: User = Depends(current_active_user
Returns:
Authorization URL for redirect
"""
user = auth.user
try:
if not space_id:
raise HTTPException(status_code=400, detail="space_id is required")
@ -134,10 +138,11 @@ async def reauth_linear(
space_id: int,
connector_id: int,
return_url: str | None = None,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
):
"""Initiate Linear re-authentication for an existing connector."""
user = auth.user
try:
result = await session.execute(
select(SearchSourceConnector).filter(

View file

@ -5,6 +5,7 @@ from sqlalchemy import and_, desc
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.auth.context import AuthContext
from app.db import (
Log,
LogLevel,
@ -16,7 +17,7 @@ from app.db import (
get_async_session,
)
from app.schemas import LogCreate, LogRead, LogUpdate
from app.users import current_active_user
from app.users import get_auth_context
from app.utils.rbac import check_permission
router = APIRouter()
@ -26,8 +27,9 @@ router = APIRouter()
async def create_log(
log: LogCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Create a new log entry.
Note: This is typically called internally. Requires LOGS_READ permission (since logs are usually system-generated).
@ -36,7 +38,7 @@ async def create_log(
# Check if the user has access to the search space
await check_permission(
session,
user,
auth,
log.search_space_id,
Permission.LOGS_READ.value,
"You don't have permission to access logs in this search space",
@ -67,8 +69,9 @@ async def read_logs(
start_date: datetime | None = None,
end_date: datetime | None = None,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Get logs with optional filtering.
Requires LOGS_READ permission for the search space(s).
@ -81,7 +84,7 @@ async def read_logs(
# Check permission for specific search space
await check_permission(
session,
user,
auth,
search_space_id,
Permission.LOGS_READ.value,
"You don't have permission to read logs in this search space",
@ -136,8 +139,9 @@ async def read_logs(
async def read_log(
log_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Get a specific log by ID.
Requires LOGS_READ permission for the search space.
@ -152,7 +156,7 @@ async def read_log(
# Check permission for the search space
await check_permission(
session,
user,
auth,
log.search_space_id,
Permission.LOGS_READ.value,
"You don't have permission to read logs in this search space",
@ -172,8 +176,9 @@ async def update_log(
log_id: int,
log_update: LogUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Update a log entry.
Requires LOGS_READ permission (logs are typically updated by system).
@ -188,7 +193,7 @@ async def update_log(
# Check permission for the search space
await check_permission(
session,
user,
auth,
db_log.search_space_id,
Permission.LOGS_READ.value,
"You don't have permission to access logs in this search space",
@ -215,8 +220,9 @@ async def update_log(
async def delete_log(
log_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Delete a log entry.
Requires LOGS_DELETE permission for the search space.
@ -231,7 +237,7 @@ async def delete_log(
# Check permission for the search space
await check_permission(
session,
user,
auth,
db_log.search_space_id,
Permission.LOGS_DELETE.value,
"You don't have permission to delete logs in this search space",
@ -254,8 +260,9 @@ async def get_logs_summary(
search_space_id: int,
hours: int = 24,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Get a summary of logs for a search space in the last X hours.
Requires LOGS_READ permission for the search space.
@ -264,7 +271,7 @@ async def get_logs_summary(
# Check permission
await check_permission(
session,
user,
auth,
search_space_id,
Permission.LOGS_READ.value,
"You don't have permission to read logs in this search space",

View file

@ -6,13 +6,13 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.auth.context import AuthContext
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
User,
get_async_session,
)
from app.users import current_active_user
from app.users import require_session_context
logger = logging.getLogger(__name__)
@ -29,7 +29,7 @@ class AddLumaConnectorRequest(BaseModel):
@router.post("/connectors/luma/add")
async def add_luma_connector(
request: AddLumaConnectorRequest,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
):
"""
@ -46,6 +46,7 @@ async def add_luma_connector(
Raises:
HTTPException: If connector already exists or validation fails
"""
user = auth.user
try:
# Check if a Luma connector already exists for this search space and user
result = await session.execute(
@ -118,7 +119,7 @@ async def add_luma_connector(
@router.delete("/connectors/luma")
async def delete_luma_connector(
space_id: int,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
):
"""
@ -135,6 +136,7 @@ async def delete_luma_connector(
Raises:
HTTPException: If connector doesn't exist
"""
user = auth.user
try:
result = await session.execute(
select(SearchSourceConnector).filter(
@ -173,7 +175,7 @@ async def delete_luma_connector(
@router.get("/connectors/luma/test")
async def test_luma_connector(
space_id: int,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
):
"""
@ -190,6 +192,7 @@ async def test_luma_connector(
Raises:
HTTPException: If connector doesn't exist or test fails
"""
user = auth.user
try:
# Get the Luma connector for this search space and user
result = await session.execute(

View file

@ -20,14 +20,14 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.auth.context import AuthContext
from app.config import config
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
User,
get_async_session,
)
from app.users import current_active_user
from app.users import require_session_context
from app.utils.connector_naming import generate_unique_connector_name
from app.utils.oauth_security import (
OAuthStateManager,
@ -164,8 +164,9 @@ def _frontend_redirect(
async def connect_mcp_service(
service: str,
space_id: int,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
):
user = auth.user
from app.services.mcp_oauth.registry import get_service
svc = get_service(service)
@ -523,9 +524,10 @@ async def reauth_mcp_service(
space_id: int,
connector_id: int,
return_url: str | None = None,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
):
user = auth.user
from app.services.mcp_oauth.registry import get_service
svc = get_service(service)

View file

@ -6,7 +6,8 @@ from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import User, get_async_session
from app.auth.context import AuthContext
from app.db import get_async_session
from app.services.memory import (
MemoryRead,
MemoryScope,
@ -15,7 +16,7 @@ from app.services.memory import (
reset_memory,
save_memory,
)
from app.users import current_active_user
from app.users import require_session_context
router = APIRouter()
@ -26,9 +27,10 @@ class MemoryUpdate(BaseModel):
@router.get("/users/me/memory", response_model=MemoryRead)
async def get_user_memory(
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
):
user = auth.user
memory_md = await read_memory(
scope=MemoryScope.USER,
target_id=user.id,
@ -40,9 +42,10 @@ async def get_user_memory(
@router.put("/users/me/memory", response_model=MemoryRead)
async def update_user_memory(
body: MemoryUpdate,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
):
user = auth.user
result = await save_memory(
scope=MemoryScope.USER,
target_id=user.id,
@ -56,9 +59,10 @@ async def update_user_memory(
@router.post("/users/me/memory/reset", response_model=MemoryRead)
async def reset_user_memory(
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
):
user = auth.user
result = await reset_memory(
scope=MemoryScope.USER,
target_id=user.id,

View file

@ -5,6 +5,7 @@ from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.auth.context import AuthContext
from app.config import config
from app.db import (
Connection,
@ -14,7 +15,6 @@ from app.db import (
NewChatThread,
Permission,
SearchSpace,
User,
get_async_session,
)
from app.schemas import (
@ -42,7 +42,7 @@ from app.services.model_connection_service import (
verify_connection,
)
from app.services.provider_registry import REGISTRY
from app.users import current_active_user
from app.users import get_auth_context, require_session_context
from app.utils.rbac import check_permission
router = APIRouter()
@ -257,8 +257,8 @@ async def _default_unset_roles(
@router.get("/model-providers", response_model=list[ModelProviderRead])
async def list_model_providers(user: User = Depends(current_active_user)):
del user
async def list_model_providers(auth: AuthContext = Depends(require_session_context)):
del auth
local_only = {"ollama_chat", "lm_studio"}
return [
ModelProviderRead(
@ -298,14 +298,16 @@ async def _load_connection(session: AsyncSession, connection_id: int) -> Connect
async def _assert_connection_access(
session: AsyncSession,
user: User,
auth: AuthContext,
conn: Connection,
permission: str = Permission.LLM_CONFIGS_CREATE.value,
allow_spaceless_pat: bool = False,
) -> None:
user = auth.user
if conn.search_space_id:
await check_permission(
session,
user,
auth,
conn.search_space_id,
permission,
"You don't have permission to manage model connections in this search space",
@ -315,17 +317,22 @@ async def _assert_connection_access(
raise HTTPException(
status_code=403, detail="Connection does not belong to user"
)
if auth.is_gated and not allow_spaceless_pat:
raise HTTPException(
status_code=403,
detail="Managing personal model connections requires an interactive session",
)
@router.get("/global-llm-config-status")
async def global_llm_config_status(user: User = Depends(current_active_user)):
del user
async def global_llm_config_status(auth: AuthContext = Depends(require_session_context)):
del auth
return {"exists": config.GLOBAL_LLM_CONFIG_FILE_EXISTS}
@router.get("/global-model-connections", response_model=list[ConnectionRead])
async def list_global_connections(user: User = Depends(current_active_user)):
del user
async def list_global_connections(auth: AuthContext = Depends(require_session_context)):
del auth
models_by_connection: dict[int, list[dict]] = {}
for model in config.GLOBAL_MODELS:
models_by_connection.setdefault(model["connection_id"], []).append(model)
@ -339,13 +346,14 @@ async def list_global_connections(user: User = Depends(current_active_user)):
async def list_connections(
search_space_id: int | None = None,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
stmt = select(Connection).options(selectinload(Connection.models))
if search_space_id is not None:
await check_permission(
session,
user,
auth,
search_space_id,
Permission.LLM_CONFIGS_CREATE.value,
"You don't have permission to view model connections in this search space",
@ -363,8 +371,9 @@ async def list_connections(
async def create_connection(
data: ConnectionCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
if data.scope == ConnectionScope.GLOBAL:
raise HTTPException(status_code=400, detail="GLOBAL connections are YAML-only")
if data.scope == ConnectionScope.SEARCH_SPACE:
@ -372,11 +381,16 @@ async def create_connection(
raise HTTPException(status_code=400, detail="search_space_id is required")
await check_permission(
session,
user,
auth,
data.search_space_id,
Permission.LLM_CONFIGS_CREATE.value,
"You don't have permission to create model connections in this search space",
)
elif auth.is_gated:
raise HTTPException(
status_code=403,
detail="Managing personal model connections requires an interactive session",
)
payload = data.model_dump(exclude={"search_space_id", "models"})
conn = Connection(
@ -411,16 +425,22 @@ async def create_connection(
async def preview_connection_models(
data: ConnectionCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
if data.scope == ConnectionScope.SEARCH_SPACE and data.search_space_id is not None:
await check_permission(
session,
user,
auth,
data.search_space_id,
Permission.LLM_CONFIGS_CREATE.value,
"You don't have permission to create model connections in this search space",
)
elif auth.is_gated:
raise HTTPException(
status_code=403,
detail="Testing personal model connections requires an interactive session",
)
draft = Connection(
provider=data.provider,
@ -445,16 +465,22 @@ async def preview_connection_models(
async def test_preview_connection_model(
data: ModelTestPreview,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
if data.scope == ConnectionScope.SEARCH_SPACE and data.search_space_id is not None:
await check_permission(
session,
user,
auth,
data.search_space_id,
Permission.LLM_CONFIGS_CREATE.value,
"You don't have permission to create model connections in this search space",
)
elif auth.is_gated:
raise HTTPException(
status_code=403,
detail="Testing personal model connections requires an interactive session",
)
model_id = data.model_id.strip()
if not model_id:
@ -491,11 +517,11 @@ async def update_connection(
connection_id: int,
data: ConnectionUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
conn = await _load_connection(session, connection_id)
await _assert_connection_access(
session, user, conn, Permission.LLM_CONFIGS_UPDATE.value
session, auth, conn, Permission.LLM_CONFIGS_UPDATE.value
)
search_space_id = conn.search_space_id
for key, value in data.model_dump(exclude_unset=True).items():
@ -512,11 +538,11 @@ async def update_connection(
async def delete_connection(
connection_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
conn = await _load_connection(session, connection_id)
await _assert_connection_access(
session, user, conn, Permission.LLM_CONFIGS_DELETE.value
session, auth, conn, Permission.LLM_CONFIGS_DELETE.value
)
search_space_id = conn.search_space_id
await session.delete(conn)
@ -533,11 +559,11 @@ async def delete_connection(
async def verify_model_connection(
connection_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
conn = await _load_connection(session, connection_id)
await _assert_connection_access(
session, user, conn, Permission.LLM_CONFIGS_CREATE.value
session, auth, conn, Permission.LLM_CONFIGS_CREATE.value
)
result = await verify_connection(conn)
return VerifyConnectionResponse(
@ -551,11 +577,11 @@ async def verify_model_connection(
async def discover_connection_models(
connection_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
conn = await _load_connection(session, connection_id)
await _assert_connection_access(
session, user, conn, Permission.LLM_CONFIGS_CREATE.value
session, auth, conn, Permission.LLM_CONFIGS_CREATE.value
)
try:
discovered = await discover_models(conn)
@ -595,11 +621,11 @@ async def add_manual_model(
connection_id: int,
data: ModelCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
conn = await _load_connection(session, connection_id)
await _assert_connection_access(
session, user, conn, Permission.LLM_CONFIGS_UPDATE.value
session, auth, conn, Permission.LLM_CONFIGS_UPDATE.value
)
model_id = data.model_id.strip()
@ -640,11 +666,11 @@ async def bulk_update_models(
connection_id: int,
data: ModelsBulkUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
conn = await _load_connection(session, connection_id)
await _assert_connection_access(
session, user, conn, Permission.LLM_CONFIGS_UPDATE.value
session, auth, conn, Permission.LLM_CONFIGS_UPDATE.value
)
search_space_id = conn.search_space_id
@ -674,7 +700,7 @@ async def update_model(
model_id: int,
data: ModelUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
result = await session.execute(
select(Model)
@ -685,7 +711,7 @@ async def update_model(
if not model:
raise HTTPException(status_code=404, detail="Model not found")
await _assert_connection_access(
session, user, model.connection, Permission.LLM_CONFIGS_UPDATE.value
session, auth, model.connection, Permission.LLM_CONFIGS_UPDATE.value
)
search_space_id = model.connection.search_space_id
update = data.model_dump(exclude_unset=True)
@ -704,7 +730,7 @@ async def update_model(
async def test_connection_model(
model_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
result = await session.execute(
select(Model)
@ -715,7 +741,7 @@ async def test_connection_model(
if not model:
raise HTTPException(status_code=404, detail="Model not found")
await _assert_connection_access(
session, user, model.connection, Permission.LLM_CONFIGS_UPDATE.value
session, auth, model.connection, Permission.LLM_CONFIGS_UPDATE.value
)
result = await test_model(model.connection, model)
await session.commit()
@ -730,11 +756,11 @@ async def test_connection_model(
async def get_model_roles(
search_space_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
await check_permission(
session,
user,
auth,
search_space_id,
Permission.LLM_CONFIGS_CREATE.value,
"You don't have permission to view model roles in this search space",
@ -756,11 +782,11 @@ async def update_model_roles(
search_space_id: int,
data: ModelRolesUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
await check_permission(
session,
user,
auth,
search_space_id,
Permission.LLM_CONFIGS_UPDATE.value,
"You don't have permission to update model roles in this search space",

View file

@ -10,9 +10,9 @@ import logging
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from app.db import User
from app.auth.context import AuthContext
from app.services.model_list_service import get_model_list
from app.users import current_active_user
from app.users import require_session_context
router = APIRouter()
logger = logging.getLogger(__name__)
@ -27,7 +27,7 @@ class ModelListItem(BaseModel):
@router.get("/models", response_model=list[ModelListItem])
async def list_available_models(
user: User = Depends(current_active_user),
_auth: AuthContext = Depends(require_session_context),
):
"""
Return all available models grouped by provider.

View file

@ -36,6 +36,7 @@ from app.agents.chat.multi_agent_chat.shared.filesystem_selection import (
FilesystemSelection,
LocalFilesystemMount,
)
from app.auth.context import AuthContext
from app.config import config
from app.db import (
ChatComment,
@ -75,7 +76,7 @@ from app.tasks.chat.streaming.flows import (
stream_new_chat,
stream_resume_chat,
)
from app.users import current_active_user
from app.users import get_auth_context
from app.utils.perf import get_perf_logger
from app.utils.rbac import check_permission
from app.utils.user_message_multimodal import (
@ -595,8 +596,9 @@ async def list_threads(
search_space_id: int,
limit: int | None = None,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
List all accessible threads for the current user in a search space.
Returns threads and archived_threads for ThreadListPrimitive.
@ -615,7 +617,7 @@ async def list_threads(
try:
await check_permission(
session,
user,
auth,
search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to read chats in this search space",
@ -702,8 +704,9 @@ async def search_threads(
search_space_id: int,
title: str,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Search accessible threads by title in a search space.
@ -721,7 +724,7 @@ async def search_threads(
try:
await check_permission(
session,
user,
auth,
search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to read chats in this search space",
@ -794,8 +797,9 @@ async def search_threads(
async def create_thread(
thread: NewChatThreadCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Create a new chat thread.
@ -807,7 +811,7 @@ async def create_thread(
try:
await check_permission(
session,
user,
auth,
thread.search_space_id,
Permission.CHATS_CREATE.value,
"You don't have permission to create chats in this search space",
@ -852,8 +856,9 @@ async def create_thread(
async def get_thread_messages(
thread_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Get a thread with all its messages.
This is used by ThreadHistoryAdapter.load() to restore conversation.
@ -877,7 +882,7 @@ async def get_thread_messages(
# Check permission to read chats in this search space
await check_permission(
session,
user,
auth,
thread.search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to read chats in this search space",
@ -936,8 +941,9 @@ async def get_thread_messages(
async def get_thread_full(
thread_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Get full thread details with all messages.
@ -964,7 +970,7 @@ async def get_thread_full(
await check_permission(
session,
user,
auth,
thread.search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to read chats in this search space",
@ -1005,8 +1011,9 @@ async def update_thread(
thread_id: int,
thread_update: NewChatThreadUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Update a thread (title, archived status).
Used for renaming and archiving threads.
@ -1027,7 +1034,7 @@ async def update_thread(
await check_permission(
session,
user,
auth,
db_thread.search_space_id,
Permission.CHATS_UPDATE.value,
"You don't have permission to update chats in this search space",
@ -1074,8 +1081,9 @@ async def update_thread(
async def delete_thread(
thread_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Delete a thread and all its messages.
@ -1095,7 +1103,7 @@ async def delete_thread(
await check_permission(
session,
user,
auth,
db_thread.search_space_id,
Permission.CHATS_DELETE.value,
"You don't have permission to delete chats in this search space",
@ -1146,8 +1154,9 @@ async def update_thread_visibility(
thread_id: int,
visibility_update: NewChatThreadVisibilityUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Update the visibility/sharing settings of a thread.
@ -1168,7 +1177,7 @@ async def update_thread_visibility(
await check_permission(
session,
user,
auth,
db_thread.search_space_id,
Permission.CHATS_UPDATE.value,
"You don't have permission to update chats in this search space",
@ -1217,7 +1226,7 @@ async def update_thread_visibility(
async def create_thread_snapshot(
thread_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
"""
Create a public snapshot of the thread.
@ -1229,7 +1238,7 @@ async def create_thread_snapshot(
return await create_snapshot(
session=session,
thread_id=thread_id,
user=user,
auth=auth,
)
@ -1239,7 +1248,7 @@ async def create_thread_snapshot(
async def list_thread_snapshots(
thread_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
"""
List all public snapshots for this thread.
@ -1252,7 +1261,7 @@ async def list_thread_snapshots(
snapshots=await list_snapshots_for_thread(
session=session,
thread_id=thread_id,
user=user,
auth=auth,
)
)
@ -1262,7 +1271,7 @@ async def delete_thread_snapshot(
thread_id: int,
snapshot_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
"""
Delete a specific snapshot.
@ -1275,7 +1284,7 @@ async def delete_thread_snapshot(
session=session,
thread_id=thread_id,
snapshot_id=snapshot_id,
user=user,
auth=auth,
)
return {"message": "Snapshot deleted successfully"}
@ -1290,8 +1299,9 @@ async def append_message(
thread_id: int,
request: Request,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
.. deprecated:: 2026-05
Replaced by the **SSE-based message ID handshake**. The streaming
@ -1321,8 +1331,8 @@ async def append_message(
Requires CHATS_UPDATE permission.
"""
try:
# Capture ``user.id`` as a primitive UUID up front. The
# ``current_active_user`` dependency hands us a ``User`` ORM
# Capture ``user.id`` as a primitive UUID up front. The auth
# dependency hands us a ``User`` ORM
# row bound to ``session``; if the outer ``except
# IntegrityError`` block below ever fires (an unexpected
# constraint like a foreign key violation — the common
@ -1370,7 +1380,7 @@ async def append_message(
await check_permission(
session,
user,
auth,
thread.search_space_id,
Permission.CHATS_UPDATE.value,
"You don't have permission to update chats in this search space",
@ -1597,8 +1607,9 @@ async def list_messages(
skip: int = 0,
limit: int = 100,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
List messages in a thread with pagination.
@ -1620,7 +1631,7 @@ async def list_messages(
await check_permission(
session,
user,
auth,
thread.search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to read chats in this search space",
@ -1662,7 +1673,7 @@ async def list_messages(
@router.get("/agent/tools", response_model=list[AgentToolInfo])
async def list_agent_tools(
_user: User = Depends(current_active_user),
_auth: AuthContext = Depends(get_auth_context),
):
"""Return the list of built-in agent tools with their metadata.
@ -1691,8 +1702,9 @@ async def handle_new_chat(
request: NewChatRequest,
http_request: Request,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Stream chat responses from the deep agent.
@ -1717,7 +1729,7 @@ async def handle_new_chat(
await check_permission(
session,
user,
auth,
thread.search_space_id,
Permission.CHATS_CREATE.value,
"You don't have permission to chat in this search space",
@ -1795,6 +1807,7 @@ async def handle_new_chat(
filesystem_selection=filesystem_selection,
request_id=getattr(http_request.state, "request_id", "unknown"),
user_image_data_urls=image_urls,
auth_context=auth,
),
media_type="text/event-stream",
headers={
@ -1821,8 +1834,9 @@ async def cancel_active_turn(
thread_id: int,
response: Response,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Signal cancellation for the currently running turn on ``thread_id``."""
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
@ -1833,7 +1847,7 @@ async def cancel_active_turn(
await check_permission(
session,
user,
auth,
thread.search_space_id,
Permission.CHATS_UPDATE.value,
"You don't have permission to update chats in this search space",
@ -1873,8 +1887,9 @@ async def cancel_active_turn(
async def get_turn_status(
thread_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
@ -1884,7 +1899,7 @@ async def get_turn_status(
await check_permission(
session,
user,
auth,
thread.search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to view chats in this search space",
@ -1911,8 +1926,9 @@ async def regenerate_response(
request: RegenerateRequest,
http_request: Request,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Regenerate the AI response for a chat thread.
@ -1947,7 +1963,7 @@ async def regenerate_response(
await check_permission(
session,
user,
auth,
thread.search_space_id,
Permission.CHATS_UPDATE.value,
"You don't have permission to update chats in this search space",
@ -2288,6 +2304,7 @@ async def regenerate_response(
filesystem_selection=filesystem_selection,
request_id=getattr(http_request.state, "request_id", "unknown"),
user_image_data_urls=regenerate_image_urls or None,
auth_context=auth,
flow="regenerate",
):
yield chunk
@ -2356,8 +2373,9 @@ async def resume_chat(
request: ResumeRequest,
http_request: Request,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
try:
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
@ -2369,7 +2387,7 @@ async def resume_chat(
await check_permission(
session,
user,
auth,
thread.search_space_id,
Permission.CHATS_CREATE.value,
"You don't have permission to chat in this search space",
@ -2413,6 +2431,7 @@ async def resume_chat(
filesystem_selection=filesystem_selection,
request_id=getattr(http_request.state, "request_id", "unknown"),
disabled_tools=request.disabled_tools,
auth_context=auth,
),
media_type="text/event-stream",
headers={

View file

@ -9,9 +9,10 @@ from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.db import Document, DocumentType, Permission, User, get_async_session
from app.schemas import DocumentRead, PaginatedResponse
from app.users import current_active_user
from app.users import get_auth_context
from app.utils.rbac import check_permission
router = APIRouter()
@ -27,8 +28,9 @@ async def create_note(
search_space_id: int,
request: CreateNoteRequest,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Create a new note document.
@ -37,7 +39,7 @@ async def create_note(
# Check RBAC permission
await check_permission(
session,
user,
auth,
search_space_id,
Permission.DOCUMENTS_CREATE.value,
"You don't have permission to create notes in this search space",
@ -98,8 +100,9 @@ async def list_notes(
page: int | None = None,
page_size: int = 50,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
List all notes in a search space.
@ -108,7 +111,7 @@ async def list_notes(
# Check RBAC permission
await check_permission(
session,
user,
auth,
search_space_id,
Permission.DOCUMENTS_READ.value,
"You don't have permission to read notes in this search space",
@ -191,8 +194,9 @@ async def delete_note(
search_space_id: int,
note_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Delete a note.
@ -201,7 +205,7 @@ async def delete_note(
# Check RBAC permission
await check_permission(
session,
user,
auth,
search_space_id,
Permission.DOCUMENTS_DELETE.value,
"You don't have permission to delete notes in this search space",

View file

@ -17,15 +17,15 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.auth.context import AuthContext
from app.config import config
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
User,
get_async_session,
)
from app.schemas.notion_auth_credentials import NotionAuthCredentialsBase
from app.users import current_active_user
from app.users import require_session_context
from app.utils.connector_naming import (
check_duplicate_connector,
extract_identifier_from_credentials,
@ -76,7 +76,10 @@ def make_basic_auth_header(client_id: str, client_secret: str) -> str:
@router.get("/auth/notion/connector/add")
async def connect_notion(space_id: int, user: User = Depends(current_active_user)):
async def connect_notion(
space_id: int,
auth: AuthContext = Depends(require_session_context),
):
"""
Initiate Notion OAuth flow.
@ -87,6 +90,7 @@ async def connect_notion(space_id: int, user: User = Depends(current_active_user
Returns:
Authorization URL for redirect
"""
user = auth.user
try:
if not space_id:
raise HTTPException(status_code=400, detail="space_id is required")
@ -131,10 +135,11 @@ async def reauth_notion(
space_id: int,
connector_id: int,
return_url: str | None = None,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
):
"""Initiate Notion re-authentication for an existing connector."""
user = auth.user
try:
result = await session.execute(
select(SearchSourceConnector).filter(

View file

@ -24,14 +24,14 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.auth.context import AuthContext
from app.config import config
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
User,
get_async_session,
)
from app.users import current_active_user
from app.users import require_session_context
from app.utils.connector_naming import (
check_duplicate_connector,
generate_unique_connector_name,
@ -361,8 +361,9 @@ class OAuthConnectorRoute:
@router.get(f"{oauth.auth_prefix}/connector/add")
async def connect(
space_id: int,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
):
user = auth.user
if not space_id:
raise HTTPException(status_code=400, detail="space_id is required")
@ -406,9 +407,10 @@ class OAuthConnectorRoute:
space_id: int,
connector_id: int,
return_url: str | None = None,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
):
user = auth.user
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == connector_id,

View file

@ -16,6 +16,7 @@ from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.auth.context import AuthContext
from app.db import (
Document,
DocumentType,
@ -52,7 +53,8 @@ from app.services.obsidian_plugin_indexer import (
upsert_note,
)
from app.tasks.celery_tasks.obsidian_tasks import index_obsidian_attachment_task
from app.users import current_active_user
from app.users import allow_any_principal, get_auth_context
from app.utils.rbac import check_search_space_access
logger = logging.getLogger(__name__)
@ -174,10 +176,11 @@ async def _finish_obsidian_sync_notification(
async def _resolve_vault_connector(
session: AsyncSession,
*,
user: User,
auth: AuthContext,
vault_id: str,
) -> SearchSourceConnector:
"""Find the OBSIDIAN_CONNECTOR row that owns ``vault_id`` for this user."""
user = auth.user
# ``config`` is core ``JSON`` (not ``JSONB``); ``as_string()`` is the
# cross-dialect equivalent of ``.astext`` and compiles to ``->>``.
stmt = select(SearchSourceConnector).where(
@ -192,6 +195,7 @@ async def _resolve_vault_connector(
connector = (await session.execute(stmt)).scalars().first()
if connector is not None:
await check_search_space_access(session, auth, connector.search_space_id)
return connector
raise HTTPException(
@ -221,10 +225,11 @@ def _queue_obsidian_attachment(
async def _ensure_search_space_access(
session: AsyncSession,
*,
user: User,
auth: AuthContext,
search_space_id: int,
) -> SearchSpace:
"""Owner-only access to the search space (shared spaces are a follow-up)."""
user = auth.user
result = await session.execute(
select(SearchSpace).where(
and_(SearchSpace.id == search_space_id, SearchSpace.user_id == user.id)
@ -239,6 +244,7 @@ async def _ensure_search_space_access(
"message": "You don't own that search space.",
},
)
await check_search_space_access(session, auth, search_space_id)
return space
@ -249,7 +255,7 @@ async def _ensure_search_space_access(
@router.get("/health", response_model=HealthResponse)
async def obsidian_health(
user: User = Depends(current_active_user),
_auth: AuthContext = Depends(allow_any_principal),
) -> HealthResponse:
"""Return the API contract handshake; plugin caches it per onload."""
return HealthResponse(
@ -306,7 +312,7 @@ def _display_name(vault_name: str) -> str:
@router.post("/connect", response_model=ConnectResponse)
async def obsidian_connect(
payload: ConnectRequest,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session),
) -> ConnectResponse:
"""Register a vault, refresh an existing one, or adopt another device's row.
@ -321,8 +327,9 @@ async def obsidian_connect(
the partial unique index can never produce two live rows for one vault.
"""
await _ensure_search_space_access(
session, user=user, search_space_id=payload.search_space_id
session, auth=auth, search_space_id=payload.search_space_id
)
user = auth.user
now_iso = datetime.now(UTC).isoformat()
cfg = _build_config(payload, now_iso=now_iso)
@ -445,13 +452,14 @@ async def obsidian_connect(
@router.post("/sync", response_model=SyncAck)
async def obsidian_sync(
payload: SyncBatchRequest,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session),
) -> SyncAck:
"""Batch-upsert notes; returns per-note ack so the plugin can dequeue/retry."""
connector = await _resolve_vault_connector(
session, user=user, vault_id=payload.vault_id
session, auth=auth, vault_id=payload.vault_id
)
user = auth.user
notification = None
try:
notification = await _start_obsidian_sync_notification(
@ -551,12 +559,12 @@ async def obsidian_sync(
@router.post("/rename", response_model=RenameAck)
async def obsidian_rename(
payload: RenameBatchRequest,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session),
) -> RenameAck:
"""Apply a batch of vault rename events."""
connector = await _resolve_vault_connector(
session, user=user, vault_id=payload.vault_id
session, auth=auth, vault_id=payload.vault_id
)
items: list[RenameAckItem] = []
@ -618,12 +626,12 @@ async def obsidian_rename(
@router.delete("/notes", response_model=DeleteAck)
async def obsidian_delete_notes(
payload: DeleteBatchRequest,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session),
) -> DeleteAck:
"""Soft-delete a batch of notes by vault-relative path."""
connector = await _resolve_vault_connector(
session, user=user, vault_id=payload.vault_id
session, auth=auth, vault_id=payload.vault_id
)
deleted = 0
@ -662,18 +670,18 @@ async def obsidian_delete_notes(
@router.get("/manifest", response_model=ManifestResponse)
async def obsidian_manifest(
vault_id: str = Query(..., description="Plugin-side stable vault UUID"),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session),
) -> ManifestResponse:
"""Return ``{path: {hash, mtime}}`` for the plugin's onload reconcile diff."""
connector = await _resolve_vault_connector(session, user=user, vault_id=vault_id)
connector = await _resolve_vault_connector(session, auth=auth, vault_id=vault_id)
return await get_manifest(session, connector=connector, vault_id=vault_id)
@router.get("/stats", response_model=StatsResponse)
async def obsidian_stats(
vault_id: str = Query(..., description="Plugin-side stable vault UUID"),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session),
) -> StatsResponse:
"""Active-note count + last sync time for the web tile.
@ -681,7 +689,7 @@ async def obsidian_stats(
``files_synced`` excludes tombstones so it matches ``/manifest``;
``last_sync_at`` includes them so deletes advance the freshness signal.
"""
connector = await _resolve_vault_connector(session, user=user, vault_id=vault_id)
connector = await _resolve_vault_connector(session, auth=auth, vault_id=vault_id)
is_active = Document.document_metadata["deleted_at"].as_string().is_(None)

View file

@ -21,21 +21,22 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm.attributes import flag_modified
from app.auth.context import AuthContext
from app.config import config
from app.connectors.onedrive import OneDriveClient, list_folder_contents
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
User,
get_async_session,
)
from app.users import current_active_user
from app.users import get_auth_context, require_session_context
from app.utils.connector_naming import (
check_duplicate_connector,
extract_identifier_from_credentials,
generate_unique_connector_name,
)
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
from app.utils.rbac import check_search_space_access
logger = logging.getLogger(__name__)
router = APIRouter()
@ -73,8 +74,12 @@ def get_token_encryption() -> TokenEncryption:
@router.get("/auth/onedrive/connector/add")
async def connect_onedrive(space_id: int, user: User = Depends(current_active_user)):
async def connect_onedrive(
space_id: int,
auth: AuthContext = Depends(require_session_context),
):
"""Initiate OneDrive OAuth flow."""
user = auth.user
try:
if not space_id:
raise HTTPException(status_code=400, detail="space_id is required")
@ -119,10 +124,11 @@ async def reauth_onedrive(
space_id: int,
connector_id: int,
return_url: str | None = None,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
):
"""Re-authenticate an existing OneDrive connector."""
user = auth.user
try:
result = await session.execute(
select(SearchSourceConnector).filter(
@ -412,10 +418,11 @@ async def list_onedrive_folders(
connector_id: int,
parent_id: str | None = None,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
"""List folders and files in user's OneDrive."""
connector = None
user = auth.user
try:
result = await session.execute(
select(SearchSourceConnector).filter(
@ -431,6 +438,8 @@ async def list_onedrive_folders(
status_code=404, detail="OneDrive connector not found or access denied"
)
await check_search_space_access(session, auth, connector.search_space_id)
onedrive_client = OneDriveClient(session, connector_id)
items, error = await list_folder_contents(onedrive_client, parent_id=parent_id)

View file

@ -0,0 +1,104 @@
from datetime import UTC, datetime, timedelta
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import delete
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.auth.context import AuthContext
from app.config import config
from app.db import PersonalAccessToken, get_async_session
from app.schemas.pat import PATCreate, PATCreated, PATRead
from app.users import require_session_context
from app.utils.pat import generate_pat, hash_pat, token_prefix
router = APIRouter()
def _expires_at(expires_in_days: int | None) -> datetime | None:
max_expiry_days = config.PAT_MAX_EXPIRY_DAYS
if max_expiry_days is not None:
if expires_in_days is None:
raise HTTPException(
status_code=400,
detail=(
"This deployment requires PATs to have an expiry of "
f"{max_expiry_days} days or less"
),
)
if expires_in_days > max_expiry_days:
raise HTTPException(
status_code=400,
detail=f"PAT expiry cannot exceed {max_expiry_days} days",
)
if expires_in_days is None:
return None
return datetime.now(UTC) + timedelta(days=expires_in_days)
@router.post("/pats", response_model=PATCreated)
async def create_personal_access_token(
body: PATCreate,
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(require_session_context),
) -> PATCreated:
token = generate_pat()
pat = PersonalAccessToken(
user_id=auth.user.id,
token_hash=hash_pat(token),
token_prefix=token_prefix(token),
label=body.label.strip(),
expires_at=_expires_at(body.expires_in_days),
)
session.add(pat)
await session.commit()
await session.refresh(pat)
return PATCreated(
id=pat.id,
label=pat.label,
token=token,
prefix=pat.token_prefix,
expires_at=pat.expires_at,
)
@router.get("/pats", response_model=list[PATRead])
async def list_personal_access_tokens(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(require_session_context),
) -> list[PATRead]:
result = await session.execute(
select(PersonalAccessToken)
.where(PersonalAccessToken.user_id == auth.user.id)
.order_by(PersonalAccessToken.created_at.desc())
)
return [
PATRead(
id=pat.id,
label=pat.label,
prefix=pat.token_prefix,
expires_at=pat.expires_at,
last_used_at=pat.last_used_at,
created_at=pat.created_at,
)
for pat in result.scalars().all()
]
@router.delete("/pats/{pat_id}", status_code=204)
async def delete_personal_access_token(
pat_id: int,
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(require_session_context),
) -> None:
await session.execute(
delete(PersonalAccessToken).where(
PersonalAccessToken.id == pat_id,
PersonalAccessToken.user_id == auth.user.id,
)
)
await session.commit()

View file

@ -3,14 +3,15 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.db import Prompt, SearchSpaceMembership, User, get_async_session
from app.auth.context import AuthContext
from app.db import Prompt, SearchSpaceMembership, get_async_session
from app.schemas.prompts import (
PromptCreate,
PromptRead,
PromptUpdate,
PublicPromptRead,
)
from app.users import current_active_user
from app.users import require_session_context
router = APIRouter(tags=["Prompts"])
@ -19,8 +20,9 @@ router = APIRouter(tags=["Prompts"])
async def list_prompts(
search_space_id: int | None = None,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
):
user = auth.user
query = select(Prompt).where(Prompt.user_id == user.id)
if search_space_id is not None:
query = query.where(Prompt.search_space_id == search_space_id)
@ -33,8 +35,9 @@ async def list_prompts(
async def create_prompt(
body: PromptCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
):
user = auth.user
if body.search_space_id is not None:
membership = await session.execute(
select(SearchSpaceMembership).where(
@ -67,8 +70,9 @@ async def update_prompt(
prompt_id: int,
body: PromptUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
):
user = auth.user
result = await session.execute(
select(Prompt).where(
Prompt.id == prompt_id,
@ -99,8 +103,9 @@ async def update_prompt(
async def delete_prompt(
prompt_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
):
user = auth.user
result = await session.execute(
select(Prompt).where(
Prompt.id == prompt_id,
@ -119,8 +124,9 @@ async def delete_prompt(
@router.get("/prompts/public", response_model=list[PublicPromptRead])
async def list_public_prompts(
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
):
user = auth.user
result = await session.execute(
select(Prompt)
.options(selectinload(Prompt.user))
@ -141,8 +147,9 @@ async def list_public_prompts(
async def copy_public_prompt(
prompt_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
):
user = auth.user
result = await session.execute(
select(Prompt).where(
Prompt.id == prompt_id,

View file

@ -11,7 +11,8 @@ from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import User, get_async_session
from app.auth.context import AuthContext
from app.db import get_async_session
from app.schemas.new_chat import (
CloneResponse,
PublicChatResponse,
@ -23,7 +24,7 @@ from app.services.public_chat_service import (
get_snapshot_report,
get_snapshot_video_presentation,
)
from app.users import current_active_user
from app.users import require_session_context
router = APIRouter(prefix="/public", tags=["public"])
@ -46,8 +47,9 @@ async def read_public_chat(
async def clone_public_chat(
share_token: str,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
):
user = auth.user
"""
Clone a public chat snapshot to the user's account.

View file

@ -18,6 +18,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from app.auth.context import AuthContext
from app.db import (
Permission,
SearchSpace,
@ -43,7 +44,7 @@ from app.schemas import (
RoleUpdate,
UserSearchSpaceAccess,
)
from app.users import current_active_user
from app.users import get_auth_context
from app.utils.rbac import (
check_permission,
check_search_space_access,
@ -107,6 +108,8 @@ PERMISSION_DESCRIPTIONS = {
"settings:view": "View search space settings",
"settings:update": "Modify search space settings",
"settings:delete": "Delete the entire search space",
# API access
"api_access:manage": "Enable or disable programmatic API access for a search space",
# Automations
"automations:create": "Create automations from chat or JSON",
"automations:read": "View automations, their triggers, and run history",
@ -120,8 +123,9 @@ PERMISSION_DESCRIPTIONS = {
@router.get("/permissions", response_model=PermissionsListResponse)
async def list_all_permissions(
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
List all available permissions that can be assigned to roles.
"""
@ -156,8 +160,9 @@ async def create_role(
search_space_id: int,
role_data: RoleCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Create a new custom role in a search space.
Requires ROLES_CREATE permission.
@ -165,7 +170,7 @@ async def create_role(
try:
await check_permission(
session,
user,
auth,
search_space_id,
Permission.ROLES_CREATE.value,
"You don't have permission to create roles",
@ -237,8 +242,9 @@ async def create_role(
async def list_roles(
search_space_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
List all roles in a search space.
Requires ROLES_READ permission.
@ -246,7 +252,7 @@ async def list_roles(
try:
await check_permission(
session,
user,
auth,
search_space_id,
Permission.ROLES_READ.value,
"You don't have permission to view roles",
@ -275,8 +281,9 @@ async def get_role(
search_space_id: int,
role_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Get a specific role by ID.
Requires ROLES_READ permission.
@ -284,7 +291,7 @@ async def get_role(
try:
await check_permission(
session,
user,
auth,
search_space_id,
Permission.ROLES_READ.value,
"You don't have permission to view roles",
@ -320,8 +327,9 @@ async def update_role(
role_id: int,
role_update: RoleUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Update a role.
Requires ROLES_UPDATE permission.
@ -330,7 +338,7 @@ async def update_role(
try:
await check_permission(
session,
user,
auth,
search_space_id,
Permission.ROLES_UPDATE.value,
"You don't have permission to update roles",
@ -417,8 +425,9 @@ async def delete_role(
search_space_id: int,
role_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Delete a custom role.
Requires ROLES_DELETE permission.
@ -427,7 +436,7 @@ async def delete_role(
try:
await check_permission(
session,
user,
auth,
search_space_id,
Permission.ROLES_DELETE.value,
"You don't have permission to delete roles",
@ -474,8 +483,9 @@ async def delete_role(
async def list_members(
search_space_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
List all members of a search space.
Requires MEMBERS_VIEW permission.
@ -483,7 +493,7 @@ async def list_members(
try:
await check_permission(
session,
user,
auth,
search_space_id,
Permission.MEMBERS_VIEW.value,
"You don't have permission to view members",
@ -539,8 +549,9 @@ async def update_member_role(
membership_id: int,
membership_update: MembershipUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Update a member's role.
Requires MEMBERS_MANAGE_ROLES permission.
@ -549,7 +560,7 @@ async def update_member_role(
try:
await check_permission(
session,
user,
auth,
search_space_id,
Permission.MEMBERS_MANAGE_ROLES.value,
"You don't have permission to manage member roles",
@ -629,8 +640,9 @@ async def update_member_role(
async def leave_search_space(
search_space_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Leave a search space (remove own membership).
Owners cannot leave their search space.
@ -675,8 +687,9 @@ async def remove_member(
search_space_id: int,
membership_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Remove a member from a search space.
Requires MEMBERS_REMOVE permission.
@ -685,7 +698,7 @@ async def remove_member(
try:
await check_permission(
session,
user,
auth,
search_space_id,
Permission.MEMBERS_REMOVE.value,
"You don't have permission to remove members",
@ -733,8 +746,9 @@ async def create_invite(
search_space_id: int,
invite_data: InviteCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Create a new invite link for a search space.
Requires MEMBERS_INVITE permission.
@ -742,7 +756,7 @@ async def create_invite(
try:
await check_permission(
session,
user,
auth,
search_space_id,
Permission.MEMBERS_INVITE.value,
"You don't have permission to create invites",
@ -798,8 +812,9 @@ async def create_invite(
async def list_invites(
search_space_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
List all invites for a search space.
Requires MEMBERS_INVITE permission.
@ -807,7 +822,7 @@ async def list_invites(
try:
await check_permission(
session,
user,
auth,
search_space_id,
Permission.MEMBERS_INVITE.value,
"You don't have permission to view invites",
@ -837,8 +852,9 @@ async def update_invite(
invite_id: int,
invite_update: InviteUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Update an invite.
Requires MEMBERS_INVITE permission.
@ -846,7 +862,7 @@ async def update_invite(
try:
await check_permission(
session,
user,
auth,
search_space_id,
Permission.MEMBERS_INVITE.value,
"You don't have permission to update invites",
@ -903,8 +919,9 @@ async def revoke_invite(
search_space_id: int,
invite_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Revoke (delete) an invite.
Requires MEMBERS_INVITE permission.
@ -912,7 +929,7 @@ async def revoke_invite(
try:
await check_permission(
session,
user,
auth,
search_space_id,
Permission.MEMBERS_INVITE.value,
"You don't have permission to revoke invites",
@ -1022,8 +1039,9 @@ async def get_invite_info(
async def accept_invite(
request: InviteAcceptRequest,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Accept an invite and join a search space.
"""
@ -1120,13 +1138,14 @@ async def accept_invite(
async def get_my_access(
search_space_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Get the current user's access info for a search space.
"""
try:
membership = await check_search_space_access(session, user, search_space_id)
membership = await check_search_space_access(session, auth, search_space_id)
# Get search space name
result = await session.execute(

View file

@ -28,6 +28,7 @@ from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.db import (
Report,
SearchSpace,
@ -42,7 +43,7 @@ from app.templates.export_helpers import (
get_reference_docx_path,
get_typst_template_path,
)
from app.users import current_active_user
from app.users import get_auth_context
from app.utils.rbac import check_search_space_access
logger = logging.getLogger(__name__)
@ -158,8 +159,9 @@ def _normalize_latex_delimiters(text: str) -> str:
async def _get_report_with_access(
report_id: int,
session: AsyncSession,
user: User,
auth: AuthContext,
) -> Report:
user = auth.user
"""Fetch a report and verify the user belongs to its search space.
Raises HTTPException(404) if not found, HTTPException(403) if no access.
@ -172,7 +174,7 @@ async def _get_report_with_access(
# Lightweight membership check - no granular RBAC, just "is the user a
# member of the search space this report belongs to?"
await check_search_space_access(session, user, report.search_space_id)
await check_search_space_access(session, auth, report.search_space_id)
return report
@ -206,8 +208,9 @@ async def read_reports(
limit: int = Query(default=100, ge=1, le=MAX_REPORT_LIST_LIMIT),
search_space_id: int | None = None,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
List reports the user has access to.
Filters by search space membership.
@ -215,7 +218,7 @@ async def read_reports(
try:
if search_space_id is not None:
# Verify the caller is a member of the requested search space
await check_search_space_access(session, user, search_space_id)
await check_search_space_access(session, auth, search_space_id)
result = await session.execute(
select(Report)
@ -247,8 +250,9 @@ async def read_reports(
async def read_report(
report_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Get a specific report by ID (metadata only, no content).
"""
@ -266,8 +270,9 @@ async def read_report(
async def read_report_content(
report_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Get full Markdown content of a report, including version siblings.
"""
@ -298,8 +303,9 @@ async def update_report_content(
report_id: int,
body: ReportContentUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Update the Markdown content of a report.
@ -339,8 +345,9 @@ async def update_report_content(
async def preview_report_pdf(
report_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Return a compiled PDF preview for Typst-based reports (resumes).
@ -394,8 +401,9 @@ async def export_report(
description="Export format: pdf, docx, html, latex, epub, odt, or plain",
),
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Export a report in the requested format.
"""
@ -568,8 +576,9 @@ async def export_report(
async def delete_report(
report_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Delete a report.
"""

View file

@ -10,8 +10,9 @@ from fastapi.responses import Response
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.auth.context import AuthContext
from app.db import NewChatThread, Permission, User, get_async_session
from app.users import current_active_user
from app.users import get_auth_context
from app.utils.rbac import check_permission
logger = logging.getLogger(__name__)
@ -47,8 +48,9 @@ async def download_sandbox_file(
thread_id: int,
path: str = Query(..., description="Absolute path of the file inside the sandbox"),
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Download a file from the Daytona sandbox associated with a chat thread."""
from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.sandbox import (
@ -68,7 +70,7 @@ async def download_sandbox_file(
await check_permission(
session,
user,
auth,
thread.search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to access files in this thread",

View file

@ -33,6 +33,7 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.auth.context import AuthContext
from app.config import config
from app.connectors.github_connector import GitHubConnector
from app.db import (
@ -56,7 +57,7 @@ from app.schemas import (
SearchSourceConnectorUpdate,
)
from app.services.composio_service import ComposioService, get_composio_service
from app.users import current_active_user
from app.users import get_auth_context
# NOTE: connector indexer functions are imported lazily inside each
# ``run_*_indexing`` helper to break a circular import cycle:
@ -143,8 +144,9 @@ class GitHubPATRequest(BaseModel):
@router.post("/github/repositories", response_model=list[dict[str, Any]])
async def list_github_repositories(
pat_request: GitHubPATRequest,
user: User = Depends(current_active_user), # Ensure the user is logged in
auth: AuthContext = Depends(get_auth_context), # Ensure the user is logged in
):
user = auth.user
"""
Fetches a list of repositories accessible by the provided GitHub PAT.
The PAT is used for this request only and is not stored.
@ -173,8 +175,9 @@ async def create_search_source_connector(
..., description="ID of the search space to associate the connector with"
),
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Create a new search source connector.
Requires CONNECTORS_CREATE permission.
@ -186,7 +189,7 @@ async def create_search_source_connector(
# Check if user has permission to create connectors
await check_permission(
session,
user,
auth,
search_space_id,
Permission.CONNECTORS_CREATE.value,
"You don't have permission to create connectors in this search space",
@ -281,8 +284,9 @@ async def read_search_source_connectors(
limit: int = 100,
search_space_id: int | None = None,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
List all search source connectors for a search space.
Requires CONNECTORS_READ permission.
@ -297,7 +301,7 @@ async def read_search_source_connectors(
# Check if user has permission to read connectors
await check_permission(
session,
user,
auth,
search_space_id,
Permission.CONNECTORS_READ.value,
"You don't have permission to view connectors in this search space",
@ -324,8 +328,9 @@ async def read_search_source_connectors(
async def read_search_source_connector(
connector_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Get a specific search source connector by ID.
Requires CONNECTORS_READ permission.
@ -345,7 +350,7 @@ async def read_search_source_connector(
# Check permission
await check_permission(
session,
user,
auth,
connector.search_space_id,
Permission.CONNECTORS_READ.value,
"You don't have permission to view this connector",
@ -367,8 +372,9 @@ async def update_search_source_connector(
connector_id: int,
connector_update: SearchSourceConnectorUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Update a search source connector.
Requires CONNECTORS_UPDATE permission.
@ -386,7 +392,7 @@ async def update_search_source_connector(
# Check permission
await check_permission(
session,
user,
auth,
db_connector.search_space_id,
Permission.CONNECTORS_UPDATE.value,
"You don't have permission to update this connector",
@ -557,8 +563,9 @@ async def update_search_source_connector(
async def delete_search_source_connector(
connector_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Delete a search source connector and all its associated documents.
@ -588,7 +595,7 @@ async def delete_search_source_connector(
# Check permission
await check_permission(
session,
user,
auth,
db_connector.search_space_id,
Permission.CONNECTORS_DELETE.value,
"You don't have permission to delete this connector",
@ -725,8 +732,9 @@ async def index_connector_content(
description="[Google Drive only] Structured request with folders and files to index",
),
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Index content from a KB connector to a search space.
@ -760,7 +768,7 @@ async def index_connector_content(
# the read/update/delete handlers — not the client-supplied query param.
await check_permission(
session,
user,
auth,
connector.search_space_id,
Permission.CONNECTORS_UPDATE.value,
"You don't have permission to index content in this search space",
@ -2645,8 +2653,9 @@ async def create_mcp_connector(
connector_data: MCPConnectorCreate,
search_space_id: int = Query(..., description="Search space ID"),
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Create a new MCP (Model Context Protocol) connector.
@ -2669,7 +2678,7 @@ async def create_mcp_connector(
# Check user has permission to create connectors
await check_permission(
session,
user,
auth,
search_space_id,
Permission.CONNECTORS_CREATE.value,
"You don't have permission to create connectors in this search space",
@ -2724,8 +2733,9 @@ async def create_mcp_connector(
async def list_mcp_connectors(
search_space_id: int = Query(..., description="Search space ID"),
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
List all MCP connectors for a search space.
@ -2741,7 +2751,7 @@ async def list_mcp_connectors(
# Check user has permission to read connectors
await check_permission(
session,
user,
auth,
search_space_id,
Permission.CONNECTORS_READ.value,
"You don't have permission to view connectors in this search space",
@ -2775,8 +2785,9 @@ async def list_mcp_connectors(
async def get_mcp_connector(
connector_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Get a specific MCP connector by ID.
@ -2805,7 +2816,7 @@ async def get_mcp_connector(
# Check user has permission to read connectors
await check_permission(
session,
user,
auth,
connector.search_space_id,
Permission.CONNECTORS_READ.value,
"You don't have permission to view this connector",
@ -2828,8 +2839,9 @@ async def update_mcp_connector(
connector_id: int,
connector_update: MCPConnectorUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Update an MCP connector.
@ -2859,7 +2871,7 @@ async def update_mcp_connector(
# Check user has permission to update connectors
await check_permission(
session,
user,
auth,
connector.search_space_id,
Permission.CONNECTORS_UPDATE.value,
"You don't have permission to update this connector",
@ -2904,8 +2916,9 @@ async def update_mcp_connector(
async def delete_mcp_connector(
connector_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Delete an MCP connector.
@ -2931,7 +2944,7 @@ async def delete_mcp_connector(
# Check user has permission to delete connectors
await check_permission(
session,
user,
auth,
connector.search_space_id,
Permission.CONNECTORS_DELETE.value,
"You don't have permission to delete this connector",
@ -2962,8 +2975,9 @@ async def delete_mcp_connector(
@router.post("/connectors/mcp/test")
async def test_mcp_server_connection(
server_config: dict = Body(...),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Test connection to an MCP server and fetch available tools.
@ -3042,8 +3056,9 @@ DRIVE_CONNECTOR_TYPES = {
async def get_drive_picker_token(
connector_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Return an OAuth access token + client ID for the Google Picker API."""
result = await session.execute(
select(SearchSourceConnector).filter(SearchSourceConnector.id == connector_id)
@ -3054,7 +3069,7 @@ async def get_drive_picker_token(
await check_permission(
session,
user,
auth,
connector.search_space_id,
Permission.CONNECTORS_READ.value,
"You don't have permission to access this connector",
@ -3164,8 +3179,9 @@ async def trust_mcp_tool(
connector_id: int,
body: MCPTrustToolRequest,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Add a tool to the MCP connector's trusted (always-allow) list.
Once trusted, the tool executes without HITL approval on subsequent
@ -3209,8 +3225,9 @@ async def untrust_mcp_tool(
connector_id: int,
body: MCPTrustToolRequest,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Remove a tool from the MCP connector's trusted list.
The tool will require HITL approval again on subsequent calls.

View file

@ -5,22 +5,23 @@ from sqlalchemy import func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.auth.context import AuthContext
from app.db import (
Permission,
SearchSpace,
SearchSpaceMembership,
SearchSpaceRole,
User,
get_async_session,
get_default_roles_config,
)
from app.schemas import (
SearchSpaceApiAccessUpdate,
SearchSpaceCreate,
SearchSpaceRead,
SearchSpaceUpdate,
SearchSpaceWithStats,
)
from app.users import current_active_user
from app.users import allow_any_principal, get_auth_context, require_session_context
from app.utils.rbac import check_permission, check_search_space_access
logger = logging.getLogger(__name__)
@ -74,8 +75,9 @@ async def create_default_roles_and_membership(
async def create_search_space(
search_space: SearchSpaceCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
):
user = auth.user
try:
search_space_data = search_space.model_dump()
@ -108,8 +110,9 @@ async def read_search_spaces(
limit: int = 200,
owned_only: bool = False,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(allow_any_principal),
):
user = auth.user
"""
Get all search spaces the user has access to, with member count and ownership info.
@ -123,11 +126,17 @@ async def read_search_spaces(
# Exclude spaces that are pending background deletion
not_deleting = ~SearchSpace.name.startswith("[DELETING] ")
api_access_filter = (
SearchSpace.api_access_enabled == True # noqa: E712
if auth.is_gated
else True
)
if owned_only:
# Return only search spaces where user is the original creator (user_id)
result = await session.execute(
select(SearchSpace)
.filter(SearchSpace.user_id == user.id, not_deleting)
.filter(SearchSpace.user_id == user.id, not_deleting, api_access_filter)
.order_by(SearchSpace.id.asc())
.offset(skip)
.limit(limit)
@ -137,7 +146,11 @@ async def read_search_spaces(
result = await session.execute(
select(SearchSpace)
.join(SearchSpaceMembership)
.filter(SearchSpaceMembership.user_id == user.id, not_deleting)
.filter(
SearchSpaceMembership.user_id == user.id,
not_deleting,
api_access_filter,
)
.order_by(SearchSpace.id.asc())
.offset(skip)
.limit(limit)
@ -174,6 +187,7 @@ async def read_search_spaces(
created_at=space.created_at,
user_id=space.user_id,
citations_enabled=space.citations_enabled,
api_access_enabled=space.api_access_enabled,
qna_custom_instructions=space.qna_custom_instructions,
ai_file_sort_enabled=space.ai_file_sort_enabled,
member_count=member_count,
@ -192,7 +206,7 @@ async def read_search_spaces(
async def read_search_space(
search_space_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
"""
Get a specific search space by ID.
@ -200,7 +214,7 @@ async def read_search_space(
"""
try:
# Check if user has access (is a member)
await check_search_space_access(session, user, search_space_id)
await check_search_space_access(session, auth, search_space_id)
result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == search_space_id)
@ -225,7 +239,7 @@ async def update_search_space(
search_space_id: int,
search_space_update: SearchSpaceUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
"""
Update a search space.
@ -235,7 +249,7 @@ async def update_search_space(
# Check permission
await check_permission(
session,
user,
auth,
search_space_id,
Permission.SETTINGS_UPDATE.value,
"You don't have permission to update this search space",
@ -265,17 +279,65 @@ async def update_search_space(
) from e
@router.put("/searchspaces/{search_space_id}/api-access", response_model=SearchSpaceRead)
async def update_search_space_api_access(
search_space_id: int,
body: SearchSpaceApiAccessUpdate,
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
"""
Toggle programmatic API/PAT access for a search space.
Requires API_ACCESS_MANAGE permission.
"""
try:
if not auth.is_session:
raise HTTPException(
status_code=403,
detail="This action requires an interactive session",
)
await check_permission(
session,
auth,
search_space_id,
Permission.API_ACCESS_MANAGE.value,
"You don't have permission to manage API access for this search space",
)
result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == search_space_id)
)
db_search_space = result.scalars().first()
if not db_search_space:
raise HTTPException(status_code=404, detail="Search space not found")
db_search_space.api_access_enabled = body.api_access_enabled
await session.commit()
await session.refresh(db_search_space)
return db_search_space
except HTTPException:
raise
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500, detail=f"Failed to update API access: {e!s}"
) from e
@router.post("/searchspaces/{search_space_id}/ai-sort")
async def trigger_ai_sort(
search_space_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Trigger a full AI file sort for all documents in the search space."""
try:
await check_permission(
session,
user,
auth,
search_space_id,
Permission.SETTINGS_UPDATE.value,
"You don't have permission to trigger AI sort on this search space",
@ -305,7 +367,7 @@ async def trigger_ai_sort(
async def delete_search_space(
search_space_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
"""
Delete a search space.
@ -318,7 +380,7 @@ async def delete_search_space(
# Check permission - only those with SETTINGS_DELETE can delete
await check_permission(
session,
user,
auth,
search_space_id,
Permission.SETTINGS_DELETE.value,
"You don't have permission to delete this search space",
@ -374,7 +436,7 @@ async def delete_search_space(
async def list_search_space_snapshots(
search_space_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
"""
List all public chat snapshots for a search space.
@ -387,6 +449,6 @@ async def list_search_space_snapshots(
snapshots = await list_snapshots_for_search_space(
session=session,
search_space_id=search_space_id,
user=user,
auth=auth,
)
return PublicChatSnapshotsBySpaceResponse(snapshots=snapshots)

View file

@ -17,21 +17,22 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.auth.context import AuthContext
from app.config import config
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
User,
get_async_session,
)
from app.schemas.slack_auth_credentials import SlackAuthCredentialsBase
from app.users import current_active_user
from app.users import get_auth_context, require_session_context
from app.utils.connector_naming import (
check_duplicate_connector,
extract_identifier_from_credentials,
generate_unique_connector_name,
)
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
from app.utils.rbac import check_search_space_access
logger = logging.getLogger(__name__)
@ -78,7 +79,10 @@ def get_token_encryption() -> TokenEncryption:
@router.get("/auth/slack/connector/add")
async def connect_slack(space_id: int, user: User = Depends(current_active_user)):
async def connect_slack(
space_id: int,
auth: AuthContext = Depends(require_session_context),
):
"""
Initiate Slack OAuth flow.
@ -89,6 +93,7 @@ async def connect_slack(space_id: int, user: User = Depends(current_active_user)
Returns:
Authorization URL for redirect
"""
user = auth.user
try:
if not space_id:
raise HTTPException(status_code=400, detail="space_id is required")
@ -525,7 +530,7 @@ async def refresh_slack_token(
async def get_slack_channels(
connector_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
) -> list[dict[str, Any]]:
"""
Get list of Slack channels with bot membership status.
@ -541,6 +546,7 @@ async def get_slack_channels(
Returns:
List of channels with id, name, is_private, and is_member fields
"""
user = auth.user
try:
# Get the connector and verify ownership
result = await session.execute(
@ -559,6 +565,8 @@ async def get_slack_channels(
detail="Slack connector not found or access denied",
)
await check_search_space_access(session, auth, connector.search_space_id)
# Get credentials and decrypt bot token
credentials = SlackAuthCredentialsBase.from_dict(connector.config)
token_encryption = get_token_encryption()

View file

@ -18,6 +18,7 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from stripe import SignatureVerificationError, StripeClient, StripeError
from app.auth.context import AuthContext
from app.config import config
from app.db import (
CreditPurchase,
@ -39,7 +40,7 @@ from app.schemas.stripe import (
StripeWebhookResponse,
UpdateAutoReloadSettingsRequest,
)
from app.users import current_active_user
from app.users import require_session_context
logger = logging.getLogger(__name__)
@ -456,7 +457,7 @@ async def _reconcile_auto_reload_payment_intent(
)
async def create_credit_checkout_session(
body: CreateCreditCheckoutSessionRequest,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
db_session: AsyncSession = Depends(get_async_session),
) -> CreateCreditCheckoutSessionResponse:
"""Create a Stripe Checkout Session for buying credit packs.
@ -466,6 +467,7 @@ async def create_credit_checkout_session(
cost reported by LiteLLM (premium calls) or ``MICROS_PER_PAGE`` per page
(ETL), so $1 of credit always buys $1 worth of usage at cost.
"""
user = auth.user
_ensure_credit_buying_enabled()
stripe_client = get_stripe_client()
price_id = _get_required_credit_price_id()
@ -644,7 +646,7 @@ async def stripe_webhook(
@router.get("/finalize-checkout", response_model=FinalizeCheckoutResponse)
async def finalize_checkout(
session_id: str,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
db_session: AsyncSession = Depends(get_async_session),
) -> FinalizeCheckoutResponse:
"""Synchronously fulfil a credit checkout session from the success page.
@ -659,6 +661,7 @@ async def finalize_checkout(
Authorization: the session's ``client_reference_id`` must match the
authenticated user's id.
"""
user = auth.user
stripe_client = get_stripe_client()
try:
@ -718,13 +721,14 @@ async def finalize_checkout(
@router.get("/credit-status", response_model=CreditStripeStatusResponse)
async def get_credit_status(
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
) -> CreditStripeStatusResponse:
"""Return credit-buying availability and current balance for the frontend.
``credit_micros_balance`` is in micro-USD (1_000_000 = $1.00); the FE
divides by 1M when displaying.
"""
user = auth.user
return CreditStripeStatusResponse(
credit_buying_enabled=config.STRIPE_CREDIT_BUYING_ENABLED,
credit_micros_balance=user.credit_micros_balance,
@ -733,12 +737,13 @@ async def get_credit_status(
@router.get("/credit-purchases", response_model=CreditPurchaseHistoryResponse)
async def get_credit_purchases(
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
db_session: AsyncSession = Depends(get_async_session),
offset: int = 0,
limit: int = 50,
) -> CreditPurchaseHistoryResponse:
"""Return the authenticated user's credit purchase history."""
user = auth.user
limit = min(limit, 100)
purchases = (
(
@ -759,7 +764,7 @@ async def get_credit_purchases(
@router.get("/purchases", response_model=PagePurchaseHistoryResponse)
async def get_page_purchases(
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
db_session: AsyncSession = Depends(get_async_session),
offset: int = 0,
limit: int = 50,
@ -768,6 +773,7 @@ async def get_page_purchases(
Page buying is removed; this endpoint stays for historical records.
"""
user = auth.user
limit = min(limit, 100)
purchases = (
(
@ -804,7 +810,7 @@ def _auto_reload_settings_response(user: User) -> AutoReloadSettingsResponse:
)
async def create_auto_reload_setup_session(
body: CreateAutoReloadSetupSessionRequest,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
db_session: AsyncSession = Depends(get_async_session),
) -> CreateAutoReloadSetupSessionResponse:
"""Start a ``mode=setup`` checkout session to save a card for auto-reload.
@ -813,6 +819,7 @@ async def create_auto_reload_setup_session(
Customer so the card can later be charged off-session. On completion the
webhook stores the resulting payment method on the user.
"""
user = auth.user
_ensure_auto_reload_enabled()
_ensure_credit_buying_enabled()
stripe_client = get_stripe_client()
@ -871,16 +878,17 @@ async def create_auto_reload_setup_session(
@router.get("/auto-reload", response_model=AutoReloadSettingsResponse)
async def get_auto_reload_settings(
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
) -> AutoReloadSettingsResponse:
"""Return the user's auto-reload configuration and saved-card state."""
user = auth.user
return _auto_reload_settings_response(user)
@router.put("/auto-reload", response_model=AutoReloadSettingsResponse)
async def update_auto_reload_settings(
body: UpdateAutoReloadSettingsRequest,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
db_session: AsyncSession = Depends(get_async_session),
) -> AutoReloadSettingsResponse:
"""Update auto-reload preferences.
@ -889,6 +897,7 @@ async def update_auto_reload_settings(
at least ``AUTO_RELOAD_MIN_AMOUNT_MICROS``. Disabling always succeeds and
clears any prior failure flag.
"""
user = auth.user
_ensure_auto_reload_enabled()
locked = (

View file

@ -6,6 +6,7 @@ from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.db import User, get_async_session
from app.services.memory import (
MemoryRead,
@ -15,7 +16,7 @@ from app.services.memory import (
reset_memory,
save_memory,
)
from app.users import current_active_user
from app.users import get_auth_context
from app.utils.rbac import check_search_space_access
router = APIRouter()
@ -29,9 +30,10 @@ class TeamMemoryUpdate(BaseModel):
async def get_team_memory(
search_space_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
await check_search_space_access(session, user, search_space_id)
user = auth.user
await check_search_space_access(session, auth, search_space_id)
memory_md = await read_memory(
scope=MemoryScope.TEAM,
target_id=search_space_id,
@ -45,9 +47,10 @@ async def update_team_memory(
search_space_id: int,
body: TeamMemoryUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
await check_search_space_access(session, user, search_space_id)
user = auth.user
await check_search_space_access(session, auth, search_space_id)
result = await save_memory(
scope=MemoryScope.TEAM,
target_id=search_space_id,
@ -63,9 +66,10 @@ async def update_team_memory(
async def reset_team_memory(
search_space_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
await check_search_space_access(session, user, search_space_id)
user = auth.user
await check_search_space_access(session, auth, search_space_id)
result = await reset_memory(
scope=MemoryScope.TEAM,
target_id=search_space_id,

View file

@ -14,15 +14,15 @@ from fastapi.responses import RedirectResponse
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.config import config
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
User,
get_async_session,
)
from app.schemas.teams_auth_credentials import TeamsAuthCredentialsBase
from app.users import current_active_user
from app.users import require_session_context
from app.utils.connector_naming import (
check_duplicate_connector,
extract_identifier_from_credentials,
@ -74,7 +74,10 @@ def get_token_encryption() -> TokenEncryption:
@router.get("/auth/teams/connector/add")
async def connect_teams(space_id: int, user: User = Depends(current_active_user)):
async def connect_teams(
space_id: int,
auth: AuthContext = Depends(require_session_context),
):
"""
Initiate Microsoft Teams OAuth flow.
@ -85,6 +88,7 @@ async def connect_teams(space_id: int, user: User = Depends(current_active_user)
Returns:
Authorization URL for redirect
"""
user = auth.user
try:
if not space_id:
raise HTTPException(status_code=400, detail="space_id is required")

View file

@ -16,6 +16,7 @@ from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.db import (
Permission,
SearchSpace,
@ -25,7 +26,7 @@ from app.db import (
get_async_session,
)
from app.schemas import VideoPresentationRead
from app.users import current_active_user
from app.users import get_auth_context
from app.utils.rbac import check_permission
router = APIRouter()
@ -37,8 +38,9 @@ async def read_video_presentations(
limit: int = 100,
search_space_id: int | None = None,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
List video presentations the user has access to.
Requires VIDEO_PRESENTATIONS_READ permission for the search space(s).
@ -49,7 +51,7 @@ async def read_video_presentations(
if search_space_id is not None:
await check_permission(
session,
user,
auth,
search_space_id,
Permission.VIDEO_PRESENTATIONS_READ.value,
"You don't have permission to read video presentations in this search space",
@ -89,8 +91,9 @@ async def read_video_presentations(
async def read_video_presentation(
video_presentation_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Get a specific video presentation by ID.
Requires authentication with VIDEO_PRESENTATIONS_READ permission.
@ -112,7 +115,7 @@ async def read_video_presentation(
await check_permission(
session,
user,
auth,
video_pres.search_space_id,
Permission.VIDEO_PRESENTATIONS_READ.value,
"You don't have permission to read video presentations in this search space",
@ -132,8 +135,9 @@ async def read_video_presentation(
async def delete_video_presentation(
video_presentation_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Delete a video presentation.
Requires VIDEO_PRESENTATIONS_DELETE permission for the search space.
@ -151,7 +155,7 @@ async def delete_video_presentation(
await check_permission(
session,
user,
auth,
db_video_pres.search_space_id,
Permission.VIDEO_PRESENTATIONS_DELETE.value,
"You don't have permission to delete video presentations in this search space",
@ -175,8 +179,9 @@ async def stream_slide_audio(
video_presentation_id: int,
slide_number: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Stream the audio file for a specific slide in a video presentation.
The slide_number is 1-based. Audio path is read from the slides JSONB.
@ -194,7 +199,7 @@ async def stream_slide_audio(
await check_permission(
session,
user,
auth,
video_pres.search_space_id,
Permission.VIDEO_PRESENTATIONS_READ.value,
"You don't have permission to access video presentations in this search space",

View file

@ -8,8 +8,8 @@ import time
from fastapi import APIRouter, Depends, HTTPException, Query
from scrapling.fetchers import AsyncFetcher
from app.db import User
from app.users import current_active_user
from app.auth.context import AuthContext
from app.users import require_session_context
from app.utils.proxy import get_proxy_url
router = APIRouter()
@ -29,7 +29,7 @@ _INNERTUBE_CLIENT = {
@router.get("/youtube/playlist-videos")
async def get_playlist_videos(
url: str = Query(..., description="YouTube playlist URL"),
_user: User = Depends(current_active_user),
_auth: AuthContext = Depends(require_session_context),
):
"""Resolve a YouTube playlist URL into individual video URLs."""
match = _PLAYLIST_ID_RE.search(url)

View file

@ -104,6 +104,7 @@ from .search_source_connector import (
SearchSourceConnectorUpdate,
)
from .search_space import (
SearchSpaceApiAccessUpdate,
SearchSpaceBase,
SearchSpaceCreate,
SearchSpaceRead,
@ -243,6 +244,7 @@ __all__ = [
"SearchSourceConnectorUpdate",
# Search space schemas
"SearchSpaceBase",
"SearchSpaceApiAccessUpdate",
"SearchSpaceCreate",
"SearchSpaceRead",
"SearchSpaceUpdate",

View file

@ -0,0 +1,27 @@
from datetime import datetime
from pydantic import BaseModel, ConfigDict, Field
class PATCreate(BaseModel):
label: str = Field(min_length=1, max_length=120)
expires_in_days: int | None = Field(default=None, gt=0)
class PATCreated(BaseModel):
id: int
label: str
token: str
prefix: str
expires_at: datetime | None = None
class PATRead(BaseModel):
id: int
label: str
prefix: str
expires_at: datetime | None = None
last_used_at: datetime | None = None
created_at: datetime
model_config = ConfigDict(from_attributes=True)

View file

@ -24,11 +24,16 @@ class SearchSpaceUpdate(BaseModel):
ai_file_sort_enabled: bool | None = None
class SearchSpaceApiAccessUpdate(BaseModel):
api_access_enabled: bool
class SearchSpaceRead(SearchSpaceBase, IDModel, TimestampModel):
id: int
created_at: datetime
user_id: uuid.UUID
citations_enabled: bool
api_access_enabled: bool = False
qna_custom_instructions: str | None = None
shared_memory_md: str | None = None
ai_file_sort_enabled: bool = False

View file

@ -9,6 +9,7 @@ from sqlalchemy import delete, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.auth.context import AuthContext
from app.db import (
ChatComment,
ChatCommentMention,
@ -138,8 +139,9 @@ async def get_comment_thread_participants(
async def get_comments_for_message(
session: AsyncSession,
message_id: int,
user: User,
auth: AuthContext,
) -> CommentListResponse:
user = auth.user
"""
Get all comments for a message with their replies.
@ -169,7 +171,7 @@ async def get_comments_for_message(
# Check permission to read comments
await check_permission(
session,
user,
auth,
search_space_id,
Permission.COMMENTS_READ.value,
"You don't have permission to read comments in this search space",
@ -268,8 +270,9 @@ async def get_comments_for_message(
async def get_comments_for_messages_batch(
session: AsyncSession,
message_ids: list[int],
user: User,
auth: AuthContext,
) -> CommentBatchResponse:
user = auth.user
"""
Batch-fetch comments for multiple messages in a single DB round-trip.
@ -295,7 +298,7 @@ async def get_comments_for_messages_batch(
for ss_id in search_space_ids:
await check_permission(
session,
user,
auth,
ss_id,
Permission.COMMENTS_READ.value,
"You don't have permission to read comments in this search space",
@ -409,8 +412,9 @@ async def create_comment(
session: AsyncSession,
message_id: int,
content: str,
user: User,
auth: AuthContext,
) -> CommentResponse:
user = auth.user
"""
Create a top-level comment on an AI response.
@ -521,8 +525,9 @@ async def create_reply(
session: AsyncSession,
comment_id: int,
content: str,
user: User,
auth: AuthContext,
) -> CommentReplyResponse:
user = auth.user
"""
Create a reply to an existing comment.
@ -657,8 +662,9 @@ async def update_comment(
session: AsyncSession,
comment_id: int,
content: str,
user: User,
auth: AuthContext,
) -> CommentReplyResponse:
user = auth.user
"""
Update a comment's content (author only).
@ -797,8 +803,9 @@ async def update_comment(
async def delete_comment(
session: AsyncSession,
comment_id: int,
user: User,
auth: AuthContext,
) -> dict:
user = auth.user
"""
Delete a comment (author or user with COMMENTS_DELETE permission).
@ -844,9 +851,10 @@ async def delete_comment(
async def get_user_mentions(
session: AsyncSession,
user: User,
auth: AuthContext,
search_space_id: int | None = None,
) -> MentionListResponse:
user = auth.user
"""
Get mentions for the current user, optionally filtered by search space.

View file

@ -21,6 +21,7 @@ from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.auth.context import AuthContext
from app.db import (
ChatVisibility,
NewChatMessage,
@ -163,8 +164,9 @@ def compute_content_hash(messages: list[dict]) -> str:
async def create_snapshot(
session: AsyncSession,
thread_id: int,
user: User,
auth: AuthContext,
) -> dict:
user = auth.user
"""
Create a public snapshot of a chat thread.
@ -186,7 +188,7 @@ async def create_snapshot(
await check_permission(
session,
user,
auth,
thread.search_space_id,
Permission.PUBLIC_SHARING_CREATE.value,
"You don't have permission to create public share links",
@ -431,8 +433,9 @@ async def get_public_chat(
async def list_snapshots_for_thread(
session: AsyncSession,
thread_id: int,
user: User,
auth: AuthContext,
) -> list[dict]:
user = auth.user
"""List all public snapshots for a thread."""
from app.config import config
@ -447,7 +450,7 @@ async def list_snapshots_for_thread(
# Check permission to view public share links
await check_permission(
session,
user,
auth,
thread.search_space_id,
Permission.PUBLIC_SHARING_VIEW.value,
"You don't have permission to view public share links",
@ -477,14 +480,15 @@ async def list_snapshots_for_thread(
async def list_snapshots_for_search_space(
session: AsyncSession,
search_space_id: int,
user: User,
auth: AuthContext,
) -> list[dict]:
user = auth.user
"""List all public snapshots for a search space."""
from app.config import config
await check_permission(
session,
user,
auth,
search_space_id,
Permission.PUBLIC_SHARING_VIEW.value,
"You don't have permission to view public share links",
@ -534,8 +538,9 @@ async def delete_snapshot(
session: AsyncSession,
thread_id: int,
snapshot_id: int,
user: User,
auth: AuthContext,
) -> bool:
user = auth.user
"""Delete a specific snapshot. Only thread owner can delete."""
# Get snapshot with thread
result = await session.execute(
@ -553,7 +558,7 @@ async def delete_snapshot(
await check_permission(
session,
user,
auth,
snapshot.thread.search_space_id,
Permission.PUBLIC_SHARING_DELETE.value,
"You don't have permission to delete public share links",

View file

@ -13,6 +13,7 @@ from app.agents.chat.multi_agent_chat.shared.filesystem_selection import (
FilesystemSelection,
)
from app.agents.chat.runtime.llm_config import AgentConfig
from app.auth.context import AuthContext
from app.db import ChatVisibility
from app.services.connector_service import ConnectorService
@ -33,6 +34,7 @@ async def build_main_agent_for_thread(
filesystem_selection: FilesystemSelection | None,
disabled_tools: list[str] | None = None,
mentioned_document_ids: list[int] | None = None,
auth_context: AuthContext | None = None,
) -> Any:
return await agent_factory(
llm=llm,
@ -48,4 +50,5 @@ async def build_main_agent_for_thread(
filesystem_selection=filesystem_selection,
disabled_tools=disabled_tools,
mentioned_document_ids=mentioned_document_ids,
auth_context=auth_context,
)

View file

@ -35,6 +35,7 @@ from app.agents.chat.multi_agent_chat.shared.filesystem_selection import (
FilesystemMode,
FilesystemSelection,
)
from app.auth.context import AuthContext
from app.db import ChatVisibility, async_session_maker
from app.observability import otel as ot
from app.services.new_streaming_service import VercelStreamingService
@ -136,6 +137,7 @@ async def stream_new_chat(
filesystem_selection: FilesystemSelection | None = None,
request_id: str | None = None,
user_image_data_urls: list[str] | None = None,
auth_context: AuthContext | None = None,
flow: Literal["new", "regenerate"] = "new",
) -> AsyncGenerator[str, None]:
"""Stream a new chat turn using the SurfSense deep agent.
@ -412,6 +414,7 @@ async def stream_new_chat(
filesystem_selection=filesystem_selection,
disabled_tools=disabled_tools,
mentioned_document_ids=mentioned_document_ids,
auth_context=auth_context,
)
_perf_log.info(
"[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0
@ -664,6 +667,7 @@ async def stream_new_chat(
filesystem_selection=filesystem_selection,
disabled_tools=disabled_tools,
mentioned_document_ids=mentioned_document_ids,
auth_context=auth_context,
)
_perf_log.info(
"[stream_new_chat] Runtime rate-limit recovery repinned "

View file

@ -29,6 +29,7 @@ from app.agents.chat.multi_agent_chat.shared.filesystem_selection import (
FilesystemMode,
FilesystemSelection,
)
from app.auth.context import AuthContext
from app.db import ChatVisibility, async_session_maker
from app.observability import otel as ot
from app.services.chat_session_state_service import set_ai_responding
@ -102,6 +103,7 @@ async def stream_resume_chat(
filesystem_selection: FilesystemSelection | None = None,
request_id: str | None = None,
disabled_tools: list[str] | None = None,
auth_context: AuthContext | None = None,
) -> AsyncGenerator[str, None]:
"""Resume a paused HITL turn with the user's decisions.
@ -346,6 +348,7 @@ async def stream_resume_chat(
thread_visibility=visibility,
filesystem_selection=filesystem_selection,
disabled_tools=disabled_tools,
auth_context=auth_context,
)
_perf_log.info(
"[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0
@ -481,6 +484,7 @@ async def stream_resume_chat(
thread_visibility=visibility,
filesystem_selection=filesystem_selection,
disabled_tools=disabled_tools,
auth_context=auth_context,
)
_perf_log.info(
"[stream_resume] Runtime rate-limit recovery repinned "

View file

@ -3,7 +3,7 @@ import uuid
from datetime import UTC, datetime
import httpx
from fastapi import Depends, Request, Response
from fastapi import Depends, HTTPException, Request, Response, status
from fastapi.responses import JSONResponse, RedirectResponse
from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, models
from fastapi_users.authentication import (
@ -14,7 +14,9 @@ from fastapi_users.authentication import (
from fastapi_users.db import SQLAlchemyUserDatabase
from pydantic import BaseModel
from sqlalchemy import update
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.config import config
from app.db import (
Prompt,
@ -23,10 +25,12 @@ from app.db import (
SearchSpaceRole,
User,
async_session_maker,
get_async_session,
get_default_roles_config,
get_user_db,
)
from app.prompts.system_defaults import SYSTEM_PROMPT_DEFAULTS
from app.utils.pat import PAT_PREFIX, maybe_touch_last_used, resolve_pat
from app.utils.refresh_tokens import create_refresh_token
logger = logging.getLogger(__name__)
@ -298,5 +302,75 @@ auth_backend = AuthenticationBackend(
fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])
current_active_user = fastapi_users.current_user(active=True)
async def get_auth_context(
request: Request,
session: AsyncSession = Depends(get_async_session),
user_manager: UserManager = Depends(get_user_manager),
) -> AuthContext:
"""Resolve the authenticated principal.
Use this for authorization-sensitive routes where session-vs-PAT matters.
FastAPI-Users still handles JWT mechanics; PATs are resolved here so RBAC
receives the full SurfSense principal instead of a bare User.
"""
auth_header = request.headers.get("Authorization")
if not auth_header:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Unauthorized",
)
scheme, _, token = auth_header.partition(" ")
if scheme.lower() != "bearer" or not token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Unauthorized",
)
if token.startswith(PAT_PREFIX):
pat = await resolve_pat(session, token)
if pat and pat.user and pat.user.is_active:
maybe_touch_last_used(pat)
return AuthContext.pat_auth(pat.user, pat)
try:
user = await get_jwt_strategy().read_token(token, user_manager)
except Exception:
logger.exception("Failed to read access token")
user = None
if not user or not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Unauthorized",
)
return AuthContext.session(user)
async def allow_any_principal(
auth: AuthContext = Depends(get_auth_context),
) -> AuthContext:
"""Allow either session or PAT principals for bootstrap probes only.
Routes using this dependency intentionally have no search-space gate.
Adding a new call site is a security decision and must be covered by
the fail-closed PAT allowlist test.
"""
return auth
async def require_session_context(
auth: AuthContext = Depends(get_auth_context),
) -> AuthContext:
"""Require an interactive session and reject PAT-authenticated requests."""
if not auth.is_session:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="This action requires an interactive session",
)
return auth
current_optional_user = fastapi_users.current_user(active=True, optional=True)

View file

@ -0,0 +1,73 @@
from __future__ import annotations
import asyncio
import hashlib
import logging
import secrets
from datetime import UTC, datetime, timedelta
from sqlalchemy import update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from app.db import PersonalAccessToken, User, async_session_maker
logger = logging.getLogger(__name__)
PAT_PREFIX = "ss_pat_"
PAT_TOKEN_BYTES = 32
LAST_USED_THROTTLE = timedelta(minutes=10)
def generate_pat() -> str:
return f"{PAT_PREFIX}{secrets.token_urlsafe(PAT_TOKEN_BYTES)}"
def hash_pat(token: str) -> str:
return hashlib.sha256(token.encode()).hexdigest()
def token_prefix(token: str) -> str:
return token[:16]
async def resolve_pat(
session: AsyncSession,
token: str,
) -> PersonalAccessToken | None:
now = datetime.now(UTC)
result = await session.execute(
select(PersonalAccessToken)
.options(selectinload(PersonalAccessToken.user))
.join(User)
.where(
PersonalAccessToken.token_hash == hash_pat(token),
(PersonalAccessToken.expires_at.is_(None))
| (PersonalAccessToken.expires_at > now),
User.is_active == True, # noqa: E712
)
)
return result.scalars().first()
async def _touch_last_used(token_id: int) -> None:
try:
async with async_session_maker() as session:
await session.execute(
update(PersonalAccessToken)
.where(PersonalAccessToken.id == token_id)
.values(last_used_at=datetime.now(UTC))
)
await session.commit()
except Exception:
logger.exception("Failed to update PAT last_used_at for token %s", token_id)
def maybe_touch_last_used(pat: PersonalAccessToken) -> None:
last_used_at = pat.last_used_at
now = datetime.now(UTC)
if last_used_at is not None and now - last_used_at < LAST_USED_THROTTLE:
return
asyncio.create_task(_touch_last_used(pat.id))

View file

@ -11,12 +11,12 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from app.auth.context import AuthContext
from app.db import (
Permission,
SearchSpace,
SearchSpaceMembership,
SearchSpaceRole,
User,
has_permission,
)
@ -80,9 +80,33 @@ async def get_user_permissions(
return []
async def _enforce_api_access_gate(
session: AsyncSession,
auth: AuthContext,
search_space_id: int,
search_space: SearchSpace | None = None,
) -> SearchSpace:
if search_space is None:
result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == search_space_id)
)
search_space = result.scalars().first()
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found")
if auth.is_gated and not search_space.api_access_enabled:
raise HTTPException(
status_code=403,
detail="API access is not enabled for this search space.",
)
return search_space
async def check_permission(
session: AsyncSession,
user: User,
auth: AuthContext,
search_space_id: int,
required_permission: str,
error_message: str = "You don't have permission to perform this action",
@ -104,7 +128,7 @@ async def check_permission(
Raises:
HTTPException: If user doesn't have access or permission
"""
membership = await get_user_membership(session, user.id, search_space_id)
membership = await get_user_membership(session, auth.user.id, search_space_id)
if not membership:
raise HTTPException(
@ -123,12 +147,14 @@ async def check_permission(
if not has_permission(permissions, required_permission):
raise HTTPException(status_code=403, detail=error_message)
await _enforce_api_access_gate(session, auth, search_space_id)
return membership
async def check_search_space_access(
session: AsyncSession,
user: User,
auth: AuthContext,
search_space_id: int,
) -> SearchSpaceMembership:
"""
@ -146,7 +172,7 @@ async def check_search_space_access(
Raises:
HTTPException: If user doesn't have access
"""
membership = await get_user_membership(session, user.id, search_space_id)
membership = await get_user_membership(session, auth.user.id, search_space_id)
if not membership:
raise HTTPException(
@ -154,6 +180,8 @@ async def check_search_space_access(
detail="You don't have access to this search space",
)
await _enforce_api_access_gate(session, auth, search_space_id)
return membership
@ -179,7 +207,7 @@ async def is_search_space_owner(
async def get_search_space_with_access_check(
session: AsyncSession,
user: User,
auth: AuthContext,
search_space_id: int,
required_permission: str | None = None,
) -> tuple[SearchSpace, SearchSpaceMembership]:
@ -210,10 +238,12 @@ async def get_search_space_with_access_check(
# Check access
if required_permission:
membership = await check_permission(
session, user, search_space_id, required_permission
session, auth, search_space_id, required_permission
)
else:
membership = await check_search_space_access(session, user, search_space_id)
membership = await check_search_space_access(session, auth, search_space_id)
await _enforce_api_access_gate(session, auth, search_space_id, search_space)
return search_space, membership

View file

@ -42,7 +42,7 @@ Maximize logic covered by unit tests; keep integration tests for what genuinely
`conftest.py` is scoped to its directory and below. Keep truly global fixtures in `tests/conftest.py`; put module-specific fixtures in that module's `conftest.py` so a DB fixture never loads for a pure unit test.
For API integration tests, override `get_async_session` and `current_active_user` to ride the test's transactional `db_session` (see `tests/integration/notifications/conftest.py`): rows seeded in the test and rows read via the endpoint share one transaction that rolls back automatically.
For API integration tests, override `get_async_session` and `get_auth_context` to ride the test's transactional `db_session` (see `tests/integration/notifications/conftest.py`): rows seeded in the test and rows read via the endpoint share one transaction that rolls back automatically.
## Import mode

View file

@ -40,6 +40,7 @@ import pytest_asyncio
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.db import (
ChatVisibility,
NewChatMessage,
@ -395,7 +396,7 @@ class TestAppendMessageRecoveryAfterFinalize:
thread_id=thread_id,
request=request,
session=db_session,
user=db_user,
auth=AuthContext.session(db_user),
)
# Response must echo the SERVER's rich payload, not the FE's
@ -469,7 +470,7 @@ class TestAppendMessageRecoveryAfterFinalize:
thread_id=thread_id,
request=_FakeRequest(fe_request_body),
session=db_session,
user=db_user,
auth=AuthContext.session(db_user),
)
assert fe_response.role == NewChatMessageRole.ASSISTANT
@ -552,7 +553,7 @@ class TestAppendMessageRecoveryAfterFinalize:
}
),
session=db_session,
user=db_user,
auth=AuthContext.session(db_user),
)
assert ok_response.role == NewChatMessageRole.USER
assert ok_response.turn_id is None

View file

@ -16,6 +16,7 @@ from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.db import (
ChatVisibility,
SearchSpace,
@ -33,6 +34,10 @@ from app.schemas.new_chat import (
pytestmark = pytest.mark.integration
def _auth(user: User) -> AuthContext:
return AuthContext.session(user)
@pytest_asyncio.fixture
async def db_member(db_session: AsyncSession, db_search_space: SearchSpace) -> User:
member = User(
@ -85,7 +90,7 @@ async def _create_thread(
visibility=ChatVisibility.PRIVATE,
),
session=db_session,
user=db_user,
auth=_auth(db_user),
)
@ -108,13 +113,13 @@ async def test_private_thread_is_hidden_from_other_search_space_member(
member_threads = await new_chat_routes.list_threads(
search_space_id=db_search_space.id,
session=db_session,
user=db_member,
auth=_auth(db_member),
)
member_search = await new_chat_routes.search_threads(
search_space_id=db_search_space.id,
title="Visibility",
session=db_session,
user=db_member,
auth=_auth(db_member),
)
assert thread.id not in _active_thread_ids(member_threads)
@ -123,7 +128,7 @@ async def test_private_thread_is_hidden_from_other_search_space_member(
await new_chat_routes.get_thread_full(
thread_id=thread.id,
session=db_session,
user=db_member,
auth=_auth(db_member),
)
assert exc_info.value.status_code == 403
@ -142,24 +147,24 @@ async def test_creator_can_share_thread_and_member_can_list_search_read_it(
visibility=ChatVisibility.SEARCH_SPACE,
),
session=db_session,
user=db_user,
auth=_auth(db_user),
)
member_threads = await new_chat_routes.list_threads(
search_space_id=db_search_space.id,
session=db_session,
user=db_member,
auth=_auth(db_member),
)
member_search = await new_chat_routes.search_threads(
search_space_id=db_search_space.id,
title="Visibility",
session=db_session,
user=db_member,
auth=_auth(db_member),
)
full_thread = await new_chat_routes.get_thread_full(
thread_id=thread.id,
session=db_session,
user=db_member,
auth=_auth(db_member),
)
assert updated.visibility == ChatVisibility.SEARCH_SPACE
@ -181,20 +186,20 @@ async def test_rename_and_archive_do_not_reset_shared_visibility(
visibility=ChatVisibility.SEARCH_SPACE,
),
session=db_session,
user=db_user,
auth=_auth(db_user),
)
renamed = await new_chat_routes.update_thread(
thread_id=thread.id,
thread_update=NewChatThreadUpdate(title="Renamed Shared Chat"),
session=db_session,
user=db_user,
auth=_auth(db_user),
)
archived = await new_chat_routes.update_thread(
thread_id=thread.id,
thread_update=NewChatThreadUpdate(archived=True),
session=db_session,
user=db_user,
auth=_auth(db_user),
)
assert renamed.visibility == ChatVisibility.SEARCH_SPACE
@ -215,7 +220,7 @@ async def test_non_creator_cannot_change_shared_thread_back_to_private(
visibility=ChatVisibility.SEARCH_SPACE,
),
session=db_session,
user=db_user,
auth=_auth(db_user),
)
with pytest.raises(HTTPException) as exc_info:
@ -225,7 +230,7 @@ async def test_non_creator_cannot_change_shared_thread_back_to_private(
visibility=ChatVisibility.PRIVATE,
),
session=db_session,
user=db_member,
auth=_auth(db_member),
)
assert exc_info.value.status_code == 403
@ -244,7 +249,7 @@ async def test_creator_can_make_shared_thread_private_again(
visibility=ChatVisibility.SEARCH_SPACE,
),
session=db_session,
user=db_user,
auth=_auth(db_user),
)
private_again = await new_chat_routes.update_thread_visibility(
@ -253,18 +258,18 @@ async def test_creator_can_make_shared_thread_private_again(
visibility=ChatVisibility.PRIVATE,
),
session=db_session,
user=db_user,
auth=_auth(db_user),
)
member_threads = await new_chat_routes.list_threads(
search_space_id=db_search_space.id,
session=db_session,
user=db_member,
auth=_auth(db_member),
)
member_search = await new_chat_routes.search_threads(
search_space_id=db_search_space.id,
title="Visibility",
session=db_session,
user=db_member,
auth=_auth(db_member),
)
assert private_again.visibility == ChatVisibility.PRIVATE
@ -274,6 +279,6 @@ async def test_creator_can_make_shared_thread_private_again(
await new_chat_routes.get_thread_full(
thread_id=thread.id,
session=db_session,
user=db_member,
auth=_auth(db_member),
)
assert exc_info.value.status_code == 403

View file

@ -15,6 +15,7 @@ from httpx import ASGITransport
from sqlalchemy.ext.asyncio import AsyncSession
from app.app import app, limiter
from app.auth.context import AuthContext
from app.config import config
from app.db import (
SearchSourceConnector,
@ -22,7 +23,7 @@ from app.db import (
User,
get_async_session,
)
from app.users import current_active_user
from app.users import get_auth_context
pytestmark = pytest.mark.integration
@ -40,12 +41,12 @@ async def client(
async def override_session() -> AsyncGenerator[AsyncSession, None]:
yield db_session
async def override_user() -> User:
return db_user
async def override_auth() -> AuthContext:
return AuthContext.session(db_user)
previous_overrides = app.dependency_overrides.copy()
app.dependency_overrides[get_async_session] = override_session
app.dependency_overrides[current_active_user] = override_user
app.dependency_overrides[get_auth_context] = override_auth
try:
async with httpx.AsyncClient(

View file

@ -1,9 +1,9 @@
"""Notifications integration fixtures.
The app's DB session and current-user dependencies are overridden to ride the
The app's DB session and auth-context dependencies are overridden to ride the
test's transactional `db_session`, so API calls and seeded rows share one
transaction that rolls back per test. Overriding `current_active_user` also
bypasses real JWT auth, so these tests don't depend on AUTH_TYPE.
transaction that rolls back per test. Overriding `get_auth_context` also bypasses
real JWT auth, so these tests don't depend on AUTH_TYPE.
"""
from __future__ import annotations
@ -17,8 +17,9 @@ from httpx import ASGITransport
from sqlalchemy.ext.asyncio import AsyncSession
from app.app import app, limiter
from app.auth.context import AuthContext
from app.db import User, get_async_session
from app.users import current_active_user
from app.users import get_auth_context
pytestmark = pytest.mark.integration
@ -33,12 +34,12 @@ async def client(
async def override_session() -> AsyncGenerator[AsyncSession, None]:
yield db_session
async def override_user() -> User:
return db_user
async def override_auth() -> AuthContext:
return AuthContext.session(db_user)
previous_overrides = app.dependency_overrides.copy()
app.dependency_overrides[get_async_session] = override_session
app.dependency_overrides[current_active_user] = override_user
app.dependency_overrides[get_auth_context] = override_auth
try:
async with httpx.AsyncClient(

View file

@ -24,6 +24,7 @@ from httpx import ASGITransport
from sqlalchemy.ext.asyncio import AsyncSession
from app.app import app, limiter
from app.auth.context import AuthContext
from app.config import config as app_config
from app.db import SearchSpace, User, get_async_session
from app.podcasts.persistence import Podcast, PodcastStatus
@ -39,7 +40,7 @@ from app.podcasts.schemas import (
from app.podcasts.service import PodcastService
from app.podcasts.tts import SynthesisRequest, SynthesizedAudio, TextToSpeech
from app.routes.search_spaces_routes import create_default_roles_and_membership
from app.users import current_active_user
from app.users import get_auth_context
pytestmark = pytest.mark.integration
@ -54,12 +55,12 @@ async def client(
async def override_session() -> AsyncGenerator[AsyncSession, None]:
yield db_session
async def override_user() -> User:
return db_user
async def override_auth() -> AuthContext:
return AuthContext.session(db_user)
previous_overrides = app.dependency_overrides.copy()
app.dependency_overrides[get_async_session] = override_session
app.dependency_overrides[current_active_user] = override_user
app.dependency_overrides[get_auth_context] = override_auth
try:
async with httpx.AsyncClient(
@ -290,7 +291,7 @@ def act_as():
"""
def _act(user: User) -> None:
app.dependency_overrides[current_active_user] = lambda: user
app.dependency_overrides[get_auth_context] = lambda: AuthContext.session(user)
return _act

View file

@ -23,6 +23,7 @@ import pytest
from fastapi import HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
@ -109,7 +110,7 @@ class TestConnectorIndexCrossSpaceAuthz:
connector_id=connector_a.id,
search_space_id=space_b.id, # the attacker's own space
session=db_session,
user=attacker,
auth=AuthContext.session(attacker),
)
assert exc_info.value.status_code == 404
@ -140,7 +141,7 @@ class TestConnectorIndexCrossSpaceAuthz:
connector_id=connector.id,
search_space_id=space.id, # the connector's own space
session=db_session,
user=owner,
auth=AuthContext.session(owner),
)
check_permission_mock.assert_awaited_once()

View file

@ -28,6 +28,7 @@ from sqlalchemy import func, select, text
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
@ -42,6 +43,7 @@ from app.routes.obsidian_plugin_routes import (
obsidian_stats,
obsidian_sync,
)
from app.routes.search_spaces_routes import create_default_roles_and_membership
from app.schemas.obsidian_plugin import (
ConnectRequest,
DeleteAck,
@ -65,6 +67,10 @@ pytestmark = pytest.mark.integration
# ---------------------------------------------------------------------------
def _auth(user: User) -> AuthContext:
return AuthContext.session(user)
def _make_note_payload(vault_id: str, path: str, content_hash: str) -> NotePayload:
"""Minimal NotePayload that the schema accepts; the indexer is mocked
out so the values don't have to round-trip through the real pipeline."""
@ -102,6 +108,8 @@ async def race_user_and_space(async_engine):
)
space = SearchSpace(name="Race Space", user_id=user_id)
setup.add_all([user, space])
await setup.flush()
await create_default_roles_and_membership(setup, space.id, user_id)
await setup.commit()
await setup.refresh(space)
space_id = space.id
@ -116,6 +124,14 @@ async def race_user_and_space(async_engine):
text("DELETE FROM search_source_connectors WHERE user_id = :uid"),
{"uid": user_id},
)
await cleanup.execute(
text("DELETE FROM search_space_memberships WHERE search_space_id = :id"),
{"id": space_id},
)
await cleanup.execute(
text("DELETE FROM search_space_roles WHERE search_space_id = :id"),
{"id": space_id},
)
await cleanup.execute(
text("DELETE FROM searchspaces WHERE id = :id"),
{"id": space_id},
@ -154,7 +170,7 @@ class TestConnectRace:
search_space_id=space_id,
vault_fingerprint=fingerprint,
)
await obsidian_connect(payload, user=fresh_user, session=s)
await obsidian_connect(payload, auth=_auth(fresh_user), session=s)
results = await asyncio.gather(_call("a"), _call("b"), return_exceptions=True)
for r in results:
@ -281,7 +297,7 @@ class TestConnectRace:
search_space_id=space_id,
vault_fingerprint=fingerprint,
),
user=fresh_user,
auth=_auth(fresh_user),
session=s,
)
@ -294,7 +310,7 @@ class TestConnectRace:
search_space_id=space_id,
vault_fingerprint=fingerprint,
),
user=fresh_user,
auth=_auth(fresh_user),
session=s,
)
@ -337,7 +353,7 @@ class TestWireContractSmoke:
search_space_id=db_search_space.id,
vault_fingerprint="fp-" + uuid.uuid4().hex,
),
user=db_user,
auth=_auth(db_user),
session=db_session,
)
assert connect_resp.connector_id > 0
@ -361,7 +377,7 @@ class TestWireContractSmoke:
_make_note_payload(vault_id, "fail.md", "hash-fail"),
],
),
user=db_user,
auth=_auth(db_user),
session=db_session,
)
@ -394,7 +410,7 @@ class TestWireContractSmoke:
_make_note_payload(vault_id, "fail.md", "h2"),
],
),
user=db_user,
auth=_auth(db_user),
session=db_session,
)
assert sync_resp.indexed == 1
@ -420,7 +436,7 @@ class TestWireContractSmoke:
RenameItem(old_path="missing.md", new_path="x.md"),
],
),
user=db_user,
auth=_auth(db_user),
session=db_session,
)
assert isinstance(rename_resp, RenameAck)
@ -441,7 +457,7 @@ class TestWireContractSmoke:
):
delete_resp = await obsidian_delete_notes(
DeleteBatchRequest(vault_id=vault_id, paths=["b.md", "ghost.md"]),
user=db_user,
auth=_auth(db_user),
session=db_session,
)
assert isinstance(delete_resp, DeleteAck)
@ -456,7 +472,7 @@ class TestWireContractSmoke:
# upsert_note was mocked) but the response shape is what we care
# about.
manifest_resp = await obsidian_manifest(
vault_id=vault_id, user=db_user, session=db_session
vault_id=vault_id, auth=_auth(db_user), session=db_session
)
assert isinstance(manifest_resp, ManifestResponse)
assert manifest_resp.vault_id == vault_id
@ -464,7 +480,7 @@ class TestWireContractSmoke:
# 6. /stats — same; row count is 0 because upsert_note was mocked.
stats_resp = await obsidian_stats(
vault_id=vault_id, user=db_user, session=db_session
vault_id=vault_id, auth=_auth(db_user), session=db_session
)
assert isinstance(stats_resp, StatsResponse)
assert stats_resp.vault_id == vault_id
@ -482,7 +498,7 @@ class TestWireContractSmoke:
search_space_id=db_search_space.id,
vault_fingerprint="fp-" + uuid.uuid4().hex,
),
user=db_user,
auth=_auth(db_user),
session=db_session,
)
@ -511,7 +527,7 @@ class TestWireContractSmoke:
binary_note,
],
),
user=db_user,
auth=_auth(db_user),
session=db_session,
)
@ -533,7 +549,7 @@ class TestWireContractSmoke:
search_space_id=db_search_space.id,
vault_fingerprint="fp-" + uuid.uuid4().hex,
),
user=db_user,
auth=_auth(db_user),
session=db_session,
)
@ -562,7 +578,7 @@ class TestWireContractSmoke:
bad_note,
],
),
user=db_user,
auth=_auth(db_user),
session=db_session,
)
@ -587,7 +603,7 @@ class TestWireContractSmoke:
search_space_id=db_search_space.id,
vault_fingerprint="fp-" + uuid.uuid4().hex,
),
user=db_user,
auth=_auth(db_user),
session=db_session,
)
@ -616,7 +632,7 @@ class TestWireContractSmoke:
mismatched,
],
),
user=db_user,
auth=_auth(db_user),
session=db_session,
)

View file

@ -0,0 +1,71 @@
"""Runtime smoke tests for fail-closed PAT authorization primitives."""
from __future__ import annotations
import pytest
from fastapi import HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.db import PersonalAccessToken, SearchSpace, User
from app.users import allow_any_principal, require_session_context
from app.utils.rbac import check_search_space_access
pytestmark = pytest.mark.integration
def _pat_auth(user: User) -> AuthContext:
pat = PersonalAccessToken(
user_id=user.id,
user=user,
token_hash="0" * 64,
token_prefix="ss_pat_test",
label="Test PAT",
)
return AuthContext.pat_auth(user, pat)
async def test_pat_is_rejected_by_session_only_dependency(db_user: User):
auth = _pat_auth(db_user)
with pytest.raises(HTTPException) as exc_info:
await require_session_context(auth=auth)
assert exc_info.value.status_code == 403
async def test_pat_is_allowed_by_bootstrap_dependency(db_user: User):
auth = _pat_auth(db_user)
assert await allow_any_principal(auth=auth) is auth
async def test_pat_is_rejected_for_api_disabled_space(
db_session: AsyncSession,
db_user: User,
db_search_space: SearchSpace,
):
db_search_space.api_access_enabled = False
await db_session.flush()
auth = _pat_auth(db_user)
with pytest.raises(HTTPException) as exc_info:
await check_search_space_access(db_session, auth, db_search_space.id)
assert exc_info.value.status_code == 403
assert exc_info.value.detail == "API access is not enabled for this search space."
async def test_pat_is_allowed_for_api_enabled_space(
db_session: AsyncSession,
db_user: User,
db_search_space: SearchSpace,
):
db_search_space.api_access_enabled = True
await db_session.flush()
auth = _pat_auth(db_user)
membership = await check_search_space_access(db_session, auth, db_search_space.id)
assert membership.user_id == db_user.id
assert membership.search_space_id == db_search_space.id

View file

@ -15,6 +15,7 @@ import pytest
from fastapi import HTTPException
import app.automations.services.automation as automation_mod
from app.auth.context import AuthContext
from app.automations.schemas.api import AutomationCreate, AutomationUpdate
from app.automations.schemas.definition.envelope import (
AutomationDefinition,
@ -45,7 +46,8 @@ class _FakeSession:
def _service(search_space: Any) -> AutomationService:
return AutomationService(
session=_FakeSession(search_space), user=SimpleNamespace(id="u-1")
session=_FakeSession(search_space),
auth=AuthContext.session(SimpleNamespace(id="u-1")),
)

View file

@ -38,7 +38,9 @@ async def cleanup_supervisors():
@pytest.mark.asyncio
async def test_start_byo_long_poll_noops_when_mode_is_webhook(monkeypatch):
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_ENABLED", True)
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_TELEGRAM_INTAKE_MODE", "webhook")
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_WHATSAPP_INTAKE_MODE", "disabled")
await byo_long_poll.start_byo_long_poll_supervisors()
@ -47,9 +49,11 @@ async def test_start_byo_long_poll_noops_when_mode_is_webhook(monkeypatch):
@pytest.mark.asyncio
async def test_start_byo_long_poll_noops_when_no_byo_accounts(mocker, monkeypatch):
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_ENABLED", True)
monkeypatch.setattr(
byo_long_poll.config, "GATEWAY_TELEGRAM_INTAKE_MODE", "longpoll"
)
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_WHATSAPP_INTAKE_MODE", "disabled")
session = mocker.AsyncMock()
session.execute.return_value = ScalarResult([])
monkeypatch.setattr(
@ -67,9 +71,11 @@ async def test_start_byo_long_poll_noops_when_no_byo_accounts(mocker, monkeypatc
async def test_start_byo_long_poll_spawns_one_supervisor_per_account(
mocker, monkeypatch
):
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_ENABLED", True)
monkeypatch.setattr(
byo_long_poll.config, "GATEWAY_TELEGRAM_INTAKE_MODE", "longpoll"
)
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_WHATSAPP_INTAKE_MODE", "disabled")
accounts = [mocker.Mock(id=1), mocker.Mock(id=2)]
session = mocker.AsyncMock()
session.execute.return_value = ScalarResult(accounts)
@ -115,9 +121,11 @@ async def test_supervisor_retries_after_run_returns(mocker, monkeypatch):
@pytest.mark.asyncio
async def test_shutdown_cancels_running_supervisors(mocker, monkeypatch):
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_ENABLED", True)
monkeypatch.setattr(
byo_long_poll.config, "GATEWAY_TELEGRAM_INTAKE_MODE", "longpoll"
)
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_WHATSAPP_INTAKE_MODE", "disabled")
session = mocker.AsyncMock()
session.execute.return_value = ScalarResult([mocker.Mock(id=1)])
monkeypatch.setattr(

View file

@ -27,6 +27,7 @@ async def test_inbox_worker_claims_and_processes_in_fastapi_process(
async def test_start_stop_gateway_inbox_worker(mocker, monkeypatch):
started = asyncio.Event()
stopped = asyncio.Event()
monkeypatch.setattr(inbox_worker.config, "GATEWAY_ENABLED", True)
async def run_forever():
started.set()

View file

@ -9,6 +9,7 @@ from types import SimpleNamespace
import pytest
from app.auth.context import AuthContext
from app.db import ExternalChatAccount, ExternalChatAccountMode, ExternalChatPlatform
from app.routes import gateway_webhook_routes as routes
@ -333,7 +334,9 @@ async def test_discord_gateway_install_returns_oauth_url(monkeypatch, mocker):
response = await routes.install_discord_gateway(
search_space_id=123,
user=SimpleNamespace(id="00000000-0000-0000-0000-000000000001"),
auth=AuthContext.session(
SimpleNamespace(id="00000000-0000-0000-0000-000000000001")
),
session=mocker.AsyncMock(),
)

View file

@ -19,6 +19,7 @@ from unittest.mock import AsyncMock, patch
import pytest
from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
from app.auth.context import AuthContext
from app.routes import agent_revert_route
from app.services.revert_service import RevertOutcome
@ -147,7 +148,7 @@ class TestFlagGuard:
thread_id=1,
chat_turn_id="42:1700000000000",
session=session,
user=_FakeUser(),
auth=AuthContext.session(_FakeUser()),
)
assert getattr(exc.value, "status_code", None) == 503
@ -167,7 +168,7 @@ class TestRevertTurnDispatch:
thread_id=1,
chat_turn_id="ct-empty",
session=session,
user=_FakeUser(),
auth=AuthContext.session(_FakeUser()),
)
assert response.status == "ok"
assert response.total == 0
@ -209,7 +210,7 @@ class TestRevertTurnDispatch:
thread_id=1,
chat_turn_id="ct-3",
session=session,
user=_FakeUser(),
auth=AuthContext.session(_FakeUser()),
)
assert response.status == "ok"
@ -248,7 +249,7 @@ class TestRevertTurnDispatch:
thread_id=1,
chat_turn_id="ct-i",
session=session,
user=_FakeUser(),
auth=AuthContext.session(_FakeUser()),
)
assert response.status == "ok"
assert response.already_reverted == 1
@ -275,7 +276,7 @@ class TestRevertTurnDispatch:
thread_id=1,
chat_turn_id="ct-rev",
session=session,
user=_FakeUser(),
auth=AuthContext.session(_FakeUser()),
)
assert response.status == "ok"
assert response.results[0].status == "skipped"
@ -315,7 +316,7 @@ class TestRevertTurnDispatch:
thread_id=1,
chat_turn_id="ct-mix",
session=session,
user=_FakeUser(),
auth=AuthContext.session(_FakeUser()),
)
assert response.status == "partial"
assert response.reverted == 1
@ -354,7 +355,7 @@ class TestRevertTurnDispatch:
thread_id=1,
chat_turn_id="ct-fail",
session=session,
user=_FakeUser(),
auth=AuthContext.session(_FakeUser()),
)
assert response.status == "partial"
assert response.failed == 1
@ -386,7 +387,7 @@ class TestRevertTurnDispatch:
thread_id=1,
chat_turn_id="ct-perm",
session=session,
user=_FakeUser(id="not-owner"),
auth=AuthContext.session(_FakeUser(id="not-owner")),
)
assert response.status == "partial"
assert response.results[0].status == "permission_denied"
@ -449,7 +450,7 @@ class TestRevertTurnDispatch:
thread_id=1,
chat_turn_id="ct-mixed-all",
session=session,
user=_FakeUser(), # only id=7 has a different user_id
auth=AuthContext.session(_FakeUser()), # only id=7 has a different user_id
)
assert response.total == len(rows) == 6
@ -518,7 +519,7 @@ class TestRevertTurnDispatch:
thread_id=1,
chat_turn_id="ct-race",
session=session,
user=_FakeUser(),
auth=AuthContext.session(_FakeUser()),
)
assert response.failed == 0, (

View file

@ -0,0 +1,101 @@
"""Static guards for the fail-closed PAT authorization model."""
from __future__ import annotations
from pathlib import Path
import pytest
pytestmark = pytest.mark.unit
BACKEND_ROOT = Path(__file__).resolve().parents[2]
APP_ROOT = BACKEND_ROOT / "app"
ALLOW_ANY_EXPECTED = {
"app.py": "auth: AuthContext = Depends(allow_any_principal)",
"routes/obsidian_plugin_routes.py": (
"_auth: AuthContext = Depends(allow_any_principal)"
),
"routes/search_spaces_routes.py": (
"auth: AuthContext = Depends(allow_any_principal)"
),
}
CONNECTOR_LISTERS = [
"routes/slack_add_connector_route.py",
"routes/composio_routes.py",
"routes/google_drive_add_connector_route.py",
"routes/discord_add_connector_route.py",
"routes/dropbox_add_connector_route.py",
"routes/onedrive_add_connector_route.py",
]
def _python_files() -> list[Path]:
return [
path
for path in APP_ROOT.rglob("*.py")
if "__pycache__" not in path.parts
]
def test_current_active_user_is_removed_from_app_tree() -> None:
offenders = [
str(path.relative_to(BACKEND_ROOT))
for path in _python_files()
if "current_active_user" in path.read_text()
]
assert offenders == []
def test_allow_any_principal_is_only_used_by_bootstrap_allowlist() -> None:
actual: dict[str, int] = {}
for path in _python_files():
text = path.read_text()
count = text.count("Depends(allow_any_principal)")
if count:
actual[str(path.relative_to(APP_ROOT))] = count
assert actual == dict.fromkeys(ALLOW_ANY_EXPECTED, 1)
for rel_path, expected_snippet in ALLOW_ANY_EXPECTED.items():
text = (APP_ROOT / rel_path).read_text()
assert expected_snippet in text
def test_connector_listers_route_pat_through_search_space_gate() -> None:
for rel_path in CONNECTOR_LISTERS:
text = (APP_ROOT / rel_path).read_text()
assert "auth: AuthContext = Depends(get_auth_context)" in text, rel_path
assert (
"await check_search_space_access(session, auth, connector.search_space_id)"
in text
), rel_path
def test_identity_routes_are_session_only() -> None:
session_only_files = [
"routes/prompts_routes.py",
"routes/memory_routes.py",
"routes/model_list_routes.py",
"routes/agent_flags_route.py",
"routes/youtube_routes.py",
"routes/incentive_tasks_routes.py",
"notifications/api/api.py",
"routes/chat_comments_routes.py",
"routes/public_chat_routes.py",
]
for rel_path in session_only_files:
text = (APP_ROOT / rel_path).read_text()
assert "require_session_context" in text, rel_path
assert "Depends(get_auth_context)" not in text, rel_path
def test_model_connection_personal_writes_default_to_session_required() -> None:
text = (APP_ROOT / "routes/model_connections_routes.py").read_text()
assert "allow_spaceless_pat: bool = False" in text
assert "auth.is_gated and not allow_spaceless_pat" in text
assert "Managing personal model connections requires an interactive session" in text

View file

@ -16,7 +16,7 @@ const ApiKeyForm = () => {
const validateForm = () => {
if (!apiKey) {
setError("API key is required");
setError("Personal access token is required");
return false;
}
setError("");
@ -39,11 +39,11 @@ const ApiKeyForm = () => {
setLoading(false);
if (response.ok) {
// Store the API key as the token
// Store the PAT as the bearer token for existing background handlers.
await storage.set("token", apiKey);
navigation("/");
} else {
setError("Invalid API key. Please check and try again.");
setError("Invalid personal access token. Please check and try again.");
}
} catch (error) {
setLoading(false);
@ -67,15 +67,15 @@ const ApiKeyForm = () => {
<div className="bg-gray-800/70 backdrop-blur-sm rounded-xl shadow-xl border border-gray-700 p-6">
<div className="space-y-6">
<h2 className="text-xl font-medium text-white">Enter your API Key</h2>
<h2 className="text-xl font-medium text-white">Enter your personal access token</h2>
<p className="text-gray-400 text-sm">
Your API key connects this extension to the SurfSense.
Your personal access token connects this extension to SurfSense.
</p>
<form onSubmit={handleSubmit} className="space-y-4">
<div className="space-y-2">
<label htmlFor="apiKey" className="text-sm font-medium text-gray-300">
API Key
Personal access token
</label>
<input
type="text"
@ -83,7 +83,7 @@ const ApiKeyForm = () => {
value={apiKey}
onChange={(e) => setApiKey(e.target.value)}
className="w-full px-3 py-2 bg-gray-900/50 border border-gray-700 rounded-md focus:outline-none focus:ring-2 focus:ring-teal-500 text-white placeholder:text-gray-500"
placeholder="Enter your API key"
placeholder="Enter your personal access token"
/>
{error && <p className="text-red-400 text-sm mt-1">{error}</p>}
</div>
@ -106,7 +106,7 @@ const ApiKeyForm = () => {
<div className="text-center mt-4">
<p className="text-sm text-gray-400">
Need an API key?{" "}
Need a personal access token?{" "}
<a
href="https://www.surfsense.com"
target="_blank"

View file

@ -51,7 +51,7 @@ Open **Settings → SurfSense** in Obsidian and fill in:
| Setting | Value |
| --- | --- |
| Server URL | `https://surfsense.com` for SurfSense Cloud, or your self-hosted URL |
| API token | Copy from the *Connectors → Obsidian* dialog in the SurfSense web app |
| API token | Create a personal access token from the *Connectors → Obsidian* dialog or *User settings → API access* in the SurfSense web app |
| Search space | Pick the search space this vault should sync into |
| Vault name | Defaults to your Obsidian vault name; rename if you have multiple vaults |
| Sync mode | *Auto* (recommended) or *Manual* |
@ -62,11 +62,6 @@ The connector row appears automatically inside SurfSense the first time the
plugin successfully calls `/obsidian/connect`. You can manage or delete it
from *Connectors → Obsidian* in the web app.
> **Token lifetime.** The web app currently issues 24-hour JWTs. If you see
> *"token expired"* in the plugin status bar, paste a fresh token from the
> SurfSense web app. Long-lived personal access tokens are coming in a future
> release.
## Mobile
The plugin works on Obsidian for iOS and Android. Sync runs whenever the

View file

@ -22,11 +22,11 @@ import type {
*
* Auth + wire contract:
* - Every request carries `Authorization: Bearer <token>` only. No
* custom headers the backend identifies the caller from the JWT
* custom headers the backend identifies the caller from the PAT
* and feature-detects the API via the `capabilities` array on
* `/health` and `/connect`.
* - 401 surfaces as `AuthError` so the orchestrator can show the
* "token expired, paste a fresh one" UX.
* "token invalid or expired" UX.
* - HealthResponse / ConnectResponse use index signatures so any
* additive backend field (e.g. new capabilities) parses without
* breaking the decoder. This mirrors `ConfigDict(extra='ignore')`

Some files were not shown because too many files have changed in this diff Show more