mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-24 21:38:09 +02:00
Merge pull request #1524 from AnishSarkar22/feat/api-key
feat(auth): replace JWT-as-API-key with hashed PATs + per-space API access gate
This commit is contained in:
commit
a0f44b283c
118 changed files with 2273 additions and 859 deletions
|
|
@ -84,6 +84,9 @@ SECRET_KEY=SECRET
|
|||
# JWT Token Lifetimes (optional, defaults shown)
|
||||
# ACCESS_TOKEN_LIFETIME_SECONDS=86400 # 1 day
|
||||
# REFRESH_TOKEN_LIFETIME_SECONDS=1209600 # 2 weeks
|
||||
# Personal Access Tokens (PATs). Empty/unset = no maximum; users may create
|
||||
# never-expiring PATs. When set, PAT creation requires an expiry <= this many days.
|
||||
# PAT_MAX_EXPIRY_DAYS=
|
||||
|
||||
NEXT_FRONTEND_URL=http://localhost:3000
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,83 @@
|
|||
"""Add personal access tokens and search-space API access gate.
|
||||
|
||||
Revision ID: 166
|
||||
Revises: 165
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "166"
|
||||
down_revision: str | None = "165"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS personal_access_tokens (
|
||||
id SERIAL PRIMARY KEY,
|
||||
user_id UUID NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
|
||||
token_hash VARCHAR(64) NOT NULL,
|
||||
token_prefix VARCHAR(16) NOT NULL,
|
||||
label VARCHAR NOT NULL,
|
||||
expires_at TIMESTAMP WITH TIME ZONE,
|
||||
last_used_at TIMESTAMP WITH TIME ZONE,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS ix_personal_access_tokens_token_hash "
|
||||
"ON personal_access_tokens (token_hash)"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_personal_access_tokens_user_id "
|
||||
"ON personal_access_tokens (user_id)"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_personal_access_tokens_id "
|
||||
"ON personal_access_tokens (id)"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_personal_access_tokens_created_at "
|
||||
"ON personal_access_tokens (created_at)"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_personal_access_tokens_expires_at "
|
||||
"ON personal_access_tokens (expires_at)"
|
||||
)
|
||||
|
||||
bind = op.get_bind()
|
||||
api_access_column_exists = bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT EXISTS (
|
||||
SELECT FROM information_schema.columns
|
||||
WHERE table_schema = current_schema()
|
||||
AND table_name = 'searchspaces'
|
||||
AND column_name = 'api_access_enabled'
|
||||
)
|
||||
"""
|
||||
)
|
||||
).scalar()
|
||||
|
||||
op.execute(
|
||||
"ALTER TABLE searchspaces ADD COLUMN IF NOT EXISTS "
|
||||
"api_access_enabled BOOLEAN NOT NULL DEFAULT false"
|
||||
)
|
||||
|
||||
if not api_access_column_exists:
|
||||
op.execute("UPDATE searchspaces SET api_access_enabled = true")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
"ALTER TABLE searchspaces DROP COLUMN IF EXISTS api_access_enabled"
|
||||
)
|
||||
op.execute("DROP TABLE IF EXISTS personal_access_tokens")
|
||||
|
|
@ -34,6 +34,7 @@ from app.agents.chat.runtime.llm_config import AgentConfig
|
|||
from app.agents.chat.runtime.prompt_caching import (
|
||||
apply_litellm_prompt_caching,
|
||||
)
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import ChatVisibility
|
||||
from app.services.connector_service import ConnectorService
|
||||
from app.services.user_tool_allowlist import (
|
||||
|
|
@ -73,6 +74,7 @@ async def create_multi_agent_chat_deep_agent(
|
|||
anon_session_id: str | None = None,
|
||||
filesystem_selection: FilesystemSelection | None = None,
|
||||
image_gen_model_id: int | None = None,
|
||||
auth_context: AuthContext | None = None,
|
||||
):
|
||||
"""Deep agent with SurfSense tools/middleware; registry route subagents behind ``task`` when enabled.
|
||||
|
||||
|
|
@ -139,6 +141,7 @@ async def create_multi_agent_chat_deep_agent(
|
|||
"connector_service": connector_service,
|
||||
"firecrawl_api_key": firecrawl_api_key,
|
||||
"user_id": user_id,
|
||||
"auth_context": auth_context,
|
||||
"thread_id": thread_id,
|
||||
"thread_visibility": visibility,
|
||||
"available_connectors": available_connectors,
|
||||
|
|
|
|||
|
|
@ -30,9 +30,10 @@ from pydantic import ValidationError
|
|||
from app.agents.chat.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import (
|
||||
request_approval,
|
||||
)
|
||||
from app.auth.context import AuthContext
|
||||
from app.automations.schemas.api import AutomationCreate
|
||||
from app.automations.services.automation import AutomationService
|
||||
from app.db import User, async_session_maker
|
||||
from app.db import async_session_maker
|
||||
from app.utils.content_utils import extract_text_content
|
||||
|
||||
from .prompt import build_draft_prompt
|
||||
|
|
@ -47,6 +48,7 @@ def create_create_automation_tool(
|
|||
search_space_id: int,
|
||||
user_id: str | UUID,
|
||||
llm: Any,
|
||||
auth_context: AuthContext | None = None,
|
||||
):
|
||||
"""Factory for the ``create_automation`` tool.
|
||||
|
||||
|
|
@ -56,8 +58,6 @@ def create_create_automation_tool(
|
|||
``AsyncSession`` is opened per call to avoid stale sessions on
|
||||
compiled-agent cache hits (same pattern as the Notion / memory tools).
|
||||
"""
|
||||
uid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
|
||||
@tool
|
||||
async def create_automation(intent: str, runtime: ToolRuntime) -> dict[str, Any]:
|
||||
"""Draft + save an automation from a natural-language intent.
|
||||
|
|
@ -165,14 +165,17 @@ def create_create_automation_tool(
|
|||
"issues": _format_validation_issues(exc),
|
||||
}
|
||||
|
||||
if auth_context is None:
|
||||
logger.error(
|
||||
"create_automation called without AuthContext; refusing to persist"
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "authorization context missing for automation creation",
|
||||
}
|
||||
|
||||
async with async_session_maker() as session:
|
||||
user = await session.get(User, uid)
|
||||
if user is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "user not found in this session",
|
||||
}
|
||||
service = AutomationService(session=session, user=user)
|
||||
service = AutomationService(session=session, auth=auth_context)
|
||||
created = await service.create(final_validated)
|
||||
return {
|
||||
"status": "saved",
|
||||
|
|
|
|||
|
|
@ -60,6 +60,7 @@ def _build_create_automation_tool(deps: dict[str, Any]) -> BaseTool:
|
|||
return create_create_automation_tool(
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
auth_context=deps.get("auth_context"),
|
||||
llm=deps["llm"],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from app.agents.chat.runtime.checkpointer import (
|
|||
close_checkpointer,
|
||||
setup_checkpointer_tables,
|
||||
)
|
||||
from app.auth.context import AuthContext
|
||||
from app.config import (
|
||||
config,
|
||||
initialize_image_gen_router,
|
||||
|
|
@ -34,7 +35,7 @@ from app.config import (
|
|||
initialize_openrouter_integration,
|
||||
initialize_pricing_registration,
|
||||
)
|
||||
from app.db import User, create_db_and_tables, get_async_session
|
||||
from app.db import create_db_and_tables, get_async_session
|
||||
from app.exceptions import GENERIC_5XX_MESSAGE, ISSUES_URL, SurfSenseError
|
||||
from app.gateway.byo_long_poll import (
|
||||
start_byo_long_poll_supervisors,
|
||||
|
|
@ -55,7 +56,7 @@ from app.routes import router as crud_router
|
|||
from app.routes.auth_routes import router as auth_router
|
||||
from app.schemas import UserCreate, UserRead, UserUpdate
|
||||
from app.session_events import register_session_hooks
|
||||
from app.users import SECRET, auth_backend, current_active_user, fastapi_users
|
||||
from app.users import SECRET, allow_any_principal, auth_backend, fastapi_users
|
||||
from app.utils.perf import log_system_snapshot
|
||||
|
||||
_error_logger = logging.getLogger("surfsense.errors")
|
||||
|
|
@ -1032,7 +1033,7 @@ async def readiness_check():
|
|||
|
||||
@app.get("/verify-token")
|
||||
async def authenticated_route(
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(allow_any_principal),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
return {"message": "Token is valid"}
|
||||
return {"message": "Token is valid", "method": auth.method}
|
||||
|
|
|
|||
1
surfsense_backend/app/auth/__init__.py
Normal file
1
surfsense_backend/app/auth/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""Authentication principals and helpers."""
|
||||
38
surfsense_backend/app/auth/context.py
Normal file
38
surfsense_backend/app/auth/context.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
from app.db import PersonalAccessToken, User
|
||||
|
||||
AuthMethod = Literal["session", "pat", "system"]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AuthContext:
|
||||
"""Typed principal for authorization decisions."""
|
||||
|
||||
user: User
|
||||
method: AuthMethod
|
||||
pat: PersonalAccessToken | None = None
|
||||
source: str | None = None
|
||||
|
||||
@classmethod
|
||||
def session(cls, user: User) -> AuthContext:
|
||||
return cls(user=user, method="session")
|
||||
|
||||
@classmethod
|
||||
def pat_auth(cls, user: User, pat: PersonalAccessToken) -> AuthContext:
|
||||
return cls(user=user, method="pat", pat=pat)
|
||||
|
||||
@classmethod
|
||||
def system(cls, user: User, source: str) -> AuthContext:
|
||||
return cls(user=user, method="system", source=source)
|
||||
|
||||
@property
|
||||
def is_gated(self) -> bool:
|
||||
return self.method == "pat"
|
||||
|
||||
@property
|
||||
def is_session(self) -> bool:
|
||||
return self.method == "session"
|
||||
|
|
@ -16,7 +16,8 @@ from app.agents.chat.runtime.mention_resolver import (
|
|||
substitute_in_text,
|
||||
)
|
||||
from app.agents.chat.shared.context import SurfSenseContextSchema
|
||||
from app.db import ChatVisibility, async_session_maker
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import ChatVisibility, User, async_session_maker
|
||||
from app.schemas.new_chat import MentionedDocumentInfo
|
||||
|
||||
from ...types import ActionContext
|
||||
|
|
@ -147,6 +148,12 @@ async def run_agent_task(
|
|||
decision = "approve" if auto_approve_all else "reject"
|
||||
|
||||
async with async_session_maker() as agent_session:
|
||||
auth_context = None
|
||||
if ctx.creator_user_id:
|
||||
user = await agent_session.get(User, ctx.creator_user_id)
|
||||
if user is not None:
|
||||
auth_context = AuthContext.system(user, source="automation")
|
||||
|
||||
deps = await build_dependencies(
|
||||
session=agent_session,
|
||||
search_space_id=ctx.search_space_id,
|
||||
|
|
@ -168,6 +175,7 @@ async def run_agent_task(
|
|||
thread_visibility=ChatVisibility.PRIVATE,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
image_gen_model_id=ctx.image_gen_model_id,
|
||||
auth_context=auth_context,
|
||||
)
|
||||
|
||||
agent_query, runtime_context = await _resolve_mention_context(
|
||||
|
|
|
|||
|
|
@ -27,17 +27,19 @@ from app.automations.services.model_policy import (
|
|||
)
|
||||
from app.automations.triggers import get_trigger
|
||||
from app.automations.triggers.builtin.schedule import compute_next_fire_at
|
||||
from app.db import Permission, SearchSpace, User, get_async_session
|
||||
from app.users import current_active_user
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import Permission, SearchSpace, get_async_session
|
||||
from app.users import get_auth_context
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
|
||||
class AutomationService:
|
||||
"""Lifecycle of the ``Automation`` resource."""
|
||||
|
||||
def __init__(self, *, session: AsyncSession, user: User) -> None:
|
||||
def __init__(self, *, session: AsyncSession, auth: AuthContext) -> None:
|
||||
self.session = session
|
||||
self.user = user
|
||||
self.auth = auth
|
||||
self.user = auth.user
|
||||
|
||||
async def create(self, payload: AutomationCreate) -> Automation:
|
||||
"""Create an automation and its initial triggers in one transaction."""
|
||||
|
|
@ -235,7 +237,7 @@ class AutomationService:
|
|||
async def _authorize(self, search_space_id: int, permission: str) -> None:
|
||||
await check_permission(
|
||||
self.session,
|
||||
self.user,
|
||||
self.auth,
|
||||
search_space_id,
|
||||
permission,
|
||||
f"You don't have permission to {permission.split(':')[1]} automations in this search space",
|
||||
|
|
@ -274,6 +276,6 @@ def _build_trigger(spec: TriggerCreate) -> AutomationTrigger:
|
|||
|
||||
def get_automation_service(
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> AutomationService:
|
||||
return AutomationService(session=session, user=user)
|
||||
return AutomationService(session=session, auth=auth)
|
||||
|
|
|
|||
|
|
@ -8,17 +8,18 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.automations.persistence.models.automation import Automation
|
||||
from app.automations.persistence.models.run import AutomationRun
|
||||
from app.db import Permission, User, get_async_session
|
||||
from app.users import current_active_user
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import Permission, get_async_session
|
||||
from app.users import get_auth_context
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
|
||||
class RunService:
|
||||
"""Read-only access to ``AutomationRun`` history."""
|
||||
|
||||
def __init__(self, *, session: AsyncSession, user: User) -> None:
|
||||
def __init__(self, *, session: AsyncSession, auth: AuthContext) -> None:
|
||||
self.session = session
|
||||
self.user = user
|
||||
self.auth = auth
|
||||
|
||||
async def list(
|
||||
self,
|
||||
|
|
@ -63,7 +64,7 @@ class RunService:
|
|||
)
|
||||
await check_permission(
|
||||
self.session,
|
||||
self.user,
|
||||
self.auth,
|
||||
automation.search_space_id,
|
||||
permission,
|
||||
f"You don't have permission to {permission.split(':')[1]} automations in this search space",
|
||||
|
|
@ -73,6 +74,6 @@ class RunService:
|
|||
|
||||
def get_run_service(
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> RunService:
|
||||
return RunService(session=session, user=user)
|
||||
return RunService(session=session, auth=auth)
|
||||
|
|
|
|||
|
|
@ -14,17 +14,18 @@ from app.automations.persistence.models.trigger import AutomationTrigger
|
|||
from app.automations.schemas.api import TriggerCreate, TriggerUpdate
|
||||
from app.automations.triggers import get_trigger
|
||||
from app.automations.triggers.builtin.schedule import compute_next_fire_at
|
||||
from app.db import Permission, User, get_async_session
|
||||
from app.users import current_active_user
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import Permission, get_async_session
|
||||
from app.users import get_auth_context
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
|
||||
class TriggerService:
|
||||
"""Lifecycle of the ``AutomationTrigger`` sub-resource."""
|
||||
|
||||
def __init__(self, *, session: AsyncSession, user: User) -> None:
|
||||
def __init__(self, *, session: AsyncSession, auth: AuthContext) -> None:
|
||||
self.session = session
|
||||
self.user = user
|
||||
self.auth = auth
|
||||
|
||||
async def add(
|
||||
self, *, automation_id: int, payload: TriggerCreate
|
||||
|
|
@ -101,7 +102,7 @@ class TriggerService:
|
|||
)
|
||||
await check_permission(
|
||||
self.session,
|
||||
self.user,
|
||||
self.auth,
|
||||
automation.search_space_id,
|
||||
permission,
|
||||
f"You don't have permission to {permission.split(':')[1]} automations in this search space",
|
||||
|
|
@ -144,6 +145,6 @@ def _initial_next_fire(
|
|||
|
||||
def get_trigger_service(
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> TriggerService:
|
||||
return TriggerService(session=session, user=user)
|
||||
return TriggerService(session=session, auth=auth)
|
||||
|
|
|
|||
|
|
@ -919,6 +919,10 @@ class Config:
|
|||
REFRESH_TOKEN_LIFETIME_SECONDS = int(
|
||||
os.getenv("REFRESH_TOKEN_LIFETIME_SECONDS", str(14 * 24 * 60 * 60)) # 2 weeks
|
||||
)
|
||||
_PAT_MAX_EXPIRY_DAYS = os.getenv("PAT_MAX_EXPIRY_DAYS", "").strip()
|
||||
PAT_MAX_EXPIRY_DAYS = (
|
||||
int(_PAT_MAX_EXPIRY_DAYS) if _PAT_MAX_EXPIRY_DAYS else None
|
||||
)
|
||||
|
||||
# ETL Service
|
||||
ETL_SERVICE = os.getenv("ETL_SERVICE")
|
||||
|
|
|
|||
|
|
@ -368,6 +368,9 @@ class Permission(StrEnum):
|
|||
SETTINGS_UPDATE = "settings:update"
|
||||
SETTINGS_DELETE = "settings:delete" # Delete the entire search space
|
||||
|
||||
# API Access
|
||||
API_ACCESS_MANAGE = "api_access:manage"
|
||||
|
||||
# Public Sharing
|
||||
PUBLIC_SHARING_VIEW = "public_sharing:view"
|
||||
PUBLIC_SHARING_CREATE = "public_sharing:create"
|
||||
|
|
@ -1693,6 +1696,9 @@ class SearchSpace(BaseModel, TimestampMixin):
|
|||
citations_enabled = Column(
|
||||
Boolean, nullable=False, default=True
|
||||
) # Enable/disable citations
|
||||
api_access_enabled = Column(
|
||||
Boolean, nullable=False, default=False, server_default="false"
|
||||
)
|
||||
qna_custom_instructions = Column(
|
||||
Text, nullable=True, default=""
|
||||
) # User's custom instructions
|
||||
|
|
@ -2330,6 +2336,11 @@ if config.AUTH_TYPE == "GOOGLE":
|
|||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
personal_access_tokens = relationship(
|
||||
"PersonalAccessToken",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
|
|
@ -2462,6 +2473,11 @@ else:
|
|||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
personal_access_tokens = relationship(
|
||||
"PersonalAccessToken",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
|
||||
class AgentActionLog(BaseModel):
|
||||
|
|
@ -2712,6 +2728,36 @@ class RefreshToken(Base, TimestampMixin):
|
|||
return not self.is_expired and not self.is_revoked
|
||||
|
||||
|
||||
class PersonalAccessToken(BaseModel, TimestampMixin):
|
||||
"""
|
||||
Stores hashed Personal Access Tokens for programmatic API access.
|
||||
Plaintext tokens are shown once on creation and are never persisted.
|
||||
"""
|
||||
|
||||
__tablename__ = "personal_access_tokens"
|
||||
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("user.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
user = relationship("User", back_populates="personal_access_tokens")
|
||||
token_hash = Column(String(64), unique=True, nullable=False, index=True)
|
||||
token_prefix = Column(String(16), nullable=False)
|
||||
label = Column(String, nullable=False)
|
||||
expires_at = Column(TIMESTAMP(timezone=True), nullable=True, index=True)
|
||||
last_used_at = Column(TIMESTAMP(timezone=True), nullable=True)
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
return self.expires_at is not None and datetime.now(UTC) >= self.expires_at
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
return not self.is_expired
|
||||
|
||||
|
||||
# Register model packages that live outside this file so their classes
|
||||
# are present in Base.metadata before configure_mappers() resolves any
|
||||
# string-based relationship() references.
|
||||
|
|
|
|||
|
|
@ -9,7 +9,8 @@ from fastapi.responses import StreamingResponse
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Document, Permission, User, get_async_session
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import Document, Permission, get_async_session
|
||||
from app.file_storage.persistence.enums import DocumentFileKind
|
||||
from app.file_storage.schemas import DocumentFileRead
|
||||
from app.file_storage.service import (
|
||||
|
|
@ -17,14 +18,14 @@ from app.file_storage.service import (
|
|||
list_document_files,
|
||||
open_document_file_stream,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.users import get_auth_context
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def _load_readable_document(
|
||||
*, document_id: int, session: AsyncSession, user: User
|
||||
*, document_id: int, session: AsyncSession, auth: AuthContext
|
||||
) -> Document:
|
||||
"""Load a document the user may read, or raise 404/403."""
|
||||
document = (
|
||||
|
|
@ -35,7 +36,7 @@ async def _load_readable_document(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
document.search_space_id,
|
||||
Permission.DOCUMENTS_READ.value,
|
||||
"You don't have permission to read documents in this search space",
|
||||
|
|
@ -57,10 +58,10 @@ def _content_disposition(filename: str) -> str:
|
|||
async def read_document_files(
|
||||
document_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> list[DocumentFileRead]:
|
||||
"""Return metadata for every stored file of a document (gates the UI)."""
|
||||
await _load_readable_document(document_id=document_id, session=session, user=user)
|
||||
await _load_readable_document(document_id=document_id, session=session, auth=auth)
|
||||
records = await list_document_files(session, document_id=document_id)
|
||||
return [DocumentFileRead.model_validate(r) for r in records]
|
||||
|
||||
|
|
@ -69,10 +70,10 @@ async def read_document_files(
|
|||
async def download_original_document_file(
|
||||
document_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> StreamingResponse:
|
||||
"""Stream the document's original uploaded file."""
|
||||
await _load_readable_document(document_id=document_id, session=session, user=user)
|
||||
await _load_readable_document(document_id=document_id, session=session, auth=auth)
|
||||
|
||||
record = await get_document_file(
|
||||
session, document_id=document_id, kind=DocumentFileKind.ORIGINAL
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from collections.abc import AsyncIterator
|
|||
from sqlalchemy import update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import ExternalChatBinding, NewChatMessage
|
||||
from app.gateway.auth_invariant import assert_authorization_invariant
|
||||
from app.gateway.base.translator import BaseStreamTranslator, GatewayStreamEvent
|
||||
|
|
@ -64,6 +65,7 @@ async def call_agent_for_gateway(
|
|||
request_id: str | None = None,
|
||||
) -> None:
|
||||
user = await assert_authorization_invariant(session, binding)
|
||||
auth_context = AuthContext.system(user, source="gateway")
|
||||
thread = await get_or_create_thread_for_binding(session, binding)
|
||||
await session.commit()
|
||||
|
||||
|
|
@ -81,6 +83,7 @@ async def call_agent_for_gateway(
|
|||
current_user_display_name=user.display_name or "A team member",
|
||||
disabled_tools=sorted(DEFAULT_HITL_TOOL_NAMES),
|
||||
request_id=request_id or "gateway",
|
||||
auth_context=auth_context,
|
||||
)
|
||||
events = _events_from_sse(stream)
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from __future__ import annotations
|
|||
from fastapi import HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import ExternalChatBinding, Permission, User
|
||||
from app.gateway.bindings import suspend_binding
|
||||
from app.observability.metrics import record_gateway_auth_invariant_failure
|
||||
|
|
@ -39,11 +40,13 @@ async def assert_authorization_invariant(
|
|||
if user is None:
|
||||
await _fail(session, binding, "owner_missing")
|
||||
|
||||
auth = AuthContext.system(user, source="gateway")
|
||||
|
||||
try:
|
||||
await check_search_space_access(session, user, binding.search_space_id)
|
||||
await check_search_space_access(session, auth, binding.search_space_id)
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
binding.search_space_id,
|
||||
Permission.CHATS_CREATE.value,
|
||||
"External chat owner no longer has permission to chat in this search space",
|
||||
|
|
|
|||
|
|
@ -8,7 +8,8 @@ from fastapi import APIRouter, Depends, HTTPException, Query, status
|
|||
from sqlalchemy import case, desc, func, literal, literal_column, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import User, get_async_session
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import get_async_session
|
||||
from app.notifications.api.schemas import (
|
||||
BatchUnreadCountResponse,
|
||||
CategoryUnreadCount,
|
||||
|
|
@ -27,7 +28,7 @@ from app.notifications.api.transform import (
|
|||
from app.notifications.constants import CATEGORY_TYPES, SYNC_WINDOW_DAYS
|
||||
from app.notifications.persistence import Notification
|
||||
from app.notifications.types import NotificationCategory, NotificationType
|
||||
from app.users import current_active_user
|
||||
from app.users import require_session_context
|
||||
|
||||
router = APIRouter(prefix="/notifications", tags=["notifications"])
|
||||
|
||||
|
|
@ -35,10 +36,11 @@ router = APIRouter(prefix="/notifications", tags=["notifications"])
|
|||
@router.get("/unread-counts-batch", response_model=BatchUnreadCountResponse)
|
||||
async def get_unread_counts_batch(
|
||||
search_space_id: int | None = Query(None, description="Filter by search space ID"),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> BatchUnreadCountResponse:
|
||||
"""Unread counts for every category in a single query."""
|
||||
user = auth.user
|
||||
cutoff_date = datetime.now(UTC) - timedelta(days=SYNC_WINDOW_DAYS)
|
||||
|
||||
base_filter = [
|
||||
|
|
@ -86,10 +88,11 @@ async def get_unread_counts_batch(
|
|||
@router.get("/source-types", response_model=SourceTypesResponse)
|
||||
async def get_notification_source_types(
|
||||
search_space_id: int | None = Query(None, description="Filter by search space ID"),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> SourceTypesResponse:
|
||||
"""Distinct connector/document source types for the Status tab filter."""
|
||||
user = auth.user
|
||||
base_filter = [Notification.user_id == user.id]
|
||||
|
||||
if search_space_id is not None:
|
||||
|
|
@ -160,7 +163,7 @@ async def get_unread_count(
|
|||
category: NotificationCategory | None = Query(
|
||||
None, description="Filter by category: 'comments' or 'status'"
|
||||
),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> UnreadCountResponse:
|
||||
"""Total and recent (within sync window) unread counts for the user.
|
||||
|
|
@ -168,6 +171,7 @@ async def get_unread_count(
|
|||
Returning both lets a client hold the older count static while
|
||||
live-syncing the recent ones.
|
||||
"""
|
||||
user = auth.user
|
||||
cutoff_date = datetime.now(UTC) - timedelta(days=SYNC_WINDOW_DAYS)
|
||||
|
||||
base_filter = [
|
||||
|
|
@ -230,10 +234,11 @@ async def list_notifications(
|
|||
),
|
||||
limit: int = Query(50, ge=1, le=100, description="Number of items to return"),
|
||||
offset: int = Query(0, ge=0, description="Number of items to skip"),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> NotificationListResponse:
|
||||
"""Paginated inbox fallback for items outside the Zero sync window."""
|
||||
user = auth.user
|
||||
query = select(Notification).where(Notification.user_id == user.id)
|
||||
count_query = select(func.count(Notification.id)).where(
|
||||
Notification.user_id == user.id
|
||||
|
|
@ -328,10 +333,11 @@ async def list_notifications(
|
|||
@router.patch("/{notification_id}/read", response_model=MarkReadResponse)
|
||||
async def mark_notification_as_read(
|
||||
notification_id: int,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> MarkReadResponse:
|
||||
"""Mark one of the user's notifications read; Zero syncs the change."""
|
||||
user = auth.user
|
||||
# Scope to the caller's own notifications.
|
||||
result = await session.execute(
|
||||
select(Notification).where(
|
||||
|
|
@ -364,10 +370,11 @@ async def mark_notification_as_read(
|
|||
|
||||
@router.patch("/read-all", response_model=MarkAllReadResponse)
|
||||
async def mark_all_notifications_as_read(
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> MarkAllReadResponse:
|
||||
"""Mark all of the user's notifications read; Zero syncs the changes."""
|
||||
user = auth.user
|
||||
result = await session.execute(
|
||||
update(Notification)
|
||||
.where(
|
||||
|
|
|
|||
|
|
@ -18,12 +18,12 @@ from fastapi.responses import StreamingResponse
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.config import config as app_config
|
||||
from app.db import (
|
||||
Permission,
|
||||
SearchSpace,
|
||||
SearchSpaceMembership,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.podcasts.generation.brief import propose_brief
|
||||
|
|
@ -42,7 +42,7 @@ from app.podcasts.voices import (
|
|||
provider_from_service,
|
||||
render_voice_preview,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.users import get_auth_context
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
from .schemas import (
|
||||
|
|
@ -63,13 +63,14 @@ async def list_podcasts(
|
|||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
if skip < 0 or limit < 1:
|
||||
raise HTTPException(status_code=400, detail="Invalid pagination parameters")
|
||||
|
||||
if search_space_id is not None:
|
||||
await _require(session, user, search_space_id, Permission.PODCASTS_READ)
|
||||
await _require(session, auth, search_space_id, Permission.PODCASTS_READ)
|
||||
query = (
|
||||
select(Podcast)
|
||||
.where(Podcast.search_space_id == search_space_id)
|
||||
|
|
@ -132,7 +133,7 @@ async def list_languages():
|
|||
@router.get("/podcasts/voices/{voice_id}/preview")
|
||||
async def preview_voice(
|
||||
voice_id: str,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
"""A short audio sample of a voice, so users pick by sound."""
|
||||
if not app_config.TTS_SERVICE:
|
||||
|
|
@ -156,9 +157,9 @@ async def preview_voice(
|
|||
async def create_podcast(
|
||||
body: CreatePodcastRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
await _require(session, user, body.search_space_id, Permission.PODCASTS_CREATE)
|
||||
await _require(session, auth, body.search_space_id, Permission.PODCASTS_CREATE)
|
||||
|
||||
service = PodcastService(session)
|
||||
podcast = await service.create(
|
||||
|
|
@ -185,9 +186,9 @@ async def create_podcast(
|
|||
async def get_podcast(
|
||||
podcast_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_READ)
|
||||
podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_READ)
|
||||
return PodcastDetail.of(podcast)
|
||||
|
||||
|
||||
|
|
@ -196,9 +197,9 @@ async def update_spec(
|
|||
podcast_id: int,
|
||||
body: UpdateSpecRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE)
|
||||
podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_UPDATE)
|
||||
async with _lifecycle_errors():
|
||||
await PodcastService(session).update_spec(
|
||||
podcast, body.spec, body.expected_version
|
||||
|
|
@ -211,10 +212,10 @@ async def update_spec(
|
|||
async def approve_brief(
|
||||
podcast_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
"""Approve the brief and start drafting the transcript."""
|
||||
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE)
|
||||
podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_UPDATE)
|
||||
async with _lifecycle_errors():
|
||||
await PodcastService(session).begin_drafting(podcast)
|
||||
await session.commit()
|
||||
|
|
@ -228,10 +229,10 @@ async def approve_brief(
|
|||
async def regenerate_transcript(
|
||||
podcast_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
"""Reopen the brief gate for a fresh take; drafting waits for re-approval."""
|
||||
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE)
|
||||
podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_UPDATE)
|
||||
async with _lifecycle_errors():
|
||||
await PodcastService(session).regenerate(podcast)
|
||||
await session.commit()
|
||||
|
|
@ -242,10 +243,10 @@ async def regenerate_transcript(
|
|||
async def revert_regeneration(
|
||||
podcast_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
"""Back out of a regeneration and return to the finished episode."""
|
||||
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE)
|
||||
podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_UPDATE)
|
||||
async with _lifecycle_errors():
|
||||
await PodcastService(session).revert_regeneration(podcast)
|
||||
await session.commit()
|
||||
|
|
@ -256,9 +257,9 @@ async def revert_regeneration(
|
|||
async def cancel_podcast(
|
||||
podcast_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE)
|
||||
podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_UPDATE)
|
||||
async with _lifecycle_errors():
|
||||
await PodcastService(session).cancel(podcast)
|
||||
await session.commit()
|
||||
|
|
@ -269,9 +270,9 @@ async def cancel_podcast(
|
|||
async def delete_podcast(
|
||||
podcast_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_DELETE)
|
||||
podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_DELETE)
|
||||
await purge_audio(podcast)
|
||||
await session.delete(podcast)
|
||||
await session.commit()
|
||||
|
|
@ -282,9 +283,9 @@ async def delete_podcast(
|
|||
async def stream_podcast(
|
||||
podcast_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_READ)
|
||||
podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_READ)
|
||||
|
||||
if podcast.storage_key:
|
||||
# Verify first so a missing object is a 404, not a mid-stream crash.
|
||||
|
|
@ -323,13 +324,13 @@ async def stream_podcast(
|
|||
|
||||
async def _require(
|
||||
session: AsyncSession,
|
||||
user: User,
|
||||
auth: AuthContext,
|
||||
search_space_id: int,
|
||||
permission: Permission,
|
||||
) -> None:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
permission.value,
|
||||
"You don't have permission for podcasts in this search space",
|
||||
|
|
@ -338,14 +339,14 @@ async def _require(
|
|||
|
||||
async def _load(
|
||||
session: AsyncSession,
|
||||
user: User,
|
||||
auth: AuthContext,
|
||||
podcast_id: int,
|
||||
permission: Permission,
|
||||
) -> Podcast:
|
||||
podcast = await PodcastRepository(session).get(podcast_id)
|
||||
if podcast is None:
|
||||
raise HTTPException(status_code=404, detail="Podcast not found")
|
||||
await _require(session, user, podcast.search_space_id, permission)
|
||||
await _require(session, auth, podcast.search_space_id, permission)
|
||||
return podcast
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ from .notes_routes import router as notes_router
|
|||
from .notion_add_connector_route import router as notion_add_connector_router
|
||||
from .obsidian_plugin_routes import router as obsidian_plugin_router
|
||||
from .onedrive_add_connector_route import router as onedrive_add_connector_router
|
||||
from .personal_access_tokens_routes import router as personal_access_tokens_router
|
||||
from .prompts_routes import router as prompts_router
|
||||
from .public_chat_routes import router as public_chat_router
|
||||
from .rbac_routes import router as rbac_router
|
||||
|
|
@ -113,6 +114,7 @@ router.include_router(slack_add_connector_router)
|
|||
router.include_router(teams_add_connector_router)
|
||||
router.include_router(onedrive_add_connector_router)
|
||||
router.include_router(obsidian_plugin_router) # Obsidian plugin push API
|
||||
router.include_router(personal_access_tokens_router) # Personal access token manager
|
||||
router.include_router(discord_add_connector_router)
|
||||
router.include_router(jira_add_connector_router)
|
||||
router.include_router(confluence_add_connector_router)
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ from pydantic import BaseModel
|
|||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.agents.chat.multi_agent_chat.shared.feature_flags import get_flags
|
||||
from app.db import (
|
||||
AgentActionLog,
|
||||
|
|
@ -36,7 +37,7 @@ from app.db import (
|
|||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.users import get_auth_context
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -111,8 +112,9 @@ async def list_thread_actions(
|
|||
page: int = Query(0, ge=0),
|
||||
page_size: int = Query(50, ge=1, le=200),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> AgentActionListResponse:
|
||||
user = auth.user
|
||||
"""List agent actions for a thread, newest first.
|
||||
|
||||
Authorization:
|
||||
|
|
@ -132,7 +134,7 @@ async def list_thread_actions(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_READ.value,
|
||||
"You don't have permission to view this thread's action log.",
|
||||
|
|
|
|||
|
|
@ -26,9 +26,9 @@ from app.agents.chat.multi_agent_chat.shared.feature_flags import (
|
|||
AgentFeatureFlags,
|
||||
get_flags,
|
||||
)
|
||||
from app.auth.context import AuthContext
|
||||
from app.config import config
|
||||
from app.db import User
|
||||
from app.users import current_active_user
|
||||
from app.users import require_session_context
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -75,6 +75,6 @@ class AgentFeatureFlagsRead(BaseModel):
|
|||
|
||||
@router.get("/agent/flags", response_model=AgentFeatureFlagsRead)
|
||||
async def get_agent_flags(
|
||||
_user: User = Depends(current_active_user),
|
||||
_auth: AuthContext = Depends(require_session_context),
|
||||
) -> AgentFeatureFlagsRead:
|
||||
return AgentFeatureFlagsRead.from_flags(get_flags())
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from sqlalchemy import select
|
|||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.agents.chat.multi_agent_chat.shared.feature_flags import get_flags
|
||||
from app.db import (
|
||||
AgentPermissionRule,
|
||||
|
|
@ -39,7 +40,7 @@ from app.db import (
|
|||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.users import get_auth_context
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -133,15 +134,16 @@ def _to_read(row: AgentPermissionRule) -> AgentPermissionRuleRead:
|
|||
|
||||
|
||||
async def _ensure_search_space_membership_admin(
|
||||
session: AsyncSession, user: User, search_space_id: int
|
||||
session: AsyncSession, auth: AuthContext, search_space_id: int
|
||||
) -> None:
|
||||
user = auth.user
|
||||
"""Curating agent rules == "settings" administration on the space."""
|
||||
space = await session.get(SearchSpace, search_space_id)
|
||||
if space is None:
|
||||
raise HTTPException(status_code=404, detail="Search space not found.")
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.SETTINGS_UPDATE.value,
|
||||
"You don't have permission to manage agent permission rules in this space.",
|
||||
|
|
@ -160,8 +162,9 @@ async def _ensure_search_space_membership_admin(
|
|||
async def list_rules(
|
||||
search_space_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> list[AgentPermissionRuleRead]:
|
||||
user = auth.user
|
||||
_flag_guard()
|
||||
await _ensure_search_space_membership_admin(session, user, search_space_id)
|
||||
|
||||
|
|
@ -183,8 +186,9 @@ async def create_rule(
|
|||
search_space_id: int,
|
||||
payload: AgentPermissionRuleCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> AgentPermissionRuleRead:
|
||||
user = auth.user
|
||||
_flag_guard()
|
||||
await _ensure_search_space_membership_admin(session, user, search_space_id)
|
||||
|
||||
|
|
@ -232,8 +236,9 @@ async def update_rule(
|
|||
rule_id: int,
|
||||
payload: AgentPermissionRuleUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> AgentPermissionRuleRead:
|
||||
user = auth.user
|
||||
_flag_guard()
|
||||
await _ensure_search_space_membership_admin(session, user, search_space_id)
|
||||
|
||||
|
|
@ -266,8 +271,9 @@ async def delete_rule(
|
|||
search_space_id: int,
|
||||
rule_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> None:
|
||||
user = auth.user
|
||||
_flag_guard()
|
||||
await _ensure_search_space_membership_admin(session, user, search_space_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ here" affordance. To prevent accidental usage during the gap we return
|
|||
``503 Service Unavailable`` until the ``SURFSENSE_ENABLE_REVERT_ROUTE``
|
||||
flag flips. Once enabled, the route runs:
|
||||
|
||||
1. Authentication via :func:`current_active_user`.
|
||||
1. Authentication via an interactive session context.
|
||||
2. Action lookup; 404 if the action does not belong to the thread.
|
||||
3. Authorization via :func:`app.services.revert_service.can_revert`.
|
||||
4. Revert dispatch via :func:`app.services.revert_service.revert_action`.
|
||||
|
|
@ -33,9 +33,9 @@ from sqlalchemy.exc import IntegrityError
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.chat.multi_agent_chat.shared.feature_flags import get_flags
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import (
|
||||
AgentActionLog,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.services.revert_service import (
|
||||
|
|
@ -45,7 +45,7 @@ from app.services.revert_service import (
|
|||
load_thread,
|
||||
revert_action,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.users import require_session_context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -57,8 +57,9 @@ async def revert_agent_action(
|
|||
thread_id: int,
|
||||
action_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
) -> dict:
|
||||
user = auth.user
|
||||
flags = get_flags()
|
||||
if flags.disable_new_agent_stack or not flags.enable_revert_route:
|
||||
raise HTTPException(
|
||||
|
|
@ -269,7 +270,7 @@ async def revert_agent_turn(
|
|||
thread_id: int,
|
||||
chat_turn_id: str,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
) -> RevertTurnResponse:
|
||||
"""Revert every reversible action emitted during ``chat_turn_id``.
|
||||
|
||||
|
|
@ -281,6 +282,7 @@ async def revert_agent_turn(
|
|||
Partial success is intentional and returned with HTTP 200. Callers
|
||||
must inspect ``results[*].status`` to find rows that need attention.
|
||||
"""
|
||||
user = auth.user
|
||||
|
||||
flags = get_flags()
|
||||
if flags.disable_new_agent_stack or not flags.enable_revert_route:
|
||||
|
|
|
|||
|
|
@ -10,16 +10,16 @@ from pydantic import ValidationError
|
|||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.config import config
|
||||
from app.connectors.airtable_connector import fetch_airtable_user_email
|
||||
from app.db import (
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.schemas.airtable_auth_credentials import AirtableAuthCredentialsBase
|
||||
from app.users import current_active_user
|
||||
from app.users import require_session_context
|
||||
from app.utils.connector_naming import (
|
||||
check_duplicate_connector,
|
||||
generate_unique_connector_name,
|
||||
|
|
@ -78,7 +78,10 @@ def make_basic_auth_header(client_id: str, client_secret: str) -> str:
|
|||
|
||||
|
||||
@router.get("/auth/airtable/connector/add")
|
||||
async def connect_airtable(space_id: int, user: User = Depends(current_active_user)):
|
||||
async def connect_airtable(
|
||||
space_id: int,
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
"""
|
||||
Initiate Airtable OAuth flow.
|
||||
|
||||
|
|
@ -89,6 +92,7 @@ async def connect_airtable(space_id: int, user: User = Depends(current_active_us
|
|||
Returns:
|
||||
Authorization URL for redirect
|
||||
"""
|
||||
user = auth.user
|
||||
try:
|
||||
if not space_id:
|
||||
raise HTTPException(status_code=400, detail="space_id is required")
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import logging
|
|||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import User, async_session_maker
|
||||
from app.schemas.auth import (
|
||||
LogoutAllResponse,
|
||||
|
|
@ -13,7 +14,7 @@ from app.schemas.auth import (
|
|||
RefreshTokenRequest,
|
||||
RefreshTokenResponse,
|
||||
)
|
||||
from app.users import current_active_user, get_jwt_strategy
|
||||
from app.users import get_jwt_strategy, require_session_context
|
||||
from app.utils.refresh_tokens import (
|
||||
revoke_all_user_tokens,
|
||||
revoke_refresh_token,
|
||||
|
|
@ -83,11 +84,14 @@ async def revoke_token(request: LogoutRequest):
|
|||
|
||||
|
||||
@router.post("/logout-all", response_model=LogoutAllResponse)
|
||||
async def logout_all_devices(user: User = Depends(current_active_user)):
|
||||
async def logout_all_devices(
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
"""
|
||||
Logout from all devices by revoking all refresh tokens for the user.
|
||||
Requires valid access token.
|
||||
"""
|
||||
user = auth.user
|
||||
await revoke_all_user_tokens(user.id)
|
||||
logger.info(f"User {user.id} logged out from all devices")
|
||||
return LogoutAllResponse()
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@ Routes for chat comments and mentions.
|
|||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import User, get_async_session
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import get_async_session
|
||||
from app.schemas.chat_comments import (
|
||||
CommentBatchRequest,
|
||||
CommentBatchResponse,
|
||||
|
|
@ -25,7 +26,7 @@ from app.services.chat_comments_service import (
|
|||
get_user_mentions,
|
||||
update_comment,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.users import require_session_context
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -34,20 +35,20 @@ router = APIRouter()
|
|||
async def batch_list_comments(
|
||||
request: CommentBatchRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
"""Batch-fetch comments for multiple messages in one request."""
|
||||
return await get_comments_for_messages_batch(session, request.message_ids, user)
|
||||
return await get_comments_for_messages_batch(session, request.message_ids, auth)
|
||||
|
||||
|
||||
@router.get("/messages/{message_id}/comments", response_model=CommentListResponse)
|
||||
async def list_comments(
|
||||
message_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
"""List all comments for a message with their replies."""
|
||||
return await get_comments_for_message(session, message_id, user)
|
||||
return await get_comments_for_message(session, message_id, auth)
|
||||
|
||||
|
||||
@router.post("/messages/{message_id}/comments", response_model=CommentResponse)
|
||||
|
|
@ -55,10 +56,10 @@ async def add_comment(
|
|||
message_id: int,
|
||||
request: CommentCreateRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
"""Create a top-level comment on an AI response."""
|
||||
return await create_comment(session, message_id, request.content, user)
|
||||
return await create_comment(session, message_id, request.content, auth)
|
||||
|
||||
|
||||
@router.post("/comments/{comment_id}/replies", response_model=CommentReplyResponse)
|
||||
|
|
@ -66,10 +67,10 @@ async def add_reply(
|
|||
comment_id: int,
|
||||
request: CommentCreateRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
"""Reply to an existing comment."""
|
||||
return await create_reply(session, comment_id, request.content, user)
|
||||
return await create_reply(session, comment_id, request.content, auth)
|
||||
|
||||
|
||||
@router.put("/comments/{comment_id}", response_model=CommentReplyResponse)
|
||||
|
|
@ -77,20 +78,20 @@ async def edit_comment(
|
|||
comment_id: int,
|
||||
request: CommentUpdateRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
"""Update a comment's content (author only)."""
|
||||
return await update_comment(session, comment_id, request.content, user)
|
||||
return await update_comment(session, comment_id, request.content, auth)
|
||||
|
||||
|
||||
@router.delete("/comments/{comment_id}")
|
||||
async def remove_comment(
|
||||
comment_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
"""Delete a comment (author or user with COMMENTS_DELETE permission)."""
|
||||
return await delete_comment(session, comment_id, user)
|
||||
return await delete_comment(session, comment_id, auth)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
|
@ -102,7 +103,7 @@ async def remove_comment(
|
|||
async def list_mentions(
|
||||
search_space_id: int | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
"""List mentions for the current user."""
|
||||
return await get_user_mentions(session, user, search_space_id)
|
||||
return await get_user_mentions(session, auth, search_space_id)
|
||||
|
|
|
|||
|
|
@ -16,15 +16,15 @@ from sqlalchemy.exc import IntegrityError
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.schemas.clickup_auth_credentials import ClickUpAuthCredentialsBase
|
||||
from app.users import current_active_user
|
||||
from app.users import require_session_context
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -61,7 +61,10 @@ def get_token_encryption() -> TokenEncryption:
|
|||
|
||||
|
||||
@router.get("/auth/clickup/connector/add")
|
||||
async def connect_clickup(space_id: int, user: User = Depends(current_active_user)):
|
||||
async def connect_clickup(
|
||||
space_id: int,
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
"""
|
||||
Initiate ClickUp OAuth flow.
|
||||
|
||||
|
|
@ -72,6 +75,7 @@ async def connect_clickup(space_id: int, user: User = Depends(current_active_use
|
|||
Returns:
|
||||
Authorization URL for redirect
|
||||
"""
|
||||
user = auth.user
|
||||
try:
|
||||
if not space_id:
|
||||
raise HTTPException(status_code=400, detail="space_id is required")
|
||||
|
|
|
|||
|
|
@ -22,11 +22,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.services.composio_service import (
|
||||
|
|
@ -35,12 +35,13 @@ from app.services.composio_service import (
|
|||
TOOLKIT_TO_CONNECTOR_TYPE,
|
||||
ComposioService,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.users import get_auth_context, require_session_context
|
||||
from app.utils.connector_naming import (
|
||||
count_connectors_of_type,
|
||||
get_base_name_for_type,
|
||||
)
|
||||
from app.utils.oauth_security import OAuthStateManager
|
||||
from app.utils.rbac import check_search_space_access
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -68,13 +69,16 @@ def get_state_manager() -> OAuthStateManager:
|
|||
|
||||
|
||||
@router.get("/composio/toolkits")
|
||||
async def list_composio_toolkits(user: User = Depends(current_active_user)):
|
||||
async def list_composio_toolkits(
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
"""
|
||||
List available Composio toolkits.
|
||||
|
||||
Returns:
|
||||
JSON with list of available toolkits and their metadata.
|
||||
"""
|
||||
del auth
|
||||
if not ComposioService.is_enabled():
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
|
|
@ -98,7 +102,7 @@ async def initiate_composio_auth(
|
|||
toolkit_id: str = Query(
|
||||
..., description="Composio toolkit ID (e.g., 'googledrive', 'gmail')"
|
||||
),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
"""
|
||||
Initiate Composio OAuth flow for a specific toolkit.
|
||||
|
|
@ -110,6 +114,7 @@ async def initiate_composio_auth(
|
|||
Returns:
|
||||
JSON with auth_url to redirect user to Composio authorization
|
||||
"""
|
||||
user = auth.user
|
||||
if not ComposioService.is_enabled():
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
|
|
@ -446,7 +451,7 @@ async def reauth_composio_connector(
|
|||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""
|
||||
|
|
@ -460,6 +465,7 @@ async def reauth_composio_connector(
|
|||
connector_id: ID of the existing Composio connector to re-authenticate
|
||||
return_url: Optional frontend path to redirect to after completion
|
||||
"""
|
||||
user = auth.user
|
||||
if not ComposioService.is_enabled():
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Composio integration is not enabled."
|
||||
|
|
@ -644,7 +650,7 @@ async def list_composio_drive_folders(
|
|||
connector_id: int,
|
||||
parent_id: str | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
"""
|
||||
List folders AND files in user's Google Drive via Composio.
|
||||
|
|
@ -659,6 +665,7 @@ async def list_composio_drive_folders(
|
|||
)
|
||||
|
||||
connector = None
|
||||
user = auth.user
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
|
|
@ -676,6 +683,8 @@ async def list_composio_drive_folders(
|
|||
detail="Composio Google Drive connector not found or access denied",
|
||||
)
|
||||
|
||||
await check_search_space_access(session, auth, connector.search_space_id)
|
||||
|
||||
composio_connected_account_id = connector.config.get(
|
||||
"composio_connected_account_id"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -15,15 +15,15 @@ from pydantic import ValidationError
|
|||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.schemas.atlassian_auth_credentials import AtlassianAuthCredentialsBase
|
||||
from app.users import current_active_user
|
||||
from app.users import require_session_context
|
||||
from app.utils.connector_naming import (
|
||||
check_duplicate_connector,
|
||||
extract_identifier_from_credentials,
|
||||
|
|
@ -77,7 +77,10 @@ def get_token_encryption() -> TokenEncryption:
|
|||
|
||||
|
||||
@router.get("/auth/confluence/connector/add")
|
||||
async def connect_confluence(space_id: int, user: User = Depends(current_active_user)):
|
||||
async def connect_confluence(
|
||||
space_id: int,
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
"""
|
||||
Initiate Confluence OAuth flow.
|
||||
|
||||
|
|
@ -88,6 +91,7 @@ async def connect_confluence(space_id: int, user: User = Depends(current_active_
|
|||
Returns:
|
||||
Authorization URL for redirect
|
||||
"""
|
||||
user = auth.user
|
||||
try:
|
||||
if not space_id:
|
||||
raise HTTPException(status_code=400, detail="space_id is required")
|
||||
|
|
@ -421,10 +425,11 @@ async def reauth_confluence(
|
|||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Initiate Confluence re-authentication to upgrade OAuth scopes."""
|
||||
user = auth.user
|
||||
try:
|
||||
from sqlalchemy.future import select
|
||||
|
||||
|
|
|
|||
|
|
@ -15,21 +15,22 @@ from pydantic import ValidationError
|
|||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.schemas.discord_auth_credentials import DiscordAuthCredentialsBase
|
||||
from app.users import current_active_user
|
||||
from app.users import get_auth_context, require_session_context
|
||||
from app.utils.connector_naming import (
|
||||
check_duplicate_connector,
|
||||
extract_identifier_from_credentials,
|
||||
generate_unique_connector_name,
|
||||
)
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
|
||||
from app.utils.rbac import check_search_space_access
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -77,7 +78,10 @@ def get_token_encryption() -> TokenEncryption:
|
|||
|
||||
|
||||
@router.get("/auth/discord/connector/add")
|
||||
async def connect_discord(space_id: int, user: User = Depends(current_active_user)):
|
||||
async def connect_discord(
|
||||
space_id: int,
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
"""
|
||||
Initiate Discord OAuth flow.
|
||||
|
||||
|
|
@ -88,6 +92,7 @@ async def connect_discord(space_id: int, user: User = Depends(current_active_use
|
|||
Returns:
|
||||
Authorization URL for redirect
|
||||
"""
|
||||
user = auth.user
|
||||
try:
|
||||
if not space_id:
|
||||
raise HTTPException(status_code=400, detail="space_id is required")
|
||||
|
|
@ -610,7 +615,7 @@ def _compute_channel_permissions(
|
|||
async def get_discord_channels(
|
||||
connector_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
"""
|
||||
Get list of Discord text channels for a connector with permission info.
|
||||
|
|
@ -628,6 +633,7 @@ async def get_discord_channels(
|
|||
"""
|
||||
from sqlalchemy import select
|
||||
|
||||
user = auth.user
|
||||
try:
|
||||
# Get connector and verify ownership
|
||||
result = await session.execute(
|
||||
|
|
@ -646,6 +652,8 @@ async def get_discord_channels(
|
|||
detail="Discord connector not found or access denied",
|
||||
)
|
||||
|
||||
await check_search_space_access(session, auth, connector.search_space_id)
|
||||
|
||||
# Get credentials and decrypt bot token
|
||||
credentials = DiscordAuthCredentialsBase.from_dict(connector.config)
|
||||
token_encryption = get_token_encryption()
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.agents.chat.runtime.path_resolver import virtual_path_to_doc
|
||||
from app.db import (
|
||||
Chunk,
|
||||
|
|
@ -35,7 +36,7 @@ from app.schemas import (
|
|||
PaginatedResponse,
|
||||
)
|
||||
from app.services.task_dispatcher import TaskDispatcher, get_task_dispatcher
|
||||
from app.users import current_active_user
|
||||
from app.users import get_auth_context
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
try:
|
||||
|
|
@ -60,8 +61,9 @@ MAX_FILE_SIZE_BYTES = 500 * 1024 * 1024 # 500 MB per file
|
|||
async def create_documents(
|
||||
request: DocumentsCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Create new documents.
|
||||
Requires DOCUMENTS_CREATE permission.
|
||||
|
|
@ -70,7 +72,7 @@ async def create_documents(
|
|||
# Check permission
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
request.search_space_id,
|
||||
Permission.DOCUMENTS_CREATE.value,
|
||||
"You don't have permission to create documents in this search space",
|
||||
|
|
@ -128,9 +130,10 @@ async def create_documents_file_upload(
|
|||
use_vision_llm: bool = Form(False),
|
||||
processing_mode: str = Form("basic"),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
dispatcher: TaskDispatcher = Depends(get_task_dispatcher),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Upload files as documents with real-time status tracking.
|
||||
|
||||
|
|
@ -159,7 +162,7 @@ async def create_documents_file_upload(
|
|||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.DOCUMENTS_CREATE.value,
|
||||
"You don't have permission to create documents in this search space",
|
||||
|
|
@ -340,8 +343,9 @@ async def read_documents(
|
|||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc",
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
List documents the user has access to, with optional filtering and pagination.
|
||||
Requires DOCUMENTS_READ permission for the search space(s).
|
||||
|
|
@ -369,7 +373,7 @@ async def read_documents(
|
|||
if search_space_id is not None:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.DOCUMENTS_READ.value,
|
||||
"You don't have permission to read documents in this search space",
|
||||
|
|
@ -519,8 +523,9 @@ async def search_documents(
|
|||
search_space_id: int | None = None,
|
||||
document_types: str | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Search documents by title substring, optionally filtered by search_space_id and document_types.
|
||||
Requires DOCUMENTS_READ permission for the search space(s).
|
||||
|
|
@ -549,7 +554,7 @@ async def search_documents(
|
|||
if search_space_id is not None:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.DOCUMENTS_READ.value,
|
||||
"You don't have permission to read documents in this search space",
|
||||
|
|
@ -677,8 +682,9 @@ async def search_document_titles(
|
|||
page: int = 0,
|
||||
page_size: int = 20,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Lightweight document title search optimized for mention picker (@mentions).
|
||||
|
||||
|
|
@ -703,7 +709,7 @@ async def search_document_titles(
|
|||
# Check permission for the search space
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.DOCUMENTS_READ.value,
|
||||
"You don't have permission to read documents in this search space",
|
||||
|
|
@ -781,8 +787,9 @@ async def get_document_by_virtual_path(
|
|||
search_space_id: int,
|
||||
virtual_path: str,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Resolve a knowledge-base document by its agent-facing virtual path.
|
||||
|
||||
The agent renders every document under ``/documents/...`` with a
|
||||
|
|
@ -804,7 +811,7 @@ async def get_document_by_virtual_path(
|
|||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.DOCUMENTS_READ.value,
|
||||
"You don't have permission to read documents in this search space",
|
||||
|
|
@ -838,8 +845,9 @@ async def get_documents_status(
|
|||
search_space_id: int,
|
||||
document_ids: str,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Batch status endpoint for documents in a search space.
|
||||
|
||||
|
|
@ -849,7 +857,7 @@ async def get_documents_status(
|
|||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.DOCUMENTS_READ.value,
|
||||
"You don't have permission to read documents in this search space",
|
||||
|
|
@ -905,8 +913,9 @@ async def get_documents_status(
|
|||
async def get_document_type_counts(
|
||||
search_space_id: int | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Get counts of documents by type for search spaces the user has access to.
|
||||
Requires DOCUMENTS_READ permission for the search space(s).
|
||||
|
|
@ -926,7 +935,7 @@ async def get_document_type_counts(
|
|||
# Check permission for specific search space
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.DOCUMENTS_READ.value,
|
||||
"You don't have permission to read documents in this search space",
|
||||
|
|
@ -965,7 +974,7 @@ async def get_document_by_chunk_id(
|
|||
5, ge=0, description="Number of chunks before/after the cited chunk to include"
|
||||
),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
"""
|
||||
Retrieves a document based on a chunk ID, including a window of chunks around the cited one.
|
||||
|
|
@ -995,7 +1004,7 @@ async def get_document_by_chunk_id(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
document.search_space_id,
|
||||
Permission.DOCUMENTS_READ.value,
|
||||
"You don't have permission to read documents in this search space",
|
||||
|
|
@ -1060,12 +1069,13 @@ async def get_document_by_chunk_id(
|
|||
async def get_watched_folders(
|
||||
search_space_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Return root folders that are marked as watched (metadata->>'watched' = 'true')."""
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.DOCUMENTS_READ.value,
|
||||
"You don't have permission to read documents in this search space",
|
||||
|
|
@ -1101,8 +1111,9 @@ async def get_document_chunks_paginated(
|
|||
None, ge=0, description="Direct offset; overrides page * page_size"
|
||||
),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Paginated chunk loading for a document.
|
||||
Supports both page-based and offset-based access.
|
||||
|
|
@ -1120,7 +1131,7 @@ async def get_document_chunks_paginated(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
document.search_space_id,
|
||||
Permission.DOCUMENTS_READ.value,
|
||||
"You don't have permission to read documents in this search space",
|
||||
|
|
@ -1162,8 +1173,9 @@ async def get_document_chunks_paginated(
|
|||
async def read_document(
|
||||
document_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Get a specific document by ID.
|
||||
Requires DOCUMENTS_READ permission for the search space.
|
||||
|
|
@ -1182,7 +1194,7 @@ async def read_document(
|
|||
# Check permission for the search space
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
document.search_space_id,
|
||||
Permission.DOCUMENTS_READ.value,
|
||||
"You don't have permission to read documents in this search space",
|
||||
|
|
@ -1216,8 +1228,9 @@ async def update_document(
|
|||
document_id: int,
|
||||
document_update: DocumentUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Update a document.
|
||||
Requires DOCUMENTS_UPDATE permission for the search space.
|
||||
|
|
@ -1236,7 +1249,7 @@ async def update_document(
|
|||
# Check permission for the search space
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
db_document.search_space_id,
|
||||
Permission.DOCUMENTS_UPDATE.value,
|
||||
"You don't have permission to update documents in this search space",
|
||||
|
|
@ -1275,8 +1288,9 @@ async def update_document(
|
|||
async def delete_document(
|
||||
document_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Delete a document.
|
||||
Requires DOCUMENTS_DELETE permission for the search space.
|
||||
|
|
@ -1311,7 +1325,7 @@ async def delete_document(
|
|||
# Check permission for the search space
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
document.search_space_id,
|
||||
Permission.DOCUMENTS_DELETE.value,
|
||||
"You don't have permission to delete documents in this search space",
|
||||
|
|
@ -1355,8 +1369,9 @@ async def delete_document(
|
|||
async def list_document_versions(
|
||||
document_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""List all versions for a document, ordered by version_number descending."""
|
||||
document = (
|
||||
await session.execute(select(Document).where(Document.id == document_id))
|
||||
|
|
@ -1396,8 +1411,9 @@ async def get_document_version(
|
|||
document_id: int,
|
||||
version_number: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Get full version content including source_markdown."""
|
||||
document = (
|
||||
await session.execute(select(Document).where(Document.id == document_id))
|
||||
|
|
@ -1434,8 +1450,9 @@ async def restore_document_version(
|
|||
document_id: int,
|
||||
version_number: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Restore a previous version: snapshot current state, then overwrite document content."""
|
||||
document = (
|
||||
await session.execute(select(Document).where(Document.id == document_id))
|
||||
|
|
@ -1517,8 +1534,9 @@ class FolderSyncFinalizeRequest(PydanticBaseModel):
|
|||
async def folder_mtime_check(
|
||||
request: FolderMtimeCheckRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Pre-upload optimization: check which files need uploading based on mtime.
|
||||
|
||||
Returns the subset of relative paths where the file is new or has a
|
||||
|
|
@ -1528,7 +1546,7 @@ async def folder_mtime_check(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
request.search_space_id,
|
||||
Permission.DOCUMENTS_CREATE.value,
|
||||
"You don't have permission to create documents in this search space",
|
||||
|
|
@ -1587,8 +1605,9 @@ async def folder_upload(
|
|||
use_vision_llm: bool = Form(False),
|
||||
processing_mode: str = Form("basic"),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Upload files from the desktop app for folder indexing.
|
||||
|
||||
Files are written to temp storage and dispatched to a Celery task.
|
||||
|
|
@ -1603,7 +1622,7 @@ async def folder_upload(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.DOCUMENTS_CREATE.value,
|
||||
"You don't have permission to create documents in this search space",
|
||||
|
|
@ -1733,8 +1752,9 @@ async def folder_upload(
|
|||
async def folder_unlink(
|
||||
request: FolderUnlinkRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Handle file deletion events from the desktop watcher.
|
||||
|
||||
For each relative path, find the matching document and delete it.
|
||||
|
|
@ -1746,7 +1766,7 @@ async def folder_unlink(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
request.search_space_id,
|
||||
Permission.DOCUMENTS_DELETE.value,
|
||||
"You don't have permission to delete documents in this search space",
|
||||
|
|
@ -1787,8 +1807,9 @@ async def folder_unlink(
|
|||
async def folder_sync_finalize(
|
||||
request: FolderSyncFinalizeRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Finalize a full folder scan by deleting orphaned documents.
|
||||
|
||||
The client sends the complete list of relative paths currently in the
|
||||
|
|
@ -1803,7 +1824,7 @@ async def folder_sync_finalize(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
request.search_space_id,
|
||||
Permission.DOCUMENTS_DELETE.value,
|
||||
"You don't have permission to delete documents in this search space",
|
||||
|
|
|
|||
|
|
@ -21,21 +21,22 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.config import config
|
||||
from app.connectors.dropbox import DropboxClient, list_folder_contents
|
||||
from app.db import (
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.users import get_auth_context, require_session_context
|
||||
from app.utils.connector_naming import (
|
||||
check_duplicate_connector,
|
||||
extract_identifier_from_credentials,
|
||||
generate_unique_connector_name,
|
||||
)
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
|
||||
from app.utils.rbac import check_search_space_access
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
|
@ -66,8 +67,12 @@ def get_token_encryption() -> TokenEncryption:
|
|||
|
||||
|
||||
@router.get("/auth/dropbox/connector/add")
|
||||
async def connect_dropbox(space_id: int, user: User = Depends(current_active_user)):
|
||||
async def connect_dropbox(
|
||||
space_id: int,
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
"""Initiate Dropbox OAuth flow."""
|
||||
user = auth.user
|
||||
try:
|
||||
if not space_id:
|
||||
raise HTTPException(status_code=400, detail="space_id is required")
|
||||
|
|
@ -109,10 +114,11 @@ async def reauth_dropbox(
|
|||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Re-authenticate an existing Dropbox connector."""
|
||||
user = auth.user
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
|
|
@ -405,10 +411,11 @@ async def list_dropbox_folders(
|
|||
connector_id: int,
|
||||
parent_path: str = "",
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
"""List folders and files in user's Dropbox."""
|
||||
connector = None
|
||||
user = auth.user
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
|
|
@ -424,6 +431,8 @@ async def list_dropbox_folders(
|
|||
status_code=404, detail="Dropbox connector not found or access denied"
|
||||
)
|
||||
|
||||
await check_search_space_access(session, auth, connector.search_space_id)
|
||||
|
||||
dropbox_client = DropboxClient(session, connector_id)
|
||||
items, error = await list_folder_contents(dropbox_client, path=parent_path)
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from fastapi.responses import StreamingResponse
|
|||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import Chunk, Document, DocumentType, Permission, User, get_async_session
|
||||
from app.routes.reports_routes import (
|
||||
_FILE_EXTENSIONS,
|
||||
|
|
@ -31,7 +32,7 @@ from app.templates.export_helpers import (
|
|||
get_reference_docx_path,
|
||||
get_typst_template_path,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.users import get_auth_context
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -47,8 +48,9 @@ async def get_editor_content(
|
|||
search_space_id: int,
|
||||
document_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Get document content for editing.
|
||||
|
||||
|
|
@ -60,7 +62,7 @@ async def get_editor_content(
|
|||
# Check RBAC permission
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.DOCUMENTS_READ.value,
|
||||
"You don't have permission to read documents in this search space",
|
||||
|
|
@ -178,15 +180,16 @@ async def download_document_markdown(
|
|||
search_space_id: int,
|
||||
document_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Download the full document content as a .md file.
|
||||
Reconstructs markdown from source_markdown or chunks.
|
||||
"""
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.DOCUMENTS_READ.value,
|
||||
"You don't have permission to read documents in this search space",
|
||||
|
|
@ -244,8 +247,9 @@ async def save_document(
|
|||
document_id: int,
|
||||
data: dict[str, Any],
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Save document markdown and trigger reindexing.
|
||||
Called when user clicks 'Save & Exit'.
|
||||
|
|
@ -259,7 +263,7 @@ async def save_document(
|
|||
# Check RBAC permission
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.DOCUMENTS_UPDATE.value,
|
||||
"You don't have permission to update documents in this search space",
|
||||
|
|
@ -331,12 +335,13 @@ async def export_document(
|
|||
description="Export format: pdf, docx, html, latex, epub, odt, or plain",
|
||||
),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Export a document in the requested format (reuses the report export pipeline)."""
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.DOCUMENTS_READ.value,
|
||||
"You don't have permission to read documents in this search space",
|
||||
|
|
|
|||
|
|
@ -7,9 +7,10 @@ from fastapi import APIRouter, Depends, HTTPException, Query
|
|||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import Permission, User, get_async_session
|
||||
from app.services.export_service import build_export_zip
|
||||
from app.users import current_active_user
|
||||
from app.users import get_auth_context
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -24,12 +25,13 @@ async def export_knowledge_base(
|
|||
None, description="Export only this folder's subtree"
|
||||
),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Export documents as a ZIP of markdown files preserving folder structure."""
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.DOCUMENTS_READ.value,
|
||||
"You don't have permission to export documents in this search space",
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from sqlalchemy import text
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import Document, Folder, Permission, User, get_async_session
|
||||
from app.schemas import (
|
||||
BulkDocumentMove,
|
||||
|
|
@ -23,7 +24,7 @@ from app.services.folder_service import (
|
|||
get_subtree_max_depth,
|
||||
validate_folder_depth,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.users import get_auth_context
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
router = APIRouter()
|
||||
|
|
@ -33,13 +34,14 @@ router = APIRouter()
|
|||
async def create_folder(
|
||||
request: FolderCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Create a new folder. Requires DOCUMENTS_CREATE permission."""
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
request.search_space_id,
|
||||
Permission.DOCUMENTS_CREATE.value,
|
||||
"You don't have permission to create folders in this search space",
|
||||
|
|
@ -91,13 +93,14 @@ async def create_folder(
|
|||
async def list_folders(
|
||||
search_space_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""List all folders in a search space (flat). Requires DOCUMENTS_READ permission."""
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.DOCUMENTS_READ.value,
|
||||
"You don't have permission to read folders in this search space",
|
||||
|
|
@ -122,8 +125,9 @@ async def list_folders(
|
|||
async def get_folder(
|
||||
folder_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Get a single folder. Requires DOCUMENTS_READ permission."""
|
||||
try:
|
||||
folder = await session.get(Folder, folder_id)
|
||||
|
|
@ -132,7 +136,7 @@ async def get_folder(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
folder.search_space_id,
|
||||
Permission.DOCUMENTS_READ.value,
|
||||
"You don't have permission to read folders in this search space",
|
||||
|
|
@ -152,8 +156,9 @@ async def get_folder(
|
|||
async def get_folder_breadcrumb(
|
||||
folder_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Get ancestor chain for breadcrumb display. Requires DOCUMENTS_READ permission."""
|
||||
try:
|
||||
folder = await session.get(Folder, folder_id)
|
||||
|
|
@ -162,7 +167,7 @@ async def get_folder_breadcrumb(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
folder.search_space_id,
|
||||
Permission.DOCUMENTS_READ.value,
|
||||
"You don't have permission to read folders in this search space",
|
||||
|
|
@ -196,8 +201,9 @@ async def get_folder_breadcrumb(
|
|||
async def stop_watching_folder(
|
||||
folder_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Clear the watched flag from a folder's metadata."""
|
||||
folder = await session.get(Folder, folder_id)
|
||||
if not folder:
|
||||
|
|
@ -205,7 +211,7 @@ async def stop_watching_folder(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
folder.search_space_id,
|
||||
Permission.DOCUMENTS_UPDATE.value,
|
||||
"You don't have permission to update folders in this search space",
|
||||
|
|
@ -224,8 +230,9 @@ async def update_folder(
|
|||
folder_id: int,
|
||||
request: FolderUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Rename a folder. Requires DOCUMENTS_UPDATE permission."""
|
||||
try:
|
||||
folder = await session.get(Folder, folder_id)
|
||||
|
|
@ -234,7 +241,7 @@ async def update_folder(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
folder.search_space_id,
|
||||
Permission.DOCUMENTS_UPDATE.value,
|
||||
"You don't have permission to update folders in this search space",
|
||||
|
|
@ -264,8 +271,9 @@ async def move_folder(
|
|||
folder_id: int,
|
||||
request: FolderMove,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Move a folder to a new parent. Requires DOCUMENTS_UPDATE permission."""
|
||||
try:
|
||||
folder = await session.get(Folder, folder_id)
|
||||
|
|
@ -274,7 +282,7 @@ async def move_folder(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
folder.search_space_id,
|
||||
Permission.DOCUMENTS_UPDATE.value,
|
||||
"You don't have permission to move folders in this search space",
|
||||
|
|
@ -324,8 +332,9 @@ async def reorder_folder(
|
|||
folder_id: int,
|
||||
request: FolderReorder,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Reorder a folder among its siblings via fractional indexing. Requires DOCUMENTS_UPDATE."""
|
||||
try:
|
||||
folder = await session.get(Folder, folder_id)
|
||||
|
|
@ -334,7 +343,7 @@ async def reorder_folder(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
folder.search_space_id,
|
||||
Permission.DOCUMENTS_UPDATE.value,
|
||||
"You don't have permission to reorder folders in this search space",
|
||||
|
|
@ -365,8 +374,9 @@ async def reorder_folder(
|
|||
async def delete_folder(
|
||||
folder_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Mark documents for deletion and dispatch Celery to delete docs first, then folders."""
|
||||
try:
|
||||
folder = await session.get(Folder, folder_id)
|
||||
|
|
@ -375,7 +385,7 @@ async def delete_folder(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
folder.search_space_id,
|
||||
Permission.DOCUMENTS_DELETE.value,
|
||||
"You don't have permission to delete folders in this search space",
|
||||
|
|
@ -439,8 +449,9 @@ async def move_document(
|
|||
document_id: int,
|
||||
request: DocumentMove,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Move a document to a folder (or root). Requires DOCUMENTS_UPDATE permission."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
|
|
@ -452,7 +463,7 @@ async def move_document(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
document.search_space_id,
|
||||
Permission.DOCUMENTS_UPDATE.value,
|
||||
"You don't have permission to move documents in this search space",
|
||||
|
|
@ -485,8 +496,9 @@ async def move_document(
|
|||
async def bulk_move_documents(
|
||||
request: BulkDocumentMove,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Move multiple documents to a folder (or root). Requires DOCUMENTS_UPDATE permission."""
|
||||
try:
|
||||
if not request.document_ids:
|
||||
|
|
@ -504,7 +516,7 @@ async def bulk_move_documents(
|
|||
for ss_id in search_space_ids:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
ss_id,
|
||||
Permission.DOCUMENTS_UPDATE.value,
|
||||
"You don't have permission to move documents in this search space",
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from sqlalchemy import or_, select
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from starlette.responses import JSONResponse, RedirectResponse, Response
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
ExternalChatAccount,
|
||||
|
|
@ -51,7 +52,7 @@ from app.observability.metrics import (
|
|||
record_gateway_inbox_write,
|
||||
record_gateway_webhook_parse_error,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.users import get_auth_context
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
|
||||
from app.utils.rbac import check_search_space_access
|
||||
|
||||
|
|
@ -250,14 +251,15 @@ def _telegram_message(payload: dict[str, Any]) -> dict[str, Any] | None:
|
|||
@router.get("/slack/install")
|
||||
async def install_slack_gateway(
|
||||
search_space_id: int,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> dict[str, str]:
|
||||
user = auth.user
|
||||
if not _slack_gateway_enabled():
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Slack gateway OAuth is not configured"
|
||||
)
|
||||
await check_search_space_access(session, user, search_space_id)
|
||||
await check_search_space_access(session, auth, search_space_id)
|
||||
state = _get_state_manager().generate_secure_state(search_space_id, user.id)
|
||||
auth_params = {
|
||||
"client_id": config.GATEWAY_SLACK_CLIENT_ID,
|
||||
|
|
@ -409,14 +411,15 @@ async def slack_gateway_callback(
|
|||
@router.get("/discord/install")
|
||||
async def install_discord_gateway(
|
||||
search_space_id: int,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> dict[str, str]:
|
||||
user = auth.user
|
||||
if not _discord_gateway_enabled():
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Discord gateway OAuth is not configured"
|
||||
)
|
||||
await check_search_space_access(session, user, search_space_id)
|
||||
await check_search_space_access(session, auth, search_space_id)
|
||||
state = _get_state_manager().generate_secure_state(search_space_id, user.id)
|
||||
auth_params = {
|
||||
"client_id": config.DISCORD_CLIENT_ID,
|
||||
|
|
@ -712,10 +715,11 @@ async def telegram_webhook(
|
|||
@router.post("/bindings/start", response_model=StartBindingResponse)
|
||||
async def start_binding(
|
||||
body: StartBindingRequest,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> StartBindingResponse:
|
||||
await check_search_space_access(session, user, body.search_space_id)
|
||||
user = auth.user
|
||||
await check_search_space_access(session, auth, body.search_space_id)
|
||||
code = generate_pairing_code()
|
||||
if body.platform == ExternalChatPlatform.TELEGRAM:
|
||||
if not _telegram_gateway_enabled():
|
||||
|
|
@ -774,9 +778,10 @@ async def start_binding(
|
|||
|
||||
@router.get("/bindings")
|
||||
async def list_bindings(
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> list[dict[str, Any]]:
|
||||
user = auth.user
|
||||
result = await session.execute(
|
||||
select(ExternalChatBinding, ExternalChatAccount)
|
||||
.join(
|
||||
|
|
@ -803,9 +808,10 @@ async def list_bindings(
|
|||
@router.get("/connections")
|
||||
async def list_connections(
|
||||
platform: ExternalChatPlatform | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> list[dict[str, Any]]:
|
||||
user = auth.user
|
||||
active_whatsapp_mode = _active_whatsapp_account_mode()
|
||||
if platform == ExternalChatPlatform.WHATSAPP and active_whatsapp_mode is None:
|
||||
return []
|
||||
|
|
@ -946,9 +952,10 @@ async def list_connections(
|
|||
|
||||
@router.get("/platforms")
|
||||
async def list_platforms(
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> list[dict[str, Any]]:
|
||||
user = auth.user
|
||||
result = await session.execute(
|
||||
select(ExternalChatAccount).where(
|
||||
(ExternalChatAccount.owner_user_id == user.id)
|
||||
|
|
@ -970,8 +977,9 @@ async def list_platforms(
|
|||
|
||||
@config_router.get("/config")
|
||||
async def get_gateway_config(
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> dict[str, bool | str]:
|
||||
user = auth.user
|
||||
if not config.GATEWAY_ENABLED:
|
||||
return {
|
||||
"enabled": False,
|
||||
|
|
@ -993,9 +1001,10 @@ async def get_gateway_config(
|
|||
async def update_binding_search_space(
|
||||
binding_id: int,
|
||||
body: UpdateBindingSearchSpaceRequest,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> dict[str, bool]:
|
||||
user = auth.user
|
||||
binding = await session.get(ExternalChatBinding, binding_id)
|
||||
if binding is None or binding.user_id != user.id:
|
||||
raise HTTPException(status_code=404, detail="Binding not found")
|
||||
|
|
@ -1010,7 +1019,7 @@ async def update_binding_search_space(
|
|||
if account is None or _is_inactive_whatsapp_account(account):
|
||||
raise HTTPException(status_code=404, detail="Binding not found")
|
||||
|
||||
await check_search_space_access(session, user, body.search_space_id)
|
||||
await check_search_space_access(session, auth, body.search_space_id)
|
||||
if binding.search_space_id != body.search_space_id:
|
||||
binding.search_space_id = body.search_space_id
|
||||
binding.new_chat_thread_id = None
|
||||
|
|
@ -1023,9 +1032,10 @@ async def update_binding_search_space(
|
|||
async def update_gateway_account_search_space(
|
||||
account_id: int,
|
||||
body: UpdateAccountSearchSpaceRequest,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> dict[str, bool]:
|
||||
user = auth.user
|
||||
account = await session.get(ExternalChatAccount, account_id)
|
||||
if (
|
||||
account is None
|
||||
|
|
@ -1036,7 +1046,7 @@ async def update_gateway_account_search_space(
|
|||
):
|
||||
raise HTTPException(status_code=404, detail="Gateway account not found")
|
||||
|
||||
await check_search_space_access(session, user, body.search_space_id)
|
||||
await check_search_space_access(session, auth, body.search_space_id)
|
||||
account.owner_search_space_id = body.search_space_id
|
||||
account.updated_at = datetime.now(UTC)
|
||||
|
||||
|
|
@ -1061,9 +1071,10 @@ async def update_gateway_account_search_space(
|
|||
@router.delete("/bindings/{binding_id}")
|
||||
async def delete_binding(
|
||||
binding_id: int,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> dict[str, bool]:
|
||||
user = auth.user
|
||||
binding = await session.get(ExternalChatBinding, binding_id)
|
||||
if binding is None or binding.user_id != user.id:
|
||||
raise HTTPException(status_code=404, detail="Binding not found")
|
||||
|
|
@ -1078,9 +1089,10 @@ async def delete_binding(
|
|||
@router.delete("/accounts/{account_id}")
|
||||
async def delete_gateway_account(
|
||||
account_id: int,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> dict[str, bool]:
|
||||
user = auth.user
|
||||
account = await session.get(ExternalChatAccount, account_id)
|
||||
if (
|
||||
account is None
|
||||
|
|
@ -1114,9 +1126,10 @@ async def delete_gateway_account(
|
|||
@router.post("/bindings/{binding_id}/resume")
|
||||
async def resume_external_chat_binding(
|
||||
binding_id: int,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> dict[str, bool]:
|
||||
user = auth.user
|
||||
binding = await session.get(ExternalChatBinding, binding_id)
|
||||
if binding is None or binding.user_id != user.id:
|
||||
raise HTTPException(status_code=404, detail="Binding not found")
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from pydantic import BaseModel
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
ExternalChatAccount,
|
||||
|
|
@ -20,7 +21,7 @@ from app.db import (
|
|||
get_async_session,
|
||||
)
|
||||
from app.gateway.whatsapp.adapter_baileys import WhatsAppBaileysAdapter
|
||||
from app.users import current_active_user
|
||||
from app.users import get_auth_context
|
||||
from app.utils.rbac import check_search_space_access
|
||||
|
||||
router = APIRouter(prefix="/gateway/whatsapp/baileys", tags=["gateway"])
|
||||
|
|
@ -60,11 +61,12 @@ async def _get_user_whatsapp_account(
|
|||
@router.post("/pair")
|
||||
async def request_pairing_code(
|
||||
body: BaileysPairRequest,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> dict[str, Any]:
|
||||
user = auth.user
|
||||
_ensure_baileys_enabled()
|
||||
await check_search_space_access(session, user, body.search_space_id)
|
||||
await check_search_space_access(session, auth, body.search_space_id)
|
||||
adapter = WhatsAppBaileysAdapter()
|
||||
try:
|
||||
pairing = await adapter.request_pairing_code(phone_number=body.phone_number)
|
||||
|
|
@ -97,8 +99,9 @@ async def request_pairing_code(
|
|||
|
||||
@router.get("/health")
|
||||
async def bridge_health(
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> dict[str, Any]:
|
||||
user = auth.user
|
||||
_ensure_baileys_enabled()
|
||||
adapter = WhatsAppBaileysAdapter()
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -15,15 +15,15 @@ from sqlalchemy.exc import IntegrityError
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.config import config
|
||||
from app.connectors.google_gmail_connector import fetch_google_user_email
|
||||
from app.db import (
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.users import require_session_context
|
||||
from app.utils.connector_naming import (
|
||||
check_duplicate_connector,
|
||||
generate_unique_connector_name,
|
||||
|
|
@ -88,7 +88,11 @@ def get_google_flow():
|
|||
|
||||
|
||||
@router.get("/auth/google/calendar/connector/add")
|
||||
async def connect_calendar(space_id: int, user: User = Depends(current_active_user)):
|
||||
async def connect_calendar(
|
||||
space_id: int,
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
user = auth.user
|
||||
try:
|
||||
if not space_id:
|
||||
raise HTTPException(status_code=400, detail="space_id is required")
|
||||
|
|
@ -127,10 +131,11 @@ async def reauth_calendar(
|
|||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Initiate Google Calendar re-authentication for an existing connector."""
|
||||
user = auth.user
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from sqlalchemy.exc import IntegrityError
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.config import config
|
||||
from app.connectors.google_drive import (
|
||||
GoogleDriveClient,
|
||||
|
|
@ -33,10 +34,9 @@ from app.connectors.google_gmail_connector import fetch_google_user_email
|
|||
from app.db import (
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.users import get_auth_context, require_session_context
|
||||
from app.utils.connector_naming import (
|
||||
check_duplicate_connector,
|
||||
generate_unique_connector_name,
|
||||
|
|
@ -46,6 +46,7 @@ from app.utils.oauth_security import (
|
|||
TokenEncryption,
|
||||
generate_code_verifier,
|
||||
)
|
||||
from app.utils.rbac import check_search_space_access
|
||||
|
||||
# Relax token scope validation for Google OAuth
|
||||
os.environ["OAUTHLIB_RELAX_TOKEN_SCOPE"] = "1"
|
||||
|
|
@ -110,7 +111,10 @@ def get_google_flow():
|
|||
|
||||
|
||||
@router.get("/auth/google/drive/connector/add")
|
||||
async def connect_drive(space_id: int, user: User = Depends(current_active_user)):
|
||||
async def connect_drive(
|
||||
space_id: int,
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
"""
|
||||
Initiate Google Drive OAuth flow.
|
||||
|
||||
|
|
@ -120,6 +124,7 @@ async def connect_drive(space_id: int, user: User = Depends(current_active_user)
|
|||
Returns:
|
||||
JSON with auth_url to redirect user to Google authorization
|
||||
"""
|
||||
user = auth.user
|
||||
try:
|
||||
if not space_id:
|
||||
raise HTTPException(status_code=400, detail="space_id is required")
|
||||
|
|
@ -165,7 +170,7 @@ async def reauth_drive(
|
|||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""
|
||||
|
|
@ -178,6 +183,7 @@ async def reauth_drive(
|
|||
Returns:
|
||||
JSON with auth_url to redirect user to Google authorization
|
||||
"""
|
||||
user = auth.user
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
|
|
@ -472,7 +478,7 @@ async def list_google_drive_folders(
|
|||
connector_id: int,
|
||||
parent_id: str | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
"""
|
||||
List folders AND files in user's Google Drive with hierarchical support.
|
||||
|
|
@ -492,6 +498,7 @@ async def list_google_drive_folders(
|
|||
]
|
||||
}
|
||||
"""
|
||||
user = auth.user
|
||||
try:
|
||||
# Get connector and verify ownership
|
||||
result = await session.execute(
|
||||
|
|
@ -510,6 +517,8 @@ async def list_google_drive_folders(
|
|||
detail="Google Drive connector not found or access denied",
|
||||
)
|
||||
|
||||
await check_search_space_access(session, auth, connector.search_space_id)
|
||||
|
||||
# Initialize Drive client (credentials will be loaded on first API call)
|
||||
drive_client = GoogleDriveClient(session, connector_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -15,15 +15,15 @@ from sqlalchemy.exc import IntegrityError
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.config import config
|
||||
from app.connectors.google_gmail_connector import fetch_google_user_email
|
||||
from app.db import (
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.users import require_session_context
|
||||
from app.utils.connector_naming import (
|
||||
check_duplicate_connector,
|
||||
generate_unique_connector_name,
|
||||
|
|
@ -92,7 +92,10 @@ def get_google_flow():
|
|||
|
||||
|
||||
@router.get("/auth/google/gmail/connector/add")
|
||||
async def connect_gmail(space_id: int, user: User = Depends(current_active_user)):
|
||||
async def connect_gmail(
|
||||
space_id: int,
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
"""
|
||||
Initiate Google Gmail OAuth flow.
|
||||
|
||||
|
|
@ -102,6 +105,7 @@ async def connect_gmail(space_id: int, user: User = Depends(current_active_user)
|
|||
Returns:
|
||||
JSON with auth_url to redirect user to Google authorization
|
||||
"""
|
||||
user = auth.user
|
||||
try:
|
||||
if not space_id:
|
||||
raise HTTPException(status_code=400, detail="space_id is required")
|
||||
|
|
@ -145,10 +149,11 @@ async def reauth_gmail(
|
|||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Initiate Gmail re-authentication for an existing connector."""
|
||||
user = auth.user
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from sqlalchemy.exc import SQLAlchemyError
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
ImageGeneration,
|
||||
|
|
@ -46,7 +47,7 @@ from app.services.image_gen_router_service import (
|
|||
)
|
||||
from app.services.model_capabilities import has_capability
|
||||
from app.services.model_resolver import to_litellm
|
||||
from app.users import current_active_user
|
||||
from app.users import get_auth_context
|
||||
from app.utils.rbac import check_permission
|
||||
from app.utils.signed_image_urls import verify_image_token
|
||||
|
||||
|
|
@ -245,8 +246,9 @@ async def _execute_image_generation(
|
|||
async def create_image_generation(
|
||||
data: ImageGenerationCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Create and execute an image generation request.
|
||||
|
||||
Premium configs are gated by the user's shared premium credit pool.
|
||||
|
|
@ -270,7 +272,7 @@ async def create_image_generation(
|
|||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
data.search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_CREATE.value,
|
||||
"You don't have permission to create image generations in this search space",
|
||||
|
|
@ -365,8 +367,9 @@ async def list_image_generations(
|
|||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""List image generations."""
|
||||
if skip < 0 or limit < 1:
|
||||
raise HTTPException(status_code=400, detail="Invalid pagination parameters")
|
||||
|
|
@ -377,7 +380,7 @@ async def list_image_generations(
|
|||
if search_space_id is not None:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_READ.value,
|
||||
"You don't have permission to read image generations in this search space",
|
||||
|
|
@ -417,8 +420,9 @@ async def list_image_generations(
|
|||
async def get_image_generation(
|
||||
image_gen_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Get a specific image generation by ID."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
|
|
@ -430,7 +434,7 @@ async def get_image_generation(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
image_gen.search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_READ.value,
|
||||
"You don't have permission to read image generations in this search space",
|
||||
|
|
@ -449,8 +453,9 @@ async def get_image_generation(
|
|||
async def delete_image_generation(
|
||||
image_gen_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Delete an image generation record."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
|
|
@ -462,7 +467,7 @@ async def delete_image_generation(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
db_image_gen.search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_DELETE.value,
|
||||
"You don't have permission to delete image generations in this search space",
|
||||
|
|
|
|||
|
|
@ -8,10 +8,10 @@ from fastapi import APIRouter, Depends, HTTPException, status
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import (
|
||||
INCENTIVE_TASKS_CONFIG,
|
||||
IncentiveTaskType,
|
||||
User,
|
||||
UserIncentiveTask,
|
||||
get_async_session,
|
||||
)
|
||||
|
|
@ -21,19 +21,20 @@ from app.schemas.incentive_tasks import (
|
|||
IncentiveTasksResponse,
|
||||
TaskAlreadyCompletedResponse,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.users import require_session_context
|
||||
|
||||
router = APIRouter(prefix="/incentive-tasks", tags=["incentive-tasks"])
|
||||
|
||||
|
||||
@router.get("", response_model=IncentiveTasksResponse)
|
||||
async def get_incentive_tasks(
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> IncentiveTasksResponse:
|
||||
"""
|
||||
Get all available incentive tasks with the user's completion status.
|
||||
"""
|
||||
user = auth.user
|
||||
# Get all completed tasks for this user
|
||||
result = await session.execute(
|
||||
select(UserIncentiveTask).where(UserIncentiveTask.user_id == user.id)
|
||||
|
|
@ -75,7 +76,7 @@ async def get_incentive_tasks(
|
|||
)
|
||||
async def complete_task(
|
||||
task_type: IncentiveTaskType,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> CompleteTaskResponse | TaskAlreadyCompletedResponse:
|
||||
"""
|
||||
|
|
@ -84,6 +85,7 @@ async def complete_task(
|
|||
Each task can only be completed once. If the task was already completed,
|
||||
returns the existing completion information without awarding additional credit.
|
||||
"""
|
||||
user = auth.user
|
||||
# Validate task type exists in config
|
||||
task_config = INCENTIVE_TASKS_CONFIG.get(task_type)
|
||||
if not task_config:
|
||||
|
|
|
|||
|
|
@ -16,15 +16,15 @@ from pydantic import ValidationError
|
|||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.schemas.atlassian_auth_credentials import AtlassianAuthCredentialsBase
|
||||
from app.users import current_active_user
|
||||
from app.users import require_session_context
|
||||
from app.utils.connector_naming import (
|
||||
check_duplicate_connector,
|
||||
extract_identifier_from_credentials,
|
||||
|
|
@ -75,7 +75,10 @@ def get_token_encryption() -> TokenEncryption:
|
|||
|
||||
|
||||
@router.get("/auth/jira/connector/add")
|
||||
async def connect_jira(space_id: int, user: User = Depends(current_active_user)):
|
||||
async def connect_jira(
|
||||
space_id: int,
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
"""
|
||||
Initiate Jira OAuth flow.
|
||||
|
||||
|
|
@ -86,6 +89,7 @@ async def connect_jira(space_id: int, user: User = Depends(current_active_user))
|
|||
Returns:
|
||||
Authorization URL for redirect
|
||||
"""
|
||||
user = auth.user
|
||||
try:
|
||||
if not space_id:
|
||||
raise HTTPException(status_code=400, detail="space_id is required")
|
||||
|
|
@ -438,10 +442,11 @@ async def reauth_jira(
|
|||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Initiate Jira re-authentication to upgrade OAuth scopes."""
|
||||
user = auth.user
|
||||
try:
|
||||
from sqlalchemy.future import select
|
||||
|
||||
|
|
|
|||
|
|
@ -17,16 +17,16 @@ from sqlalchemy.exc import IntegrityError
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.config import config
|
||||
from app.connectors.linear_connector import fetch_linear_organization_name
|
||||
from app.db import (
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.schemas.linear_auth_credentials import LinearAuthCredentialsBase
|
||||
from app.users import current_active_user
|
||||
from app.users import require_session_context
|
||||
from app.utils.connector_naming import (
|
||||
check_duplicate_connector,
|
||||
generate_unique_connector_name,
|
||||
|
|
@ -79,7 +79,10 @@ def make_basic_auth_header(client_id: str, client_secret: str) -> str:
|
|||
|
||||
|
||||
@router.get("/auth/linear/connector/add")
|
||||
async def connect_linear(space_id: int, user: User = Depends(current_active_user)):
|
||||
async def connect_linear(
|
||||
space_id: int,
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
"""
|
||||
Initiate Linear OAuth flow.
|
||||
|
||||
|
|
@ -90,6 +93,7 @@ async def connect_linear(space_id: int, user: User = Depends(current_active_user
|
|||
Returns:
|
||||
Authorization URL for redirect
|
||||
"""
|
||||
user = auth.user
|
||||
try:
|
||||
if not space_id:
|
||||
raise HTTPException(status_code=400, detail="space_id is required")
|
||||
|
|
@ -134,10 +138,11 @@ async def reauth_linear(
|
|||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Initiate Linear re-authentication for an existing connector."""
|
||||
user = auth.user
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from sqlalchemy import and_, desc
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import (
|
||||
Log,
|
||||
LogLevel,
|
||||
|
|
@ -16,7 +17,7 @@ from app.db import (
|
|||
get_async_session,
|
||||
)
|
||||
from app.schemas import LogCreate, LogRead, LogUpdate
|
||||
from app.users import current_active_user
|
||||
from app.users import get_auth_context
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
router = APIRouter()
|
||||
|
|
@ -26,8 +27,9 @@ router = APIRouter()
|
|||
async def create_log(
|
||||
log: LogCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Create a new log entry.
|
||||
Note: This is typically called internally. Requires LOGS_READ permission (since logs are usually system-generated).
|
||||
|
|
@ -36,7 +38,7 @@ async def create_log(
|
|||
# Check if the user has access to the search space
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
log.search_space_id,
|
||||
Permission.LOGS_READ.value,
|
||||
"You don't have permission to access logs in this search space",
|
||||
|
|
@ -67,8 +69,9 @@ async def read_logs(
|
|||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Get logs with optional filtering.
|
||||
Requires LOGS_READ permission for the search space(s).
|
||||
|
|
@ -81,7 +84,7 @@ async def read_logs(
|
|||
# Check permission for specific search space
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.LOGS_READ.value,
|
||||
"You don't have permission to read logs in this search space",
|
||||
|
|
@ -136,8 +139,9 @@ async def read_logs(
|
|||
async def read_log(
|
||||
log_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Get a specific log by ID.
|
||||
Requires LOGS_READ permission for the search space.
|
||||
|
|
@ -152,7 +156,7 @@ async def read_log(
|
|||
# Check permission for the search space
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
log.search_space_id,
|
||||
Permission.LOGS_READ.value,
|
||||
"You don't have permission to read logs in this search space",
|
||||
|
|
@ -172,8 +176,9 @@ async def update_log(
|
|||
log_id: int,
|
||||
log_update: LogUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Update a log entry.
|
||||
Requires LOGS_READ permission (logs are typically updated by system).
|
||||
|
|
@ -188,7 +193,7 @@ async def update_log(
|
|||
# Check permission for the search space
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
db_log.search_space_id,
|
||||
Permission.LOGS_READ.value,
|
||||
"You don't have permission to access logs in this search space",
|
||||
|
|
@ -215,8 +220,9 @@ async def update_log(
|
|||
async def delete_log(
|
||||
log_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Delete a log entry.
|
||||
Requires LOGS_DELETE permission for the search space.
|
||||
|
|
@ -231,7 +237,7 @@ async def delete_log(
|
|||
# Check permission for the search space
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
db_log.search_space_id,
|
||||
Permission.LOGS_DELETE.value,
|
||||
"You don't have permission to delete logs in this search space",
|
||||
|
|
@ -254,8 +260,9 @@ async def get_logs_summary(
|
|||
search_space_id: int,
|
||||
hours: int = 24,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Get a summary of logs for a search space in the last X hours.
|
||||
Requires LOGS_READ permission for the search space.
|
||||
|
|
@ -264,7 +271,7 @@ async def get_logs_summary(
|
|||
# Check permission
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.LOGS_READ.value,
|
||||
"You don't have permission to read logs in this search space",
|
||||
|
|
|
|||
|
|
@ -6,13 +6,13 @@ from sqlalchemy.exc import IntegrityError
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import (
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.users import require_session_context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -29,7 +29,7 @@ class AddLumaConnectorRequest(BaseModel):
|
|||
@router.post("/connectors/luma/add")
|
||||
async def add_luma_connector(
|
||||
request: AddLumaConnectorRequest,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""
|
||||
|
|
@ -46,6 +46,7 @@ async def add_luma_connector(
|
|||
Raises:
|
||||
HTTPException: If connector already exists or validation fails
|
||||
"""
|
||||
user = auth.user
|
||||
try:
|
||||
# Check if a Luma connector already exists for this search space and user
|
||||
result = await session.execute(
|
||||
|
|
@ -118,7 +119,7 @@ async def add_luma_connector(
|
|||
@router.delete("/connectors/luma")
|
||||
async def delete_luma_connector(
|
||||
space_id: int,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""
|
||||
|
|
@ -135,6 +136,7 @@ async def delete_luma_connector(
|
|||
Raises:
|
||||
HTTPException: If connector doesn't exist
|
||||
"""
|
||||
user = auth.user
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
|
|
@ -173,7 +175,7 @@ async def delete_luma_connector(
|
|||
@router.get("/connectors/luma/test")
|
||||
async def test_luma_connector(
|
||||
space_id: int,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""
|
||||
|
|
@ -190,6 +192,7 @@ async def test_luma_connector(
|
|||
Raises:
|
||||
HTTPException: If connector doesn't exist or test fails
|
||||
"""
|
||||
user = auth.user
|
||||
try:
|
||||
# Get the Luma connector for this search space and user
|
||||
result = await session.execute(
|
||||
|
|
|
|||
|
|
@ -20,14 +20,14 @@ from sqlalchemy.exc import IntegrityError
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.users import require_session_context
|
||||
from app.utils.connector_naming import generate_unique_connector_name
|
||||
from app.utils.oauth_security import (
|
||||
OAuthStateManager,
|
||||
|
|
@ -164,8 +164,9 @@ def _frontend_redirect(
|
|||
async def connect_mcp_service(
|
||||
service: str,
|
||||
space_id: int,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
user = auth.user
|
||||
from app.services.mcp_oauth.registry import get_service
|
||||
|
||||
svc = get_service(service)
|
||||
|
|
@ -523,9 +524,10 @@ async def reauth_mcp_service(
|
|||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
user = auth.user
|
||||
from app.services.mcp_oauth.registry import get_service
|
||||
|
||||
svc = get_service(service)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,8 @@ from fastapi import APIRouter, Depends, HTTPException
|
|||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import User, get_async_session
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import get_async_session
|
||||
from app.services.memory import (
|
||||
MemoryRead,
|
||||
MemoryScope,
|
||||
|
|
@ -15,7 +16,7 @@ from app.services.memory import (
|
|||
reset_memory,
|
||||
save_memory,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.users import require_session_context
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -26,9 +27,10 @@ class MemoryUpdate(BaseModel):
|
|||
|
||||
@router.get("/users/me/memory", response_model=MemoryRead)
|
||||
async def get_user_memory(
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
user = auth.user
|
||||
memory_md = await read_memory(
|
||||
scope=MemoryScope.USER,
|
||||
target_id=user.id,
|
||||
|
|
@ -40,9 +42,10 @@ async def get_user_memory(
|
|||
@router.put("/users/me/memory", response_model=MemoryRead)
|
||||
async def update_user_memory(
|
||||
body: MemoryUpdate,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
user = auth.user
|
||||
result = await save_memory(
|
||||
scope=MemoryScope.USER,
|
||||
target_id=user.id,
|
||||
|
|
@ -56,9 +59,10 @@ async def update_user_memory(
|
|||
|
||||
@router.post("/users/me/memory/reset", response_model=MemoryRead)
|
||||
async def reset_user_memory(
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
user = auth.user
|
||||
result = await reset_memory(
|
||||
scope=MemoryScope.USER,
|
||||
target_id=user.id,
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from sqlalchemy import select, update
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
Connection,
|
||||
|
|
@ -14,7 +15,6 @@ from app.db import (
|
|||
NewChatThread,
|
||||
Permission,
|
||||
SearchSpace,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.schemas import (
|
||||
|
|
@ -42,7 +42,7 @@ from app.services.model_connection_service import (
|
|||
verify_connection,
|
||||
)
|
||||
from app.services.provider_registry import REGISTRY
|
||||
from app.users import current_active_user
|
||||
from app.users import get_auth_context, require_session_context
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
router = APIRouter()
|
||||
|
|
@ -257,8 +257,8 @@ async def _default_unset_roles(
|
|||
|
||||
|
||||
@router.get("/model-providers", response_model=list[ModelProviderRead])
|
||||
async def list_model_providers(user: User = Depends(current_active_user)):
|
||||
del user
|
||||
async def list_model_providers(auth: AuthContext = Depends(require_session_context)):
|
||||
del auth
|
||||
local_only = {"ollama_chat", "lm_studio"}
|
||||
return [
|
||||
ModelProviderRead(
|
||||
|
|
@ -298,14 +298,16 @@ async def _load_connection(session: AsyncSession, connection_id: int) -> Connect
|
|||
|
||||
async def _assert_connection_access(
|
||||
session: AsyncSession,
|
||||
user: User,
|
||||
auth: AuthContext,
|
||||
conn: Connection,
|
||||
permission: str = Permission.LLM_CONFIGS_CREATE.value,
|
||||
allow_spaceless_pat: bool = False,
|
||||
) -> None:
|
||||
user = auth.user
|
||||
if conn.search_space_id:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
conn.search_space_id,
|
||||
permission,
|
||||
"You don't have permission to manage model connections in this search space",
|
||||
|
|
@ -315,17 +317,22 @@ async def _assert_connection_access(
|
|||
raise HTTPException(
|
||||
status_code=403, detail="Connection does not belong to user"
|
||||
)
|
||||
if auth.is_gated and not allow_spaceless_pat:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Managing personal model connections requires an interactive session",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/global-llm-config-status")
|
||||
async def global_llm_config_status(user: User = Depends(current_active_user)):
|
||||
del user
|
||||
async def global_llm_config_status(auth: AuthContext = Depends(require_session_context)):
|
||||
del auth
|
||||
return {"exists": config.GLOBAL_LLM_CONFIG_FILE_EXISTS}
|
||||
|
||||
|
||||
@router.get("/global-model-connections", response_model=list[ConnectionRead])
|
||||
async def list_global_connections(user: User = Depends(current_active_user)):
|
||||
del user
|
||||
async def list_global_connections(auth: AuthContext = Depends(require_session_context)):
|
||||
del auth
|
||||
models_by_connection: dict[int, list[dict]] = {}
|
||||
for model in config.GLOBAL_MODELS:
|
||||
models_by_connection.setdefault(model["connection_id"], []).append(model)
|
||||
|
|
@ -339,13 +346,14 @@ async def list_global_connections(user: User = Depends(current_active_user)):
|
|||
async def list_connections(
|
||||
search_space_id: int | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
stmt = select(Connection).options(selectinload(Connection.models))
|
||||
if search_space_id is not None:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.LLM_CONFIGS_CREATE.value,
|
||||
"You don't have permission to view model connections in this search space",
|
||||
|
|
@ -363,8 +371,9 @@ async def list_connections(
|
|||
async def create_connection(
|
||||
data: ConnectionCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
if data.scope == ConnectionScope.GLOBAL:
|
||||
raise HTTPException(status_code=400, detail="GLOBAL connections are YAML-only")
|
||||
if data.scope == ConnectionScope.SEARCH_SPACE:
|
||||
|
|
@ -372,11 +381,16 @@ async def create_connection(
|
|||
raise HTTPException(status_code=400, detail="search_space_id is required")
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
data.search_space_id,
|
||||
Permission.LLM_CONFIGS_CREATE.value,
|
||||
"You don't have permission to create model connections in this search space",
|
||||
)
|
||||
elif auth.is_gated:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Managing personal model connections requires an interactive session",
|
||||
)
|
||||
payload = data.model_dump(exclude={"search_space_id", "models"})
|
||||
|
||||
conn = Connection(
|
||||
|
|
@ -411,16 +425,22 @@ async def create_connection(
|
|||
async def preview_connection_models(
|
||||
data: ConnectionCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
if data.scope == ConnectionScope.SEARCH_SPACE and data.search_space_id is not None:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
data.search_space_id,
|
||||
Permission.LLM_CONFIGS_CREATE.value,
|
||||
"You don't have permission to create model connections in this search space",
|
||||
)
|
||||
elif auth.is_gated:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Testing personal model connections requires an interactive session",
|
||||
)
|
||||
|
||||
draft = Connection(
|
||||
provider=data.provider,
|
||||
|
|
@ -445,16 +465,22 @@ async def preview_connection_models(
|
|||
async def test_preview_connection_model(
|
||||
data: ModelTestPreview,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
if data.scope == ConnectionScope.SEARCH_SPACE and data.search_space_id is not None:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
data.search_space_id,
|
||||
Permission.LLM_CONFIGS_CREATE.value,
|
||||
"You don't have permission to create model connections in this search space",
|
||||
)
|
||||
elif auth.is_gated:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Testing personal model connections requires an interactive session",
|
||||
)
|
||||
|
||||
model_id = data.model_id.strip()
|
||||
if not model_id:
|
||||
|
|
@ -491,11 +517,11 @@ async def update_connection(
|
|||
connection_id: int,
|
||||
data: ConnectionUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
conn = await _load_connection(session, connection_id)
|
||||
await _assert_connection_access(
|
||||
session, user, conn, Permission.LLM_CONFIGS_UPDATE.value
|
||||
session, auth, conn, Permission.LLM_CONFIGS_UPDATE.value
|
||||
)
|
||||
search_space_id = conn.search_space_id
|
||||
for key, value in data.model_dump(exclude_unset=True).items():
|
||||
|
|
@ -512,11 +538,11 @@ async def update_connection(
|
|||
async def delete_connection(
|
||||
connection_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
conn = await _load_connection(session, connection_id)
|
||||
await _assert_connection_access(
|
||||
session, user, conn, Permission.LLM_CONFIGS_DELETE.value
|
||||
session, auth, conn, Permission.LLM_CONFIGS_DELETE.value
|
||||
)
|
||||
search_space_id = conn.search_space_id
|
||||
await session.delete(conn)
|
||||
|
|
@ -533,11 +559,11 @@ async def delete_connection(
|
|||
async def verify_model_connection(
|
||||
connection_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
conn = await _load_connection(session, connection_id)
|
||||
await _assert_connection_access(
|
||||
session, user, conn, Permission.LLM_CONFIGS_CREATE.value
|
||||
session, auth, conn, Permission.LLM_CONFIGS_CREATE.value
|
||||
)
|
||||
result = await verify_connection(conn)
|
||||
return VerifyConnectionResponse(
|
||||
|
|
@ -551,11 +577,11 @@ async def verify_model_connection(
|
|||
async def discover_connection_models(
|
||||
connection_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
conn = await _load_connection(session, connection_id)
|
||||
await _assert_connection_access(
|
||||
session, user, conn, Permission.LLM_CONFIGS_CREATE.value
|
||||
session, auth, conn, Permission.LLM_CONFIGS_CREATE.value
|
||||
)
|
||||
try:
|
||||
discovered = await discover_models(conn)
|
||||
|
|
@ -595,11 +621,11 @@ async def add_manual_model(
|
|||
connection_id: int,
|
||||
data: ModelCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
conn = await _load_connection(session, connection_id)
|
||||
await _assert_connection_access(
|
||||
session, user, conn, Permission.LLM_CONFIGS_UPDATE.value
|
||||
session, auth, conn, Permission.LLM_CONFIGS_UPDATE.value
|
||||
)
|
||||
|
||||
model_id = data.model_id.strip()
|
||||
|
|
@ -640,11 +666,11 @@ async def bulk_update_models(
|
|||
connection_id: int,
|
||||
data: ModelsBulkUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
conn = await _load_connection(session, connection_id)
|
||||
await _assert_connection_access(
|
||||
session, user, conn, Permission.LLM_CONFIGS_UPDATE.value
|
||||
session, auth, conn, Permission.LLM_CONFIGS_UPDATE.value
|
||||
)
|
||||
search_space_id = conn.search_space_id
|
||||
|
||||
|
|
@ -674,7 +700,7 @@ async def update_model(
|
|||
model_id: int,
|
||||
data: ModelUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
result = await session.execute(
|
||||
select(Model)
|
||||
|
|
@ -685,7 +711,7 @@ async def update_model(
|
|||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
await _assert_connection_access(
|
||||
session, user, model.connection, Permission.LLM_CONFIGS_UPDATE.value
|
||||
session, auth, model.connection, Permission.LLM_CONFIGS_UPDATE.value
|
||||
)
|
||||
search_space_id = model.connection.search_space_id
|
||||
update = data.model_dump(exclude_unset=True)
|
||||
|
|
@ -704,7 +730,7 @@ async def update_model(
|
|||
async def test_connection_model(
|
||||
model_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
result = await session.execute(
|
||||
select(Model)
|
||||
|
|
@ -715,7 +741,7 @@ async def test_connection_model(
|
|||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
await _assert_connection_access(
|
||||
session, user, model.connection, Permission.LLM_CONFIGS_UPDATE.value
|
||||
session, auth, model.connection, Permission.LLM_CONFIGS_UPDATE.value
|
||||
)
|
||||
result = await test_model(model.connection, model)
|
||||
await session.commit()
|
||||
|
|
@ -730,11 +756,11 @@ async def test_connection_model(
|
|||
async def get_model_roles(
|
||||
search_space_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.LLM_CONFIGS_CREATE.value,
|
||||
"You don't have permission to view model roles in this search space",
|
||||
|
|
@ -756,11 +782,11 @@ async def update_model_roles(
|
|||
search_space_id: int,
|
||||
data: ModelRolesUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.LLM_CONFIGS_UPDATE.value,
|
||||
"You don't have permission to update model roles in this search space",
|
||||
|
|
|
|||
|
|
@ -10,9 +10,9 @@ import logging
|
|||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.db import User
|
||||
from app.auth.context import AuthContext
|
||||
from app.services.model_list_service import get_model_list
|
||||
from app.users import current_active_user
|
||||
from app.users import require_session_context
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -27,7 +27,7 @@ class ModelListItem(BaseModel):
|
|||
|
||||
@router.get("/models", response_model=list[ModelListItem])
|
||||
async def list_available_models(
|
||||
user: User = Depends(current_active_user),
|
||||
_auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
"""
|
||||
Return all available models grouped by provider.
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ from app.agents.chat.multi_agent_chat.shared.filesystem_selection import (
|
|||
FilesystemSelection,
|
||||
LocalFilesystemMount,
|
||||
)
|
||||
from app.auth.context import AuthContext
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
ChatComment,
|
||||
|
|
@ -75,7 +76,7 @@ from app.tasks.chat.streaming.flows import (
|
|||
stream_new_chat,
|
||||
stream_resume_chat,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.users import get_auth_context
|
||||
from app.utils.perf import get_perf_logger
|
||||
from app.utils.rbac import check_permission
|
||||
from app.utils.user_message_multimodal import (
|
||||
|
|
@ -595,8 +596,9 @@ async def list_threads(
|
|||
search_space_id: int,
|
||||
limit: int | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
List all accessible threads for the current user in a search space.
|
||||
Returns threads and archived_threads for ThreadListPrimitive.
|
||||
|
|
@ -615,7 +617,7 @@ async def list_threads(
|
|||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.CHATS_READ.value,
|
||||
"You don't have permission to read chats in this search space",
|
||||
|
|
@ -702,8 +704,9 @@ async def search_threads(
|
|||
search_space_id: int,
|
||||
title: str,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Search accessible threads by title in a search space.
|
||||
|
||||
|
|
@ -721,7 +724,7 @@ async def search_threads(
|
|||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.CHATS_READ.value,
|
||||
"You don't have permission to read chats in this search space",
|
||||
|
|
@ -794,8 +797,9 @@ async def search_threads(
|
|||
async def create_thread(
|
||||
thread: NewChatThreadCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Create a new chat thread.
|
||||
|
||||
|
|
@ -807,7 +811,7 @@ async def create_thread(
|
|||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_CREATE.value,
|
||||
"You don't have permission to create chats in this search space",
|
||||
|
|
@ -852,8 +856,9 @@ async def create_thread(
|
|||
async def get_thread_messages(
|
||||
thread_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Get a thread with all its messages.
|
||||
This is used by ThreadHistoryAdapter.load() to restore conversation.
|
||||
|
|
@ -877,7 +882,7 @@ async def get_thread_messages(
|
|||
# Check permission to read chats in this search space
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_READ.value,
|
||||
"You don't have permission to read chats in this search space",
|
||||
|
|
@ -936,8 +941,9 @@ async def get_thread_messages(
|
|||
async def get_thread_full(
|
||||
thread_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Get full thread details with all messages.
|
||||
|
||||
|
|
@ -964,7 +970,7 @@ async def get_thread_full(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_READ.value,
|
||||
"You don't have permission to read chats in this search space",
|
||||
|
|
@ -1005,8 +1011,9 @@ async def update_thread(
|
|||
thread_id: int,
|
||||
thread_update: NewChatThreadUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Update a thread (title, archived status).
|
||||
Used for renaming and archiving threads.
|
||||
|
|
@ -1027,7 +1034,7 @@ async def update_thread(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
db_thread.search_space_id,
|
||||
Permission.CHATS_UPDATE.value,
|
||||
"You don't have permission to update chats in this search space",
|
||||
|
|
@ -1074,8 +1081,9 @@ async def update_thread(
|
|||
async def delete_thread(
|
||||
thread_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Delete a thread and all its messages.
|
||||
|
||||
|
|
@ -1095,7 +1103,7 @@ async def delete_thread(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
db_thread.search_space_id,
|
||||
Permission.CHATS_DELETE.value,
|
||||
"You don't have permission to delete chats in this search space",
|
||||
|
|
@ -1146,8 +1154,9 @@ async def update_thread_visibility(
|
|||
thread_id: int,
|
||||
visibility_update: NewChatThreadVisibilityUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Update the visibility/sharing settings of a thread.
|
||||
|
||||
|
|
@ -1168,7 +1177,7 @@ async def update_thread_visibility(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
db_thread.search_space_id,
|
||||
Permission.CHATS_UPDATE.value,
|
||||
"You don't have permission to update chats in this search space",
|
||||
|
|
@ -1217,7 +1226,7 @@ async def update_thread_visibility(
|
|||
async def create_thread_snapshot(
|
||||
thread_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
"""
|
||||
Create a public snapshot of the thread.
|
||||
|
|
@ -1229,7 +1238,7 @@ async def create_thread_snapshot(
|
|||
return await create_snapshot(
|
||||
session=session,
|
||||
thread_id=thread_id,
|
||||
user=user,
|
||||
auth=auth,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1239,7 +1248,7 @@ async def create_thread_snapshot(
|
|||
async def list_thread_snapshots(
|
||||
thread_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
"""
|
||||
List all public snapshots for this thread.
|
||||
|
|
@ -1252,7 +1261,7 @@ async def list_thread_snapshots(
|
|||
snapshots=await list_snapshots_for_thread(
|
||||
session=session,
|
||||
thread_id=thread_id,
|
||||
user=user,
|
||||
auth=auth,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -1262,7 +1271,7 @@ async def delete_thread_snapshot(
|
|||
thread_id: int,
|
||||
snapshot_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
"""
|
||||
Delete a specific snapshot.
|
||||
|
|
@ -1275,7 +1284,7 @@ async def delete_thread_snapshot(
|
|||
session=session,
|
||||
thread_id=thread_id,
|
||||
snapshot_id=snapshot_id,
|
||||
user=user,
|
||||
auth=auth,
|
||||
)
|
||||
return {"message": "Snapshot deleted successfully"}
|
||||
|
||||
|
|
@ -1290,8 +1299,9 @@ async def append_message(
|
|||
thread_id: int,
|
||||
request: Request,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
.. deprecated:: 2026-05
|
||||
Replaced by the **SSE-based message ID handshake**. The streaming
|
||||
|
|
@ -1321,8 +1331,8 @@ async def append_message(
|
|||
Requires CHATS_UPDATE permission.
|
||||
"""
|
||||
try:
|
||||
# Capture ``user.id`` as a primitive UUID up front. The
|
||||
# ``current_active_user`` dependency hands us a ``User`` ORM
|
||||
# Capture ``user.id`` as a primitive UUID up front. The auth
|
||||
# dependency hands us a ``User`` ORM
|
||||
# row bound to ``session``; if the outer ``except
|
||||
# IntegrityError`` block below ever fires (an unexpected
|
||||
# constraint like a foreign key violation — the common
|
||||
|
|
@ -1370,7 +1380,7 @@ async def append_message(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_UPDATE.value,
|
||||
"You don't have permission to update chats in this search space",
|
||||
|
|
@ -1597,8 +1607,9 @@ async def list_messages(
|
|||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
List messages in a thread with pagination.
|
||||
|
||||
|
|
@ -1620,7 +1631,7 @@ async def list_messages(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_READ.value,
|
||||
"You don't have permission to read chats in this search space",
|
||||
|
|
@ -1662,7 +1673,7 @@ async def list_messages(
|
|||
|
||||
@router.get("/agent/tools", response_model=list[AgentToolInfo])
|
||||
async def list_agent_tools(
|
||||
_user: User = Depends(current_active_user),
|
||||
_auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
"""Return the list of built-in agent tools with their metadata.
|
||||
|
||||
|
|
@ -1691,8 +1702,9 @@ async def handle_new_chat(
|
|||
request: NewChatRequest,
|
||||
http_request: Request,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Stream chat responses from the deep agent.
|
||||
|
||||
|
|
@ -1717,7 +1729,7 @@ async def handle_new_chat(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_CREATE.value,
|
||||
"You don't have permission to chat in this search space",
|
||||
|
|
@ -1795,6 +1807,7 @@ async def handle_new_chat(
|
|||
filesystem_selection=filesystem_selection,
|
||||
request_id=getattr(http_request.state, "request_id", "unknown"),
|
||||
user_image_data_urls=image_urls,
|
||||
auth_context=auth,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
|
|
@ -1821,8 +1834,9 @@ async def cancel_active_turn(
|
|||
thread_id: int,
|
||||
response: Response,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Signal cancellation for the currently running turn on ``thread_id``."""
|
||||
result = await session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == thread_id)
|
||||
|
|
@ -1833,7 +1847,7 @@ async def cancel_active_turn(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_UPDATE.value,
|
||||
"You don't have permission to update chats in this search space",
|
||||
|
|
@ -1873,8 +1887,9 @@ async def cancel_active_turn(
|
|||
async def get_turn_status(
|
||||
thread_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
result = await session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == thread_id)
|
||||
)
|
||||
|
|
@ -1884,7 +1899,7 @@ async def get_turn_status(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_READ.value,
|
||||
"You don't have permission to view chats in this search space",
|
||||
|
|
@ -1911,8 +1926,9 @@ async def regenerate_response(
|
|||
request: RegenerateRequest,
|
||||
http_request: Request,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Regenerate the AI response for a chat thread.
|
||||
|
||||
|
|
@ -1947,7 +1963,7 @@ async def regenerate_response(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_UPDATE.value,
|
||||
"You don't have permission to update chats in this search space",
|
||||
|
|
@ -2288,6 +2304,7 @@ async def regenerate_response(
|
|||
filesystem_selection=filesystem_selection,
|
||||
request_id=getattr(http_request.state, "request_id", "unknown"),
|
||||
user_image_data_urls=regenerate_image_urls or None,
|
||||
auth_context=auth,
|
||||
flow="regenerate",
|
||||
):
|
||||
yield chunk
|
||||
|
|
@ -2356,8 +2373,9 @@ async def resume_chat(
|
|||
request: ResumeRequest,
|
||||
http_request: Request,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == thread_id)
|
||||
|
|
@ -2369,7 +2387,7 @@ async def resume_chat(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_CREATE.value,
|
||||
"You don't have permission to chat in this search space",
|
||||
|
|
@ -2413,6 +2431,7 @@ async def resume_chat(
|
|||
filesystem_selection=filesystem_selection,
|
||||
request_id=getattr(http_request.state, "request_id", "unknown"),
|
||||
disabled_tools=request.disabled_tools,
|
||||
auth_context=auth,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
|
|
|
|||
|
|
@ -9,9 +9,10 @@ from pydantic import BaseModel
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import Document, DocumentType, Permission, User, get_async_session
|
||||
from app.schemas import DocumentRead, PaginatedResponse
|
||||
from app.users import current_active_user
|
||||
from app.users import get_auth_context
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
router = APIRouter()
|
||||
|
|
@ -27,8 +28,9 @@ async def create_note(
|
|||
search_space_id: int,
|
||||
request: CreateNoteRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Create a new note document.
|
||||
|
||||
|
|
@ -37,7 +39,7 @@ async def create_note(
|
|||
# Check RBAC permission
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.DOCUMENTS_CREATE.value,
|
||||
"You don't have permission to create notes in this search space",
|
||||
|
|
@ -98,8 +100,9 @@ async def list_notes(
|
|||
page: int | None = None,
|
||||
page_size: int = 50,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
List all notes in a search space.
|
||||
|
||||
|
|
@ -108,7 +111,7 @@ async def list_notes(
|
|||
# Check RBAC permission
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.DOCUMENTS_READ.value,
|
||||
"You don't have permission to read notes in this search space",
|
||||
|
|
@ -191,8 +194,9 @@ async def delete_note(
|
|||
search_space_id: int,
|
||||
note_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Delete a note.
|
||||
|
||||
|
|
@ -201,7 +205,7 @@ async def delete_note(
|
|||
# Check RBAC permission
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.DOCUMENTS_DELETE.value,
|
||||
"You don't have permission to delete notes in this search space",
|
||||
|
|
|
|||
|
|
@ -17,15 +17,15 @@ from sqlalchemy.exc import IntegrityError
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.schemas.notion_auth_credentials import NotionAuthCredentialsBase
|
||||
from app.users import current_active_user
|
||||
from app.users import require_session_context
|
||||
from app.utils.connector_naming import (
|
||||
check_duplicate_connector,
|
||||
extract_identifier_from_credentials,
|
||||
|
|
@ -76,7 +76,10 @@ def make_basic_auth_header(client_id: str, client_secret: str) -> str:
|
|||
|
||||
|
||||
@router.get("/auth/notion/connector/add")
|
||||
async def connect_notion(space_id: int, user: User = Depends(current_active_user)):
|
||||
async def connect_notion(
|
||||
space_id: int,
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
"""
|
||||
Initiate Notion OAuth flow.
|
||||
|
||||
|
|
@ -87,6 +90,7 @@ async def connect_notion(space_id: int, user: User = Depends(current_active_user
|
|||
Returns:
|
||||
Authorization URL for redirect
|
||||
"""
|
||||
user = auth.user
|
||||
try:
|
||||
if not space_id:
|
||||
raise HTTPException(status_code=400, detail="space_id is required")
|
||||
|
|
@ -131,10 +135,11 @@ async def reauth_notion(
|
|||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Initiate Notion re-authentication for an existing connector."""
|
||||
user = auth.user
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
|
|
|
|||
|
|
@ -24,14 +24,14 @@ from sqlalchemy.exc import IntegrityError
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.users import require_session_context
|
||||
from app.utils.connector_naming import (
|
||||
check_duplicate_connector,
|
||||
generate_unique_connector_name,
|
||||
|
|
@ -361,8 +361,9 @@ class OAuthConnectorRoute:
|
|||
@router.get(f"{oauth.auth_prefix}/connector/add")
|
||||
async def connect(
|
||||
space_id: int,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
user = auth.user
|
||||
if not space_id:
|
||||
raise HTTPException(status_code=400, detail="space_id is required")
|
||||
|
||||
|
|
@ -406,9 +407,10 @@ class OAuthConnectorRoute:
|
|||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
user = auth.user
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from sqlalchemy.dialects.postgresql import insert as pg_insert
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import (
|
||||
Document,
|
||||
DocumentType,
|
||||
|
|
@ -52,7 +53,8 @@ from app.services.obsidian_plugin_indexer import (
|
|||
upsert_note,
|
||||
)
|
||||
from app.tasks.celery_tasks.obsidian_tasks import index_obsidian_attachment_task
|
||||
from app.users import current_active_user
|
||||
from app.users import allow_any_principal, get_auth_context
|
||||
from app.utils.rbac import check_search_space_access
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -174,10 +176,11 @@ async def _finish_obsidian_sync_notification(
|
|||
async def _resolve_vault_connector(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
user: User,
|
||||
auth: AuthContext,
|
||||
vault_id: str,
|
||||
) -> SearchSourceConnector:
|
||||
"""Find the OBSIDIAN_CONNECTOR row that owns ``vault_id`` for this user."""
|
||||
user = auth.user
|
||||
# ``config`` is core ``JSON`` (not ``JSONB``); ``as_string()`` is the
|
||||
# cross-dialect equivalent of ``.astext`` and compiles to ``->>``.
|
||||
stmt = select(SearchSourceConnector).where(
|
||||
|
|
@ -192,6 +195,7 @@ async def _resolve_vault_connector(
|
|||
|
||||
connector = (await session.execute(stmt)).scalars().first()
|
||||
if connector is not None:
|
||||
await check_search_space_access(session, auth, connector.search_space_id)
|
||||
return connector
|
||||
|
||||
raise HTTPException(
|
||||
|
|
@ -221,10 +225,11 @@ def _queue_obsidian_attachment(
|
|||
async def _ensure_search_space_access(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
user: User,
|
||||
auth: AuthContext,
|
||||
search_space_id: int,
|
||||
) -> SearchSpace:
|
||||
"""Owner-only access to the search space (shared spaces are a follow-up)."""
|
||||
user = auth.user
|
||||
result = await session.execute(
|
||||
select(SearchSpace).where(
|
||||
and_(SearchSpace.id == search_space_id, SearchSpace.user_id == user.id)
|
||||
|
|
@ -239,6 +244,7 @@ async def _ensure_search_space_access(
|
|||
"message": "You don't own that search space.",
|
||||
},
|
||||
)
|
||||
await check_search_space_access(session, auth, search_space_id)
|
||||
return space
|
||||
|
||||
|
||||
|
|
@ -249,7 +255,7 @@ async def _ensure_search_space_access(
|
|||
|
||||
@router.get("/health", response_model=HealthResponse)
|
||||
async def obsidian_health(
|
||||
user: User = Depends(current_active_user),
|
||||
_auth: AuthContext = Depends(allow_any_principal),
|
||||
) -> HealthResponse:
|
||||
"""Return the API contract handshake; plugin caches it per onload."""
|
||||
return HealthResponse(
|
||||
|
|
@ -306,7 +312,7 @@ def _display_name(vault_name: str) -> str:
|
|||
@router.post("/connect", response_model=ConnectResponse)
|
||||
async def obsidian_connect(
|
||||
payload: ConnectRequest,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> ConnectResponse:
|
||||
"""Register a vault, refresh an existing one, or adopt another device's row.
|
||||
|
|
@ -321,8 +327,9 @@ async def obsidian_connect(
|
|||
the partial unique index can never produce two live rows for one vault.
|
||||
"""
|
||||
await _ensure_search_space_access(
|
||||
session, user=user, search_space_id=payload.search_space_id
|
||||
session, auth=auth, search_space_id=payload.search_space_id
|
||||
)
|
||||
user = auth.user
|
||||
|
||||
now_iso = datetime.now(UTC).isoformat()
|
||||
cfg = _build_config(payload, now_iso=now_iso)
|
||||
|
|
@ -445,13 +452,14 @@ async def obsidian_connect(
|
|||
@router.post("/sync", response_model=SyncAck)
|
||||
async def obsidian_sync(
|
||||
payload: SyncBatchRequest,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> SyncAck:
|
||||
"""Batch-upsert notes; returns per-note ack so the plugin can dequeue/retry."""
|
||||
connector = await _resolve_vault_connector(
|
||||
session, user=user, vault_id=payload.vault_id
|
||||
session, auth=auth, vault_id=payload.vault_id
|
||||
)
|
||||
user = auth.user
|
||||
notification = None
|
||||
try:
|
||||
notification = await _start_obsidian_sync_notification(
|
||||
|
|
@ -551,12 +559,12 @@ async def obsidian_sync(
|
|||
@router.post("/rename", response_model=RenameAck)
|
||||
async def obsidian_rename(
|
||||
payload: RenameBatchRequest,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> RenameAck:
|
||||
"""Apply a batch of vault rename events."""
|
||||
connector = await _resolve_vault_connector(
|
||||
session, user=user, vault_id=payload.vault_id
|
||||
session, auth=auth, vault_id=payload.vault_id
|
||||
)
|
||||
|
||||
items: list[RenameAckItem] = []
|
||||
|
|
@ -618,12 +626,12 @@ async def obsidian_rename(
|
|||
@router.delete("/notes", response_model=DeleteAck)
|
||||
async def obsidian_delete_notes(
|
||||
payload: DeleteBatchRequest,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> DeleteAck:
|
||||
"""Soft-delete a batch of notes by vault-relative path."""
|
||||
connector = await _resolve_vault_connector(
|
||||
session, user=user, vault_id=payload.vault_id
|
||||
session, auth=auth, vault_id=payload.vault_id
|
||||
)
|
||||
|
||||
deleted = 0
|
||||
|
|
@ -662,18 +670,18 @@ async def obsidian_delete_notes(
|
|||
@router.get("/manifest", response_model=ManifestResponse)
|
||||
async def obsidian_manifest(
|
||||
vault_id: str = Query(..., description="Plugin-side stable vault UUID"),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> ManifestResponse:
|
||||
"""Return ``{path: {hash, mtime}}`` for the plugin's onload reconcile diff."""
|
||||
connector = await _resolve_vault_connector(session, user=user, vault_id=vault_id)
|
||||
connector = await _resolve_vault_connector(session, auth=auth, vault_id=vault_id)
|
||||
return await get_manifest(session, connector=connector, vault_id=vault_id)
|
||||
|
||||
|
||||
@router.get("/stats", response_model=StatsResponse)
|
||||
async def obsidian_stats(
|
||||
vault_id: str = Query(..., description="Plugin-side stable vault UUID"),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> StatsResponse:
|
||||
"""Active-note count + last sync time for the web tile.
|
||||
|
|
@ -681,7 +689,7 @@ async def obsidian_stats(
|
|||
``files_synced`` excludes tombstones so it matches ``/manifest``;
|
||||
``last_sync_at`` includes them so deletes advance the freshness signal.
|
||||
"""
|
||||
connector = await _resolve_vault_connector(session, user=user, vault_id=vault_id)
|
||||
connector = await _resolve_vault_connector(session, auth=auth, vault_id=vault_id)
|
||||
|
||||
is_active = Document.document_metadata["deleted_at"].as_string().is_(None)
|
||||
|
||||
|
|
|
|||
|
|
@ -21,21 +21,22 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.config import config
|
||||
from app.connectors.onedrive import OneDriveClient, list_folder_contents
|
||||
from app.db import (
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.users import get_auth_context, require_session_context
|
||||
from app.utils.connector_naming import (
|
||||
check_duplicate_connector,
|
||||
extract_identifier_from_credentials,
|
||||
generate_unique_connector_name,
|
||||
)
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
|
||||
from app.utils.rbac import check_search_space_access
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
|
@ -73,8 +74,12 @@ def get_token_encryption() -> TokenEncryption:
|
|||
|
||||
|
||||
@router.get("/auth/onedrive/connector/add")
|
||||
async def connect_onedrive(space_id: int, user: User = Depends(current_active_user)):
|
||||
async def connect_onedrive(
|
||||
space_id: int,
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
"""Initiate OneDrive OAuth flow."""
|
||||
user = auth.user
|
||||
try:
|
||||
if not space_id:
|
||||
raise HTTPException(status_code=400, detail="space_id is required")
|
||||
|
|
@ -119,10 +124,11 @@ async def reauth_onedrive(
|
|||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Re-authenticate an existing OneDrive connector."""
|
||||
user = auth.user
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
|
|
@ -412,10 +418,11 @@ async def list_onedrive_folders(
|
|||
connector_id: int,
|
||||
parent_id: str | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
"""List folders and files in user's OneDrive."""
|
||||
connector = None
|
||||
user = auth.user
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
|
|
@ -431,6 +438,8 @@ async def list_onedrive_folders(
|
|||
status_code=404, detail="OneDrive connector not found or access denied"
|
||||
)
|
||||
|
||||
await check_search_space_access(session, auth, connector.search_space_id)
|
||||
|
||||
onedrive_client = OneDriveClient(session, connector_id)
|
||||
items, error = await list_folder_contents(onedrive_client, parent_id=parent_id)
|
||||
|
||||
|
|
|
|||
104
surfsense_backend/app/routes/personal_access_tokens_routes.py
Normal file
104
surfsense_backend/app/routes/personal_access_tokens_routes.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.config import config
|
||||
from app.db import PersonalAccessToken, get_async_session
|
||||
from app.schemas.pat import PATCreate, PATCreated, PATRead
|
||||
from app.users import require_session_context
|
||||
from app.utils.pat import generate_pat, hash_pat, token_prefix
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _expires_at(expires_in_days: int | None) -> datetime | None:
|
||||
max_expiry_days = config.PAT_MAX_EXPIRY_DAYS
|
||||
|
||||
if max_expiry_days is not None:
|
||||
if expires_in_days is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
"This deployment requires PATs to have an expiry of "
|
||||
f"{max_expiry_days} days or less"
|
||||
),
|
||||
)
|
||||
if expires_in_days > max_expiry_days:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"PAT expiry cannot exceed {max_expiry_days} days",
|
||||
)
|
||||
|
||||
if expires_in_days is None:
|
||||
return None
|
||||
|
||||
return datetime.now(UTC) + timedelta(days=expires_in_days)
|
||||
|
||||
|
||||
@router.post("/pats", response_model=PATCreated)
|
||||
async def create_personal_access_token(
|
||||
body: PATCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
) -> PATCreated:
|
||||
token = generate_pat()
|
||||
pat = PersonalAccessToken(
|
||||
user_id=auth.user.id,
|
||||
token_hash=hash_pat(token),
|
||||
token_prefix=token_prefix(token),
|
||||
label=body.label.strip(),
|
||||
expires_at=_expires_at(body.expires_in_days),
|
||||
)
|
||||
session.add(pat)
|
||||
await session.commit()
|
||||
await session.refresh(pat)
|
||||
|
||||
return PATCreated(
|
||||
id=pat.id,
|
||||
label=pat.label,
|
||||
token=token,
|
||||
prefix=pat.token_prefix,
|
||||
expires_at=pat.expires_at,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/pats", response_model=list[PATRead])
|
||||
async def list_personal_access_tokens(
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
) -> list[PATRead]:
|
||||
result = await session.execute(
|
||||
select(PersonalAccessToken)
|
||||
.where(PersonalAccessToken.user_id == auth.user.id)
|
||||
.order_by(PersonalAccessToken.created_at.desc())
|
||||
)
|
||||
return [
|
||||
PATRead(
|
||||
id=pat.id,
|
||||
label=pat.label,
|
||||
prefix=pat.token_prefix,
|
||||
expires_at=pat.expires_at,
|
||||
last_used_at=pat.last_used_at,
|
||||
created_at=pat.created_at,
|
||||
)
|
||||
for pat in result.scalars().all()
|
||||
]
|
||||
|
||||
|
||||
@router.delete("/pats/{pat_id}", status_code=204)
|
||||
async def delete_personal_access_token(
|
||||
pat_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
) -> None:
|
||||
await session.execute(
|
||||
delete(PersonalAccessToken).where(
|
||||
PersonalAccessToken.id == pat_id,
|
||||
PersonalAccessToken.user_id == auth.user.id,
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
|
|
@ -3,14 +3,15 @@ from sqlalchemy import select
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.db import Prompt, SearchSpaceMembership, User, get_async_session
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import Prompt, SearchSpaceMembership, get_async_session
|
||||
from app.schemas.prompts import (
|
||||
PromptCreate,
|
||||
PromptRead,
|
||||
PromptUpdate,
|
||||
PublicPromptRead,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.users import require_session_context
|
||||
|
||||
router = APIRouter(tags=["Prompts"])
|
||||
|
||||
|
|
@ -19,8 +20,9 @@ router = APIRouter(tags=["Prompts"])
|
|||
async def list_prompts(
|
||||
search_space_id: int | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
user = auth.user
|
||||
query = select(Prompt).where(Prompt.user_id == user.id)
|
||||
if search_space_id is not None:
|
||||
query = query.where(Prompt.search_space_id == search_space_id)
|
||||
|
|
@ -33,8 +35,9 @@ async def list_prompts(
|
|||
async def create_prompt(
|
||||
body: PromptCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
user = auth.user
|
||||
if body.search_space_id is not None:
|
||||
membership = await session.execute(
|
||||
select(SearchSpaceMembership).where(
|
||||
|
|
@ -67,8 +70,9 @@ async def update_prompt(
|
|||
prompt_id: int,
|
||||
body: PromptUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
user = auth.user
|
||||
result = await session.execute(
|
||||
select(Prompt).where(
|
||||
Prompt.id == prompt_id,
|
||||
|
|
@ -99,8 +103,9 @@ async def update_prompt(
|
|||
async def delete_prompt(
|
||||
prompt_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
user = auth.user
|
||||
result = await session.execute(
|
||||
select(Prompt).where(
|
||||
Prompt.id == prompt_id,
|
||||
|
|
@ -119,8 +124,9 @@ async def delete_prompt(
|
|||
@router.get("/prompts/public", response_model=list[PublicPromptRead])
|
||||
async def list_public_prompts(
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
user = auth.user
|
||||
result = await session.execute(
|
||||
select(Prompt)
|
||||
.options(selectinload(Prompt.user))
|
||||
|
|
@ -141,8 +147,9 @@ async def list_public_prompts(
|
|||
async def copy_public_prompt(
|
||||
prompt_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
user = auth.user
|
||||
result = await session.execute(
|
||||
select(Prompt).where(
|
||||
Prompt.id == prompt_id,
|
||||
|
|
|
|||
|
|
@ -11,7 +11,8 @@ from fastapi import APIRouter, Depends, HTTPException
|
|||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import User, get_async_session
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import get_async_session
|
||||
from app.schemas.new_chat import (
|
||||
CloneResponse,
|
||||
PublicChatResponse,
|
||||
|
|
@ -23,7 +24,7 @@ from app.services.public_chat_service import (
|
|||
get_snapshot_report,
|
||||
get_snapshot_video_presentation,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.users import require_session_context
|
||||
|
||||
router = APIRouter(prefix="/public", tags=["public"])
|
||||
|
||||
|
|
@ -46,8 +47,9 @@ async def read_public_chat(
|
|||
async def clone_public_chat(
|
||||
share_token: str,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Clone a public chat snapshot to the user's account.
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import (
|
||||
Permission,
|
||||
SearchSpace,
|
||||
|
|
@ -43,7 +44,7 @@ from app.schemas import (
|
|||
RoleUpdate,
|
||||
UserSearchSpaceAccess,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.users import get_auth_context
|
||||
from app.utils.rbac import (
|
||||
check_permission,
|
||||
check_search_space_access,
|
||||
|
|
@ -107,6 +108,8 @@ PERMISSION_DESCRIPTIONS = {
|
|||
"settings:view": "View search space settings",
|
||||
"settings:update": "Modify search space settings",
|
||||
"settings:delete": "Delete the entire search space",
|
||||
# API access
|
||||
"api_access:manage": "Enable or disable programmatic API access for a search space",
|
||||
# Automations
|
||||
"automations:create": "Create automations from chat or JSON",
|
||||
"automations:read": "View automations, their triggers, and run history",
|
||||
|
|
@ -120,8 +123,9 @@ PERMISSION_DESCRIPTIONS = {
|
|||
|
||||
@router.get("/permissions", response_model=PermissionsListResponse)
|
||||
async def list_all_permissions(
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
List all available permissions that can be assigned to roles.
|
||||
"""
|
||||
|
|
@ -156,8 +160,9 @@ async def create_role(
|
|||
search_space_id: int,
|
||||
role_data: RoleCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Create a new custom role in a search space.
|
||||
Requires ROLES_CREATE permission.
|
||||
|
|
@ -165,7 +170,7 @@ async def create_role(
|
|||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.ROLES_CREATE.value,
|
||||
"You don't have permission to create roles",
|
||||
|
|
@ -237,8 +242,9 @@ async def create_role(
|
|||
async def list_roles(
|
||||
search_space_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
List all roles in a search space.
|
||||
Requires ROLES_READ permission.
|
||||
|
|
@ -246,7 +252,7 @@ async def list_roles(
|
|||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.ROLES_READ.value,
|
||||
"You don't have permission to view roles",
|
||||
|
|
@ -275,8 +281,9 @@ async def get_role(
|
|||
search_space_id: int,
|
||||
role_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Get a specific role by ID.
|
||||
Requires ROLES_READ permission.
|
||||
|
|
@ -284,7 +291,7 @@ async def get_role(
|
|||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.ROLES_READ.value,
|
||||
"You don't have permission to view roles",
|
||||
|
|
@ -320,8 +327,9 @@ async def update_role(
|
|||
role_id: int,
|
||||
role_update: RoleUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Update a role.
|
||||
Requires ROLES_UPDATE permission.
|
||||
|
|
@ -330,7 +338,7 @@ async def update_role(
|
|||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.ROLES_UPDATE.value,
|
||||
"You don't have permission to update roles",
|
||||
|
|
@ -417,8 +425,9 @@ async def delete_role(
|
|||
search_space_id: int,
|
||||
role_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Delete a custom role.
|
||||
Requires ROLES_DELETE permission.
|
||||
|
|
@ -427,7 +436,7 @@ async def delete_role(
|
|||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.ROLES_DELETE.value,
|
||||
"You don't have permission to delete roles",
|
||||
|
|
@ -474,8 +483,9 @@ async def delete_role(
|
|||
async def list_members(
|
||||
search_space_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
List all members of a search space.
|
||||
Requires MEMBERS_VIEW permission.
|
||||
|
|
@ -483,7 +493,7 @@ async def list_members(
|
|||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.MEMBERS_VIEW.value,
|
||||
"You don't have permission to view members",
|
||||
|
|
@ -539,8 +549,9 @@ async def update_member_role(
|
|||
membership_id: int,
|
||||
membership_update: MembershipUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Update a member's role.
|
||||
Requires MEMBERS_MANAGE_ROLES permission.
|
||||
|
|
@ -549,7 +560,7 @@ async def update_member_role(
|
|||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.MEMBERS_MANAGE_ROLES.value,
|
||||
"You don't have permission to manage member roles",
|
||||
|
|
@ -629,8 +640,9 @@ async def update_member_role(
|
|||
async def leave_search_space(
|
||||
search_space_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Leave a search space (remove own membership).
|
||||
Owners cannot leave their search space.
|
||||
|
|
@ -675,8 +687,9 @@ async def remove_member(
|
|||
search_space_id: int,
|
||||
membership_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Remove a member from a search space.
|
||||
Requires MEMBERS_REMOVE permission.
|
||||
|
|
@ -685,7 +698,7 @@ async def remove_member(
|
|||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.MEMBERS_REMOVE.value,
|
||||
"You don't have permission to remove members",
|
||||
|
|
@ -733,8 +746,9 @@ async def create_invite(
|
|||
search_space_id: int,
|
||||
invite_data: InviteCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Create a new invite link for a search space.
|
||||
Requires MEMBERS_INVITE permission.
|
||||
|
|
@ -742,7 +756,7 @@ async def create_invite(
|
|||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.MEMBERS_INVITE.value,
|
||||
"You don't have permission to create invites",
|
||||
|
|
@ -798,8 +812,9 @@ async def create_invite(
|
|||
async def list_invites(
|
||||
search_space_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
List all invites for a search space.
|
||||
Requires MEMBERS_INVITE permission.
|
||||
|
|
@ -807,7 +822,7 @@ async def list_invites(
|
|||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.MEMBERS_INVITE.value,
|
||||
"You don't have permission to view invites",
|
||||
|
|
@ -837,8 +852,9 @@ async def update_invite(
|
|||
invite_id: int,
|
||||
invite_update: InviteUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Update an invite.
|
||||
Requires MEMBERS_INVITE permission.
|
||||
|
|
@ -846,7 +862,7 @@ async def update_invite(
|
|||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.MEMBERS_INVITE.value,
|
||||
"You don't have permission to update invites",
|
||||
|
|
@ -903,8 +919,9 @@ async def revoke_invite(
|
|||
search_space_id: int,
|
||||
invite_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Revoke (delete) an invite.
|
||||
Requires MEMBERS_INVITE permission.
|
||||
|
|
@ -912,7 +929,7 @@ async def revoke_invite(
|
|||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.MEMBERS_INVITE.value,
|
||||
"You don't have permission to revoke invites",
|
||||
|
|
@ -1022,8 +1039,9 @@ async def get_invite_info(
|
|||
async def accept_invite(
|
||||
request: InviteAcceptRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Accept an invite and join a search space.
|
||||
"""
|
||||
|
|
@ -1120,13 +1138,14 @@ async def accept_invite(
|
|||
async def get_my_access(
|
||||
search_space_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Get the current user's access info for a search space.
|
||||
"""
|
||||
try:
|
||||
membership = await check_search_space_access(session, user, search_space_id)
|
||||
membership = await check_search_space_access(session, auth, search_space_id)
|
||||
|
||||
# Get search space name
|
||||
result = await session.execute(
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ from sqlalchemy import select
|
|||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import (
|
||||
Report,
|
||||
SearchSpace,
|
||||
|
|
@ -42,7 +43,7 @@ from app.templates.export_helpers import (
|
|||
get_reference_docx_path,
|
||||
get_typst_template_path,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.users import get_auth_context
|
||||
from app.utils.rbac import check_search_space_access
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -158,8 +159,9 @@ def _normalize_latex_delimiters(text: str) -> str:
|
|||
async def _get_report_with_access(
|
||||
report_id: int,
|
||||
session: AsyncSession,
|
||||
user: User,
|
||||
auth: AuthContext,
|
||||
) -> Report:
|
||||
user = auth.user
|
||||
"""Fetch a report and verify the user belongs to its search space.
|
||||
|
||||
Raises HTTPException(404) if not found, HTTPException(403) if no access.
|
||||
|
|
@ -172,7 +174,7 @@ async def _get_report_with_access(
|
|||
|
||||
# Lightweight membership check - no granular RBAC, just "is the user a
|
||||
# member of the search space this report belongs to?"
|
||||
await check_search_space_access(session, user, report.search_space_id)
|
||||
await check_search_space_access(session, auth, report.search_space_id)
|
||||
|
||||
return report
|
||||
|
||||
|
|
@ -206,8 +208,9 @@ async def read_reports(
|
|||
limit: int = Query(default=100, ge=1, le=MAX_REPORT_LIST_LIMIT),
|
||||
search_space_id: int | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
List reports the user has access to.
|
||||
Filters by search space membership.
|
||||
|
|
@ -215,7 +218,7 @@ async def read_reports(
|
|||
try:
|
||||
if search_space_id is not None:
|
||||
# Verify the caller is a member of the requested search space
|
||||
await check_search_space_access(session, user, search_space_id)
|
||||
await check_search_space_access(session, auth, search_space_id)
|
||||
|
||||
result = await session.execute(
|
||||
select(Report)
|
||||
|
|
@ -247,8 +250,9 @@ async def read_reports(
|
|||
async def read_report(
|
||||
report_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Get a specific report by ID (metadata only, no content).
|
||||
"""
|
||||
|
|
@ -266,8 +270,9 @@ async def read_report(
|
|||
async def read_report_content(
|
||||
report_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Get full Markdown content of a report, including version siblings.
|
||||
"""
|
||||
|
|
@ -298,8 +303,9 @@ async def update_report_content(
|
|||
report_id: int,
|
||||
body: ReportContentUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Update the Markdown content of a report.
|
||||
|
||||
|
|
@ -339,8 +345,9 @@ async def update_report_content(
|
|||
async def preview_report_pdf(
|
||||
report_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Return a compiled PDF preview for Typst-based reports (resumes).
|
||||
|
||||
|
|
@ -394,8 +401,9 @@ async def export_report(
|
|||
description="Export format: pdf, docx, html, latex, epub, odt, or plain",
|
||||
),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Export a report in the requested format.
|
||||
"""
|
||||
|
|
@ -568,8 +576,9 @@ async def export_report(
|
|||
async def delete_report(
|
||||
report_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Delete a report.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -10,8 +10,9 @@ from fastapi.responses import Response
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import NewChatThread, Permission, User, get_async_session
|
||||
from app.users import current_active_user
|
||||
from app.users import get_auth_context
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -47,8 +48,9 @@ async def download_sandbox_file(
|
|||
thread_id: int,
|
||||
path: str = Query(..., description="Absolute path of the file inside the sandbox"),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Download a file from the Daytona sandbox associated with a chat thread."""
|
||||
|
||||
from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.sandbox import (
|
||||
|
|
@ -68,7 +70,7 @@ async def download_sandbox_file(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_READ.value,
|
||||
"You don't have permission to access files in this thread",
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ from sqlalchemy.exc import IntegrityError
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.config import config
|
||||
from app.connectors.github_connector import GitHubConnector
|
||||
from app.db import (
|
||||
|
|
@ -56,7 +57,7 @@ from app.schemas import (
|
|||
SearchSourceConnectorUpdate,
|
||||
)
|
||||
from app.services.composio_service import ComposioService, get_composio_service
|
||||
from app.users import current_active_user
|
||||
from app.users import get_auth_context
|
||||
|
||||
# NOTE: connector indexer functions are imported lazily inside each
|
||||
# ``run_*_indexing`` helper to break a circular import cycle:
|
||||
|
|
@ -143,8 +144,9 @@ class GitHubPATRequest(BaseModel):
|
|||
@router.post("/github/repositories", response_model=list[dict[str, Any]])
|
||||
async def list_github_repositories(
|
||||
pat_request: GitHubPATRequest,
|
||||
user: User = Depends(current_active_user), # Ensure the user is logged in
|
||||
auth: AuthContext = Depends(get_auth_context), # Ensure the user is logged in
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Fetches a list of repositories accessible by the provided GitHub PAT.
|
||||
The PAT is used for this request only and is not stored.
|
||||
|
|
@ -173,8 +175,9 @@ async def create_search_source_connector(
|
|||
..., description="ID of the search space to associate the connector with"
|
||||
),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Create a new search source connector.
|
||||
Requires CONNECTORS_CREATE permission.
|
||||
|
|
@ -186,7 +189,7 @@ async def create_search_source_connector(
|
|||
# Check if user has permission to create connectors
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.CONNECTORS_CREATE.value,
|
||||
"You don't have permission to create connectors in this search space",
|
||||
|
|
@ -281,8 +284,9 @@ async def read_search_source_connectors(
|
|||
limit: int = 100,
|
||||
search_space_id: int | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
List all search source connectors for a search space.
|
||||
Requires CONNECTORS_READ permission.
|
||||
|
|
@ -297,7 +301,7 @@ async def read_search_source_connectors(
|
|||
# Check if user has permission to read connectors
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.CONNECTORS_READ.value,
|
||||
"You don't have permission to view connectors in this search space",
|
||||
|
|
@ -324,8 +328,9 @@ async def read_search_source_connectors(
|
|||
async def read_search_source_connector(
|
||||
connector_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Get a specific search source connector by ID.
|
||||
Requires CONNECTORS_READ permission.
|
||||
|
|
@ -345,7 +350,7 @@ async def read_search_source_connector(
|
|||
# Check permission
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
connector.search_space_id,
|
||||
Permission.CONNECTORS_READ.value,
|
||||
"You don't have permission to view this connector",
|
||||
|
|
@ -367,8 +372,9 @@ async def update_search_source_connector(
|
|||
connector_id: int,
|
||||
connector_update: SearchSourceConnectorUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Update a search source connector.
|
||||
Requires CONNECTORS_UPDATE permission.
|
||||
|
|
@ -386,7 +392,7 @@ async def update_search_source_connector(
|
|||
# Check permission
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
db_connector.search_space_id,
|
||||
Permission.CONNECTORS_UPDATE.value,
|
||||
"You don't have permission to update this connector",
|
||||
|
|
@ -557,8 +563,9 @@ async def update_search_source_connector(
|
|||
async def delete_search_source_connector(
|
||||
connector_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Delete a search source connector and all its associated documents.
|
||||
|
||||
|
|
@ -588,7 +595,7 @@ async def delete_search_source_connector(
|
|||
# Check permission
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
db_connector.search_space_id,
|
||||
Permission.CONNECTORS_DELETE.value,
|
||||
"You don't have permission to delete this connector",
|
||||
|
|
@ -725,8 +732,9 @@ async def index_connector_content(
|
|||
description="[Google Drive only] Structured request with folders and files to index",
|
||||
),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Index content from a KB connector to a search space.
|
||||
|
||||
|
|
@ -760,7 +768,7 @@ async def index_connector_content(
|
|||
# the read/update/delete handlers — not the client-supplied query param.
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
connector.search_space_id,
|
||||
Permission.CONNECTORS_UPDATE.value,
|
||||
"You don't have permission to index content in this search space",
|
||||
|
|
@ -2645,8 +2653,9 @@ async def create_mcp_connector(
|
|||
connector_data: MCPConnectorCreate,
|
||||
search_space_id: int = Query(..., description="Search space ID"),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Create a new MCP (Model Context Protocol) connector.
|
||||
|
||||
|
|
@ -2669,7 +2678,7 @@ async def create_mcp_connector(
|
|||
# Check user has permission to create connectors
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.CONNECTORS_CREATE.value,
|
||||
"You don't have permission to create connectors in this search space",
|
||||
|
|
@ -2724,8 +2733,9 @@ async def create_mcp_connector(
|
|||
async def list_mcp_connectors(
|
||||
search_space_id: int = Query(..., description="Search space ID"),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
List all MCP connectors for a search space.
|
||||
|
||||
|
|
@ -2741,7 +2751,7 @@ async def list_mcp_connectors(
|
|||
# Check user has permission to read connectors
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.CONNECTORS_READ.value,
|
||||
"You don't have permission to view connectors in this search space",
|
||||
|
|
@ -2775,8 +2785,9 @@ async def list_mcp_connectors(
|
|||
async def get_mcp_connector(
|
||||
connector_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Get a specific MCP connector by ID.
|
||||
|
||||
|
|
@ -2805,7 +2816,7 @@ async def get_mcp_connector(
|
|||
# Check user has permission to read connectors
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
connector.search_space_id,
|
||||
Permission.CONNECTORS_READ.value,
|
||||
"You don't have permission to view this connector",
|
||||
|
|
@ -2828,8 +2839,9 @@ async def update_mcp_connector(
|
|||
connector_id: int,
|
||||
connector_update: MCPConnectorUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Update an MCP connector.
|
||||
|
||||
|
|
@ -2859,7 +2871,7 @@ async def update_mcp_connector(
|
|||
# Check user has permission to update connectors
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
connector.search_space_id,
|
||||
Permission.CONNECTORS_UPDATE.value,
|
||||
"You don't have permission to update this connector",
|
||||
|
|
@ -2904,8 +2916,9 @@ async def update_mcp_connector(
|
|||
async def delete_mcp_connector(
|
||||
connector_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Delete an MCP connector.
|
||||
|
||||
|
|
@ -2931,7 +2944,7 @@ async def delete_mcp_connector(
|
|||
# Check user has permission to delete connectors
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
connector.search_space_id,
|
||||
Permission.CONNECTORS_DELETE.value,
|
||||
"You don't have permission to delete this connector",
|
||||
|
|
@ -2962,8 +2975,9 @@ async def delete_mcp_connector(
|
|||
@router.post("/connectors/mcp/test")
|
||||
async def test_mcp_server_connection(
|
||||
server_config: dict = Body(...),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Test connection to an MCP server and fetch available tools.
|
||||
|
||||
|
|
@ -3042,8 +3056,9 @@ DRIVE_CONNECTOR_TYPES = {
|
|||
async def get_drive_picker_token(
|
||||
connector_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Return an OAuth access token + client ID for the Google Picker API."""
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(SearchSourceConnector.id == connector_id)
|
||||
|
|
@ -3054,7 +3069,7 @@ async def get_drive_picker_token(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
connector.search_space_id,
|
||||
Permission.CONNECTORS_READ.value,
|
||||
"You don't have permission to access this connector",
|
||||
|
|
@ -3164,8 +3179,9 @@ async def trust_mcp_tool(
|
|||
connector_id: int,
|
||||
body: MCPTrustToolRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Add a tool to the MCP connector's trusted (always-allow) list.
|
||||
|
||||
Once trusted, the tool executes without HITL approval on subsequent
|
||||
|
|
@ -3209,8 +3225,9 @@ async def untrust_mcp_tool(
|
|||
connector_id: int,
|
||||
body: MCPTrustToolRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Remove a tool from the MCP connector's trusted list.
|
||||
|
||||
The tool will require HITL approval again on subsequent calls.
|
||||
|
|
|
|||
|
|
@ -5,22 +5,23 @@ from sqlalchemy import func
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import (
|
||||
Permission,
|
||||
SearchSpace,
|
||||
SearchSpaceMembership,
|
||||
SearchSpaceRole,
|
||||
User,
|
||||
get_async_session,
|
||||
get_default_roles_config,
|
||||
)
|
||||
from app.schemas import (
|
||||
SearchSpaceApiAccessUpdate,
|
||||
SearchSpaceCreate,
|
||||
SearchSpaceRead,
|
||||
SearchSpaceUpdate,
|
||||
SearchSpaceWithStats,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.users import allow_any_principal, get_auth_context, require_session_context
|
||||
from app.utils.rbac import check_permission, check_search_space_access
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -74,8 +75,9 @@ async def create_default_roles_and_membership(
|
|||
async def create_search_space(
|
||||
search_space: SearchSpaceCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
user = auth.user
|
||||
try:
|
||||
search_space_data = search_space.model_dump()
|
||||
|
||||
|
|
@ -108,8 +110,9 @@ async def read_search_spaces(
|
|||
limit: int = 200,
|
||||
owned_only: bool = False,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(allow_any_principal),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Get all search spaces the user has access to, with member count and ownership info.
|
||||
|
||||
|
|
@ -123,11 +126,17 @@ async def read_search_spaces(
|
|||
# Exclude spaces that are pending background deletion
|
||||
not_deleting = ~SearchSpace.name.startswith("[DELETING] ")
|
||||
|
||||
api_access_filter = (
|
||||
SearchSpace.api_access_enabled == True # noqa: E712
|
||||
if auth.is_gated
|
||||
else True
|
||||
)
|
||||
|
||||
if owned_only:
|
||||
# Return only search spaces where user is the original creator (user_id)
|
||||
result = await session.execute(
|
||||
select(SearchSpace)
|
||||
.filter(SearchSpace.user_id == user.id, not_deleting)
|
||||
.filter(SearchSpace.user_id == user.id, not_deleting, api_access_filter)
|
||||
.order_by(SearchSpace.id.asc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
|
|
@ -137,7 +146,11 @@ async def read_search_spaces(
|
|||
result = await session.execute(
|
||||
select(SearchSpace)
|
||||
.join(SearchSpaceMembership)
|
||||
.filter(SearchSpaceMembership.user_id == user.id, not_deleting)
|
||||
.filter(
|
||||
SearchSpaceMembership.user_id == user.id,
|
||||
not_deleting,
|
||||
api_access_filter,
|
||||
)
|
||||
.order_by(SearchSpace.id.asc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
|
|
@ -174,6 +187,7 @@ async def read_search_spaces(
|
|||
created_at=space.created_at,
|
||||
user_id=space.user_id,
|
||||
citations_enabled=space.citations_enabled,
|
||||
api_access_enabled=space.api_access_enabled,
|
||||
qna_custom_instructions=space.qna_custom_instructions,
|
||||
ai_file_sort_enabled=space.ai_file_sort_enabled,
|
||||
member_count=member_count,
|
||||
|
|
@ -192,7 +206,7 @@ async def read_search_spaces(
|
|||
async def read_search_space(
|
||||
search_space_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
"""
|
||||
Get a specific search space by ID.
|
||||
|
|
@ -200,7 +214,7 @@ async def read_search_space(
|
|||
"""
|
||||
try:
|
||||
# Check if user has access (is a member)
|
||||
await check_search_space_access(session, user, search_space_id)
|
||||
await check_search_space_access(session, auth, search_space_id)
|
||||
|
||||
result = await session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
||||
|
|
@ -225,7 +239,7 @@ async def update_search_space(
|
|||
search_space_id: int,
|
||||
search_space_update: SearchSpaceUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
"""
|
||||
Update a search space.
|
||||
|
|
@ -235,7 +249,7 @@ async def update_search_space(
|
|||
# Check permission
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.SETTINGS_UPDATE.value,
|
||||
"You don't have permission to update this search space",
|
||||
|
|
@ -265,17 +279,65 @@ async def update_search_space(
|
|||
) from e
|
||||
|
||||
|
||||
@router.put("/searchspaces/{search_space_id}/api-access", response_model=SearchSpaceRead)
|
||||
async def update_search_space_api_access(
|
||||
search_space_id: int,
|
||||
body: SearchSpaceApiAccessUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
"""
|
||||
Toggle programmatic API/PAT access for a search space.
|
||||
Requires API_ACCESS_MANAGE permission.
|
||||
"""
|
||||
try:
|
||||
if not auth.is_session:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="This action requires an interactive session",
|
||||
)
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.API_ACCESS_MANAGE.value,
|
||||
"You don't have permission to manage API access for this search space",
|
||||
)
|
||||
|
||||
result = await session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
||||
)
|
||||
db_search_space = result.scalars().first()
|
||||
|
||||
if not db_search_space:
|
||||
raise HTTPException(status_code=404, detail="Search space not found")
|
||||
|
||||
db_search_space.api_access_enabled = body.api_access_enabled
|
||||
await session.commit()
|
||||
await session.refresh(db_search_space)
|
||||
return db_search_space
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to update API access: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/searchspaces/{search_space_id}/ai-sort")
|
||||
async def trigger_ai_sort(
|
||||
search_space_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Trigger a full AI file sort for all documents in the search space."""
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.SETTINGS_UPDATE.value,
|
||||
"You don't have permission to trigger AI sort on this search space",
|
||||
|
|
@ -305,7 +367,7 @@ async def trigger_ai_sort(
|
|||
async def delete_search_space(
|
||||
search_space_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
"""
|
||||
Delete a search space.
|
||||
|
|
@ -318,7 +380,7 @@ async def delete_search_space(
|
|||
# Check permission - only those with SETTINGS_DELETE can delete
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.SETTINGS_DELETE.value,
|
||||
"You don't have permission to delete this search space",
|
||||
|
|
@ -374,7 +436,7 @@ async def delete_search_space(
|
|||
async def list_search_space_snapshots(
|
||||
search_space_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
"""
|
||||
List all public chat snapshots for a search space.
|
||||
|
|
@ -387,6 +449,6 @@ async def list_search_space_snapshots(
|
|||
snapshots = await list_snapshots_for_search_space(
|
||||
session=session,
|
||||
search_space_id=search_space_id,
|
||||
user=user,
|
||||
auth=auth,
|
||||
)
|
||||
return PublicChatSnapshotsBySpaceResponse(snapshots=snapshots)
|
||||
|
|
|
|||
|
|
@ -17,21 +17,22 @@ from sqlalchemy.exc import IntegrityError
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.schemas.slack_auth_credentials import SlackAuthCredentialsBase
|
||||
from app.users import current_active_user
|
||||
from app.users import get_auth_context, require_session_context
|
||||
from app.utils.connector_naming import (
|
||||
check_duplicate_connector,
|
||||
extract_identifier_from_credentials,
|
||||
generate_unique_connector_name,
|
||||
)
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
|
||||
from app.utils.rbac import check_search_space_access
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -78,7 +79,10 @@ def get_token_encryption() -> TokenEncryption:
|
|||
|
||||
|
||||
@router.get("/auth/slack/connector/add")
|
||||
async def connect_slack(space_id: int, user: User = Depends(current_active_user)):
|
||||
async def connect_slack(
|
||||
space_id: int,
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
"""
|
||||
Initiate Slack OAuth flow.
|
||||
|
||||
|
|
@ -89,6 +93,7 @@ async def connect_slack(space_id: int, user: User = Depends(current_active_user)
|
|||
Returns:
|
||||
Authorization URL for redirect
|
||||
"""
|
||||
user = auth.user
|
||||
try:
|
||||
if not space_id:
|
||||
raise HTTPException(status_code=400, detail="space_id is required")
|
||||
|
|
@ -525,7 +530,7 @@ async def refresh_slack_token(
|
|||
async def get_slack_channels(
|
||||
connector_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Get list of Slack channels with bot membership status.
|
||||
|
|
@ -541,6 +546,7 @@ async def get_slack_channels(
|
|||
Returns:
|
||||
List of channels with id, name, is_private, and is_member fields
|
||||
"""
|
||||
user = auth.user
|
||||
try:
|
||||
# Get the connector and verify ownership
|
||||
result = await session.execute(
|
||||
|
|
@ -559,6 +565,8 @@ async def get_slack_channels(
|
|||
detail="Slack connector not found or access denied",
|
||||
)
|
||||
|
||||
await check_search_space_access(session, auth, connector.search_space_id)
|
||||
|
||||
# Get credentials and decrypt bot token
|
||||
credentials = SlackAuthCredentialsBase.from_dict(connector.config)
|
||||
token_encryption = get_token_encryption()
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from sqlalchemy import select
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from stripe import SignatureVerificationError, StripeClient, StripeError
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
CreditPurchase,
|
||||
|
|
@ -39,7 +40,7 @@ from app.schemas.stripe import (
|
|||
StripeWebhookResponse,
|
||||
UpdateAutoReloadSettingsRequest,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.users import require_session_context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -456,7 +457,7 @@ async def _reconcile_auto_reload_payment_intent(
|
|||
)
|
||||
async def create_credit_checkout_session(
|
||||
body: CreateCreditCheckoutSessionRequest,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
db_session: AsyncSession = Depends(get_async_session),
|
||||
) -> CreateCreditCheckoutSessionResponse:
|
||||
"""Create a Stripe Checkout Session for buying credit packs.
|
||||
|
|
@ -466,6 +467,7 @@ async def create_credit_checkout_session(
|
|||
cost reported by LiteLLM (premium calls) or ``MICROS_PER_PAGE`` per page
|
||||
(ETL), so $1 of credit always buys $1 worth of usage at cost.
|
||||
"""
|
||||
user = auth.user
|
||||
_ensure_credit_buying_enabled()
|
||||
stripe_client = get_stripe_client()
|
||||
price_id = _get_required_credit_price_id()
|
||||
|
|
@ -644,7 +646,7 @@ async def stripe_webhook(
|
|||
@router.get("/finalize-checkout", response_model=FinalizeCheckoutResponse)
|
||||
async def finalize_checkout(
|
||||
session_id: str,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
db_session: AsyncSession = Depends(get_async_session),
|
||||
) -> FinalizeCheckoutResponse:
|
||||
"""Synchronously fulfil a credit checkout session from the success page.
|
||||
|
|
@ -659,6 +661,7 @@ async def finalize_checkout(
|
|||
Authorization: the session's ``client_reference_id`` must match the
|
||||
authenticated user's id.
|
||||
"""
|
||||
user = auth.user
|
||||
stripe_client = get_stripe_client()
|
||||
|
||||
try:
|
||||
|
|
@ -718,13 +721,14 @@ async def finalize_checkout(
|
|||
|
||||
@router.get("/credit-status", response_model=CreditStripeStatusResponse)
|
||||
async def get_credit_status(
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
) -> CreditStripeStatusResponse:
|
||||
"""Return credit-buying availability and current balance for the frontend.
|
||||
|
||||
``credit_micros_balance`` is in micro-USD (1_000_000 = $1.00); the FE
|
||||
divides by 1M when displaying.
|
||||
"""
|
||||
user = auth.user
|
||||
return CreditStripeStatusResponse(
|
||||
credit_buying_enabled=config.STRIPE_CREDIT_BUYING_ENABLED,
|
||||
credit_micros_balance=user.credit_micros_balance,
|
||||
|
|
@ -733,12 +737,13 @@ async def get_credit_status(
|
|||
|
||||
@router.get("/credit-purchases", response_model=CreditPurchaseHistoryResponse)
|
||||
async def get_credit_purchases(
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
db_session: AsyncSession = Depends(get_async_session),
|
||||
offset: int = 0,
|
||||
limit: int = 50,
|
||||
) -> CreditPurchaseHistoryResponse:
|
||||
"""Return the authenticated user's credit purchase history."""
|
||||
user = auth.user
|
||||
limit = min(limit, 100)
|
||||
purchases = (
|
||||
(
|
||||
|
|
@ -759,7 +764,7 @@ async def get_credit_purchases(
|
|||
|
||||
@router.get("/purchases", response_model=PagePurchaseHistoryResponse)
|
||||
async def get_page_purchases(
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
db_session: AsyncSession = Depends(get_async_session),
|
||||
offset: int = 0,
|
||||
limit: int = 50,
|
||||
|
|
@ -768,6 +773,7 @@ async def get_page_purchases(
|
|||
|
||||
Page buying is removed; this endpoint stays for historical records.
|
||||
"""
|
||||
user = auth.user
|
||||
limit = min(limit, 100)
|
||||
purchases = (
|
||||
(
|
||||
|
|
@ -804,7 +810,7 @@ def _auto_reload_settings_response(user: User) -> AutoReloadSettingsResponse:
|
|||
)
|
||||
async def create_auto_reload_setup_session(
|
||||
body: CreateAutoReloadSetupSessionRequest,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
db_session: AsyncSession = Depends(get_async_session),
|
||||
) -> CreateAutoReloadSetupSessionResponse:
|
||||
"""Start a ``mode=setup`` checkout session to save a card for auto-reload.
|
||||
|
|
@ -813,6 +819,7 @@ async def create_auto_reload_setup_session(
|
|||
Customer so the card can later be charged off-session. On completion the
|
||||
webhook stores the resulting payment method on the user.
|
||||
"""
|
||||
user = auth.user
|
||||
_ensure_auto_reload_enabled()
|
||||
_ensure_credit_buying_enabled()
|
||||
stripe_client = get_stripe_client()
|
||||
|
|
@ -871,16 +878,17 @@ async def create_auto_reload_setup_session(
|
|||
|
||||
@router.get("/auto-reload", response_model=AutoReloadSettingsResponse)
|
||||
async def get_auto_reload_settings(
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
) -> AutoReloadSettingsResponse:
|
||||
"""Return the user's auto-reload configuration and saved-card state."""
|
||||
user = auth.user
|
||||
return _auto_reload_settings_response(user)
|
||||
|
||||
|
||||
@router.put("/auto-reload", response_model=AutoReloadSettingsResponse)
|
||||
async def update_auto_reload_settings(
|
||||
body: UpdateAutoReloadSettingsRequest,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
db_session: AsyncSession = Depends(get_async_session),
|
||||
) -> AutoReloadSettingsResponse:
|
||||
"""Update auto-reload preferences.
|
||||
|
|
@ -889,6 +897,7 @@ async def update_auto_reload_settings(
|
|||
at least ``AUTO_RELOAD_MIN_AMOUNT_MICROS``. Disabling always succeeds and
|
||||
clears any prior failure flag.
|
||||
"""
|
||||
user = auth.user
|
||||
_ensure_auto_reload_enabled()
|
||||
|
||||
locked = (
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from fastapi import APIRouter, Depends, HTTPException
|
|||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import User, get_async_session
|
||||
from app.services.memory import (
|
||||
MemoryRead,
|
||||
|
|
@ -15,7 +16,7 @@ from app.services.memory import (
|
|||
reset_memory,
|
||||
save_memory,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.users import get_auth_context
|
||||
from app.utils.rbac import check_search_space_access
|
||||
|
||||
router = APIRouter()
|
||||
|
|
@ -29,9 +30,10 @@ class TeamMemoryUpdate(BaseModel):
|
|||
async def get_team_memory(
|
||||
search_space_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
await check_search_space_access(session, user, search_space_id)
|
||||
user = auth.user
|
||||
await check_search_space_access(session, auth, search_space_id)
|
||||
memory_md = await read_memory(
|
||||
scope=MemoryScope.TEAM,
|
||||
target_id=search_space_id,
|
||||
|
|
@ -45,9 +47,10 @@ async def update_team_memory(
|
|||
search_space_id: int,
|
||||
body: TeamMemoryUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
await check_search_space_access(session, user, search_space_id)
|
||||
user = auth.user
|
||||
await check_search_space_access(session, auth, search_space_id)
|
||||
result = await save_memory(
|
||||
scope=MemoryScope.TEAM,
|
||||
target_id=search_space_id,
|
||||
|
|
@ -63,9 +66,10 @@ async def update_team_memory(
|
|||
async def reset_team_memory(
|
||||
search_space_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
await check_search_space_access(session, user, search_space_id)
|
||||
user = auth.user
|
||||
await check_search_space_access(session, auth, search_space_id)
|
||||
result = await reset_memory(
|
||||
scope=MemoryScope.TEAM,
|
||||
target_id=search_space_id,
|
||||
|
|
|
|||
|
|
@ -14,15 +14,15 @@ from fastapi.responses import RedirectResponse
|
|||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.schemas.teams_auth_credentials import TeamsAuthCredentialsBase
|
||||
from app.users import current_active_user
|
||||
from app.users import require_session_context
|
||||
from app.utils.connector_naming import (
|
||||
check_duplicate_connector,
|
||||
extract_identifier_from_credentials,
|
||||
|
|
@ -74,7 +74,10 @@ def get_token_encryption() -> TokenEncryption:
|
|||
|
||||
|
||||
@router.get("/auth/teams/connector/add")
|
||||
async def connect_teams(space_id: int, user: User = Depends(current_active_user)):
|
||||
async def connect_teams(
|
||||
space_id: int,
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
"""
|
||||
Initiate Microsoft Teams OAuth flow.
|
||||
|
||||
|
|
@ -85,6 +88,7 @@ async def connect_teams(space_id: int, user: User = Depends(current_active_user)
|
|||
Returns:
|
||||
Authorization URL for redirect
|
||||
"""
|
||||
user = auth.user
|
||||
try:
|
||||
if not space_id:
|
||||
raise HTTPException(status_code=400, detail="space_id is required")
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from sqlalchemy import select
|
|||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import (
|
||||
Permission,
|
||||
SearchSpace,
|
||||
|
|
@ -25,7 +26,7 @@ from app.db import (
|
|||
get_async_session,
|
||||
)
|
||||
from app.schemas import VideoPresentationRead
|
||||
from app.users import current_active_user
|
||||
from app.users import get_auth_context
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
router = APIRouter()
|
||||
|
|
@ -37,8 +38,9 @@ async def read_video_presentations(
|
|||
limit: int = 100,
|
||||
search_space_id: int | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
List video presentations the user has access to.
|
||||
Requires VIDEO_PRESENTATIONS_READ permission for the search space(s).
|
||||
|
|
@ -49,7 +51,7 @@ async def read_video_presentations(
|
|||
if search_space_id is not None:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.VIDEO_PRESENTATIONS_READ.value,
|
||||
"You don't have permission to read video presentations in this search space",
|
||||
|
|
@ -89,8 +91,9 @@ async def read_video_presentations(
|
|||
async def read_video_presentation(
|
||||
video_presentation_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Get a specific video presentation by ID.
|
||||
Requires authentication with VIDEO_PRESENTATIONS_READ permission.
|
||||
|
|
@ -112,7 +115,7 @@ async def read_video_presentation(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
video_pres.search_space_id,
|
||||
Permission.VIDEO_PRESENTATIONS_READ.value,
|
||||
"You don't have permission to read video presentations in this search space",
|
||||
|
|
@ -132,8 +135,9 @@ async def read_video_presentation(
|
|||
async def delete_video_presentation(
|
||||
video_presentation_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Delete a video presentation.
|
||||
Requires VIDEO_PRESENTATIONS_DELETE permission for the search space.
|
||||
|
|
@ -151,7 +155,7 @@ async def delete_video_presentation(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
db_video_pres.search_space_id,
|
||||
Permission.VIDEO_PRESENTATIONS_DELETE.value,
|
||||
"You don't have permission to delete video presentations in this search space",
|
||||
|
|
@ -175,8 +179,9 @@ async def stream_slide_audio(
|
|||
video_presentation_id: int,
|
||||
slide_number: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Stream the audio file for a specific slide in a video presentation.
|
||||
The slide_number is 1-based. Audio path is read from the slides JSONB.
|
||||
|
|
@ -194,7 +199,7 @@ async def stream_slide_audio(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
video_pres.search_space_id,
|
||||
Permission.VIDEO_PRESENTATIONS_READ.value,
|
||||
"You don't have permission to access video presentations in this search space",
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ import time
|
|||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from scrapling.fetchers import AsyncFetcher
|
||||
|
||||
from app.db import User
|
||||
from app.users import current_active_user
|
||||
from app.auth.context import AuthContext
|
||||
from app.users import require_session_context
|
||||
from app.utils.proxy import get_proxy_url
|
||||
|
||||
router = APIRouter()
|
||||
|
|
@ -29,7 +29,7 @@ _INNERTUBE_CLIENT = {
|
|||
@router.get("/youtube/playlist-videos")
|
||||
async def get_playlist_videos(
|
||||
url: str = Query(..., description="YouTube playlist URL"),
|
||||
_user: User = Depends(current_active_user),
|
||||
_auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
"""Resolve a YouTube playlist URL into individual video URLs."""
|
||||
match = _PLAYLIST_ID_RE.search(url)
|
||||
|
|
|
|||
|
|
@ -104,6 +104,7 @@ from .search_source_connector import (
|
|||
SearchSourceConnectorUpdate,
|
||||
)
|
||||
from .search_space import (
|
||||
SearchSpaceApiAccessUpdate,
|
||||
SearchSpaceBase,
|
||||
SearchSpaceCreate,
|
||||
SearchSpaceRead,
|
||||
|
|
@ -243,6 +244,7 @@ __all__ = [
|
|||
"SearchSourceConnectorUpdate",
|
||||
# Search space schemas
|
||||
"SearchSpaceBase",
|
||||
"SearchSpaceApiAccessUpdate",
|
||||
"SearchSpaceCreate",
|
||||
"SearchSpaceRead",
|
||||
"SearchSpaceUpdate",
|
||||
|
|
|
|||
27
surfsense_backend/app/schemas/pat.py
Normal file
27
surfsense_backend/app/schemas/pat.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class PATCreate(BaseModel):
|
||||
label: str = Field(min_length=1, max_length=120)
|
||||
expires_in_days: int | None = Field(default=None, gt=0)
|
||||
|
||||
|
||||
class PATCreated(BaseModel):
|
||||
id: int
|
||||
label: str
|
||||
token: str
|
||||
prefix: str
|
||||
expires_at: datetime | None = None
|
||||
|
||||
|
||||
class PATRead(BaseModel):
|
||||
id: int
|
||||
label: str
|
||||
prefix: str
|
||||
expires_at: datetime | None = None
|
||||
last_used_at: datetime | None = None
|
||||
created_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
|
@ -24,11 +24,16 @@ class SearchSpaceUpdate(BaseModel):
|
|||
ai_file_sort_enabled: bool | None = None
|
||||
|
||||
|
||||
class SearchSpaceApiAccessUpdate(BaseModel):
|
||||
api_access_enabled: bool
|
||||
|
||||
|
||||
class SearchSpaceRead(SearchSpaceBase, IDModel, TimestampModel):
|
||||
id: int
|
||||
created_at: datetime
|
||||
user_id: uuid.UUID
|
||||
citations_enabled: bool
|
||||
api_access_enabled: bool = False
|
||||
qna_custom_instructions: str | None = None
|
||||
shared_memory_md: str | None = None
|
||||
ai_file_sort_enabled: bool = False
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from sqlalchemy import delete, or_, select
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import (
|
||||
ChatComment,
|
||||
ChatCommentMention,
|
||||
|
|
@ -138,8 +139,9 @@ async def get_comment_thread_participants(
|
|||
async def get_comments_for_message(
|
||||
session: AsyncSession,
|
||||
message_id: int,
|
||||
user: User,
|
||||
auth: AuthContext,
|
||||
) -> CommentListResponse:
|
||||
user = auth.user
|
||||
"""
|
||||
Get all comments for a message with their replies.
|
||||
|
||||
|
|
@ -169,7 +171,7 @@ async def get_comments_for_message(
|
|||
# Check permission to read comments
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.COMMENTS_READ.value,
|
||||
"You don't have permission to read comments in this search space",
|
||||
|
|
@ -268,8 +270,9 @@ async def get_comments_for_message(
|
|||
async def get_comments_for_messages_batch(
|
||||
session: AsyncSession,
|
||||
message_ids: list[int],
|
||||
user: User,
|
||||
auth: AuthContext,
|
||||
) -> CommentBatchResponse:
|
||||
user = auth.user
|
||||
"""
|
||||
Batch-fetch comments for multiple messages in a single DB round-trip.
|
||||
|
||||
|
|
@ -295,7 +298,7 @@ async def get_comments_for_messages_batch(
|
|||
for ss_id in search_space_ids:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
ss_id,
|
||||
Permission.COMMENTS_READ.value,
|
||||
"You don't have permission to read comments in this search space",
|
||||
|
|
@ -409,8 +412,9 @@ async def create_comment(
|
|||
session: AsyncSession,
|
||||
message_id: int,
|
||||
content: str,
|
||||
user: User,
|
||||
auth: AuthContext,
|
||||
) -> CommentResponse:
|
||||
user = auth.user
|
||||
"""
|
||||
Create a top-level comment on an AI response.
|
||||
|
||||
|
|
@ -521,8 +525,9 @@ async def create_reply(
|
|||
session: AsyncSession,
|
||||
comment_id: int,
|
||||
content: str,
|
||||
user: User,
|
||||
auth: AuthContext,
|
||||
) -> CommentReplyResponse:
|
||||
user = auth.user
|
||||
"""
|
||||
Create a reply to an existing comment.
|
||||
|
||||
|
|
@ -657,8 +662,9 @@ async def update_comment(
|
|||
session: AsyncSession,
|
||||
comment_id: int,
|
||||
content: str,
|
||||
user: User,
|
||||
auth: AuthContext,
|
||||
) -> CommentReplyResponse:
|
||||
user = auth.user
|
||||
"""
|
||||
Update a comment's content (author only).
|
||||
|
||||
|
|
@ -797,8 +803,9 @@ async def update_comment(
|
|||
async def delete_comment(
|
||||
session: AsyncSession,
|
||||
comment_id: int,
|
||||
user: User,
|
||||
auth: AuthContext,
|
||||
) -> dict:
|
||||
user = auth.user
|
||||
"""
|
||||
Delete a comment (author or user with COMMENTS_DELETE permission).
|
||||
|
||||
|
|
@ -844,9 +851,10 @@ async def delete_comment(
|
|||
|
||||
async def get_user_mentions(
|
||||
session: AsyncSession,
|
||||
user: User,
|
||||
auth: AuthContext,
|
||||
search_space_id: int | None = None,
|
||||
) -> MentionListResponse:
|
||||
user = auth.user
|
||||
"""
|
||||
Get mentions for the current user, optionally filtered by search space.
|
||||
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from sqlalchemy import delete, select
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import (
|
||||
ChatVisibility,
|
||||
NewChatMessage,
|
||||
|
|
@ -163,8 +164,9 @@ def compute_content_hash(messages: list[dict]) -> str:
|
|||
async def create_snapshot(
|
||||
session: AsyncSession,
|
||||
thread_id: int,
|
||||
user: User,
|
||||
auth: AuthContext,
|
||||
) -> dict:
|
||||
user = auth.user
|
||||
"""
|
||||
Create a public snapshot of a chat thread.
|
||||
|
||||
|
|
@ -186,7 +188,7 @@ async def create_snapshot(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
thread.search_space_id,
|
||||
Permission.PUBLIC_SHARING_CREATE.value,
|
||||
"You don't have permission to create public share links",
|
||||
|
|
@ -431,8 +433,9 @@ async def get_public_chat(
|
|||
async def list_snapshots_for_thread(
|
||||
session: AsyncSession,
|
||||
thread_id: int,
|
||||
user: User,
|
||||
auth: AuthContext,
|
||||
) -> list[dict]:
|
||||
user = auth.user
|
||||
"""List all public snapshots for a thread."""
|
||||
from app.config import config
|
||||
|
||||
|
|
@ -447,7 +450,7 @@ async def list_snapshots_for_thread(
|
|||
# Check permission to view public share links
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
thread.search_space_id,
|
||||
Permission.PUBLIC_SHARING_VIEW.value,
|
||||
"You don't have permission to view public share links",
|
||||
|
|
@ -477,14 +480,15 @@ async def list_snapshots_for_thread(
|
|||
async def list_snapshots_for_search_space(
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
user: User,
|
||||
auth: AuthContext,
|
||||
) -> list[dict]:
|
||||
user = auth.user
|
||||
"""List all public snapshots for a search space."""
|
||||
from app.config import config
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.PUBLIC_SHARING_VIEW.value,
|
||||
"You don't have permission to view public share links",
|
||||
|
|
@ -534,8 +538,9 @@ async def delete_snapshot(
|
|||
session: AsyncSession,
|
||||
thread_id: int,
|
||||
snapshot_id: int,
|
||||
user: User,
|
||||
auth: AuthContext,
|
||||
) -> bool:
|
||||
user = auth.user
|
||||
"""Delete a specific snapshot. Only thread owner can delete."""
|
||||
# Get snapshot with thread
|
||||
result = await session.execute(
|
||||
|
|
@ -553,7 +558,7 @@ async def delete_snapshot(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
snapshot.thread.search_space_id,
|
||||
Permission.PUBLIC_SHARING_DELETE.value,
|
||||
"You don't have permission to delete public share links",
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from app.agents.chat.multi_agent_chat.shared.filesystem_selection import (
|
|||
FilesystemSelection,
|
||||
)
|
||||
from app.agents.chat.runtime.llm_config import AgentConfig
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import ChatVisibility
|
||||
from app.services.connector_service import ConnectorService
|
||||
|
||||
|
|
@ -33,6 +34,7 @@ async def build_main_agent_for_thread(
|
|||
filesystem_selection: FilesystemSelection | None,
|
||||
disabled_tools: list[str] | None = None,
|
||||
mentioned_document_ids: list[int] | None = None,
|
||||
auth_context: AuthContext | None = None,
|
||||
) -> Any:
|
||||
return await agent_factory(
|
||||
llm=llm,
|
||||
|
|
@ -48,4 +50,5 @@ async def build_main_agent_for_thread(
|
|||
filesystem_selection=filesystem_selection,
|
||||
disabled_tools=disabled_tools,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
auth_context=auth_context,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ from app.agents.chat.multi_agent_chat.shared.filesystem_selection import (
|
|||
FilesystemMode,
|
||||
FilesystemSelection,
|
||||
)
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import ChatVisibility, async_session_maker
|
||||
from app.observability import otel as ot
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
|
|
@ -136,6 +137,7 @@ async def stream_new_chat(
|
|||
filesystem_selection: FilesystemSelection | None = None,
|
||||
request_id: str | None = None,
|
||||
user_image_data_urls: list[str] | None = None,
|
||||
auth_context: AuthContext | None = None,
|
||||
flow: Literal["new", "regenerate"] = "new",
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream a new chat turn using the SurfSense deep agent.
|
||||
|
|
@ -412,6 +414,7 @@ async def stream_new_chat(
|
|||
filesystem_selection=filesystem_selection,
|
||||
disabled_tools=disabled_tools,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
auth_context=auth_context,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0
|
||||
|
|
@ -664,6 +667,7 @@ async def stream_new_chat(
|
|||
filesystem_selection=filesystem_selection,
|
||||
disabled_tools=disabled_tools,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
auth_context=auth_context,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Runtime rate-limit recovery repinned "
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from app.agents.chat.multi_agent_chat.shared.filesystem_selection import (
|
|||
FilesystemMode,
|
||||
FilesystemSelection,
|
||||
)
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import ChatVisibility, async_session_maker
|
||||
from app.observability import otel as ot
|
||||
from app.services.chat_session_state_service import set_ai_responding
|
||||
|
|
@ -102,6 +103,7 @@ async def stream_resume_chat(
|
|||
filesystem_selection: FilesystemSelection | None = None,
|
||||
request_id: str | None = None,
|
||||
disabled_tools: list[str] | None = None,
|
||||
auth_context: AuthContext | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Resume a paused HITL turn with the user's decisions.
|
||||
|
||||
|
|
@ -346,6 +348,7 @@ async def stream_resume_chat(
|
|||
thread_visibility=visibility,
|
||||
filesystem_selection=filesystem_selection,
|
||||
disabled_tools=disabled_tools,
|
||||
auth_context=auth_context,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0
|
||||
|
|
@ -481,6 +484,7 @@ async def stream_resume_chat(
|
|||
thread_visibility=visibility,
|
||||
filesystem_selection=filesystem_selection,
|
||||
disabled_tools=disabled_tools,
|
||||
auth_context=auth_context,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[stream_resume] Runtime rate-limit recovery repinned "
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import uuid
|
|||
from datetime import UTC, datetime
|
||||
|
||||
import httpx
|
||||
from fastapi import Depends, Request, Response
|
||||
from fastapi import Depends, HTTPException, Request, Response, status
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, models
|
||||
from fastapi_users.authentication import (
|
||||
|
|
@ -14,7 +14,9 @@ from fastapi_users.authentication import (
|
|||
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
Prompt,
|
||||
|
|
@ -23,10 +25,12 @@ from app.db import (
|
|||
SearchSpaceRole,
|
||||
User,
|
||||
async_session_maker,
|
||||
get_async_session,
|
||||
get_default_roles_config,
|
||||
get_user_db,
|
||||
)
|
||||
from app.prompts.system_defaults import SYSTEM_PROMPT_DEFAULTS
|
||||
from app.utils.pat import PAT_PREFIX, maybe_touch_last_used, resolve_pat
|
||||
from app.utils.refresh_tokens import create_refresh_token
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -298,5 +302,75 @@ auth_backend = AuthenticationBackend(
|
|||
|
||||
fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])
|
||||
|
||||
current_active_user = fastapi_users.current_user(active=True)
|
||||
|
||||
async def get_auth_context(
|
||||
request: Request,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user_manager: UserManager = Depends(get_user_manager),
|
||||
) -> AuthContext:
|
||||
"""Resolve the authenticated principal.
|
||||
|
||||
Use this for authorization-sensitive routes where session-vs-PAT matters.
|
||||
FastAPI-Users still handles JWT mechanics; PATs are resolved here so RBAC
|
||||
receives the full SurfSense principal instead of a bare User.
|
||||
"""
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if not auth_header:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Unauthorized",
|
||||
)
|
||||
|
||||
scheme, _, token = auth_header.partition(" ")
|
||||
if scheme.lower() != "bearer" or not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Unauthorized",
|
||||
)
|
||||
|
||||
if token.startswith(PAT_PREFIX):
|
||||
pat = await resolve_pat(session, token)
|
||||
if pat and pat.user and pat.user.is_active:
|
||||
maybe_touch_last_used(pat)
|
||||
return AuthContext.pat_auth(pat.user, pat)
|
||||
|
||||
try:
|
||||
user = await get_jwt_strategy().read_token(token, user_manager)
|
||||
except Exception:
|
||||
logger.exception("Failed to read access token")
|
||||
user = None
|
||||
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Unauthorized",
|
||||
)
|
||||
|
||||
return AuthContext.session(user)
|
||||
|
||||
|
||||
async def allow_any_principal(
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> AuthContext:
|
||||
"""Allow either session or PAT principals for bootstrap probes only.
|
||||
|
||||
Routes using this dependency intentionally have no search-space gate.
|
||||
Adding a new call site is a security decision and must be covered by
|
||||
the fail-closed PAT allowlist test.
|
||||
"""
|
||||
return auth
|
||||
|
||||
|
||||
async def require_session_context(
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> AuthContext:
|
||||
"""Require an interactive session and reject PAT-authenticated requests."""
|
||||
if not auth.is_session:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="This action requires an interactive session",
|
||||
)
|
||||
return auth
|
||||
|
||||
|
||||
current_optional_user = fastapi_users.current_user(active=True, optional=True)
|
||||
|
|
|
|||
73
surfsense_backend/app/utils/pat.py
Normal file
73
surfsense_backend/app/utils/pat.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.db import PersonalAccessToken, User, async_session_maker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PAT_PREFIX = "ss_pat_"
|
||||
PAT_TOKEN_BYTES = 32
|
||||
LAST_USED_THROTTLE = timedelta(minutes=10)
|
||||
|
||||
|
||||
def generate_pat() -> str:
|
||||
return f"{PAT_PREFIX}{secrets.token_urlsafe(PAT_TOKEN_BYTES)}"
|
||||
|
||||
|
||||
def hash_pat(token: str) -> str:
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
|
||||
def token_prefix(token: str) -> str:
|
||||
return token[:16]
|
||||
|
||||
|
||||
async def resolve_pat(
|
||||
session: AsyncSession,
|
||||
token: str,
|
||||
) -> PersonalAccessToken | None:
|
||||
now = datetime.now(UTC)
|
||||
result = await session.execute(
|
||||
select(PersonalAccessToken)
|
||||
.options(selectinload(PersonalAccessToken.user))
|
||||
.join(User)
|
||||
.where(
|
||||
PersonalAccessToken.token_hash == hash_pat(token),
|
||||
(PersonalAccessToken.expires_at.is_(None))
|
||||
| (PersonalAccessToken.expires_at > now),
|
||||
User.is_active == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
|
||||
async def _touch_last_used(token_id: int) -> None:
|
||||
try:
|
||||
async with async_session_maker() as session:
|
||||
await session.execute(
|
||||
update(PersonalAccessToken)
|
||||
.where(PersonalAccessToken.id == token_id)
|
||||
.values(last_used_at=datetime.now(UTC))
|
||||
)
|
||||
await session.commit()
|
||||
except Exception:
|
||||
logger.exception("Failed to update PAT last_used_at for token %s", token_id)
|
||||
|
||||
|
||||
def maybe_touch_last_used(pat: PersonalAccessToken) -> None:
|
||||
last_used_at = pat.last_used_at
|
||||
now = datetime.now(UTC)
|
||||
if last_used_at is not None and now - last_used_at < LAST_USED_THROTTLE:
|
||||
return
|
||||
|
||||
asyncio.create_task(_touch_last_used(pat.id))
|
||||
|
|
@ -11,12 +11,12 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import (
|
||||
Permission,
|
||||
SearchSpace,
|
||||
SearchSpaceMembership,
|
||||
SearchSpaceRole,
|
||||
User,
|
||||
has_permission,
|
||||
)
|
||||
|
||||
|
|
@ -80,9 +80,33 @@ async def get_user_permissions(
|
|||
return []
|
||||
|
||||
|
||||
async def _enforce_api_access_gate(
|
||||
session: AsyncSession,
|
||||
auth: AuthContext,
|
||||
search_space_id: int,
|
||||
search_space: SearchSpace | None = None,
|
||||
) -> SearchSpace:
|
||||
if search_space is None:
|
||||
result = await session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
||||
)
|
||||
search_space = result.scalars().first()
|
||||
|
||||
if not search_space:
|
||||
raise HTTPException(status_code=404, detail="Search space not found")
|
||||
|
||||
if auth.is_gated and not search_space.api_access_enabled:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="API access is not enabled for this search space.",
|
||||
)
|
||||
|
||||
return search_space
|
||||
|
||||
|
||||
async def check_permission(
|
||||
session: AsyncSession,
|
||||
user: User,
|
||||
auth: AuthContext,
|
||||
search_space_id: int,
|
||||
required_permission: str,
|
||||
error_message: str = "You don't have permission to perform this action",
|
||||
|
|
@ -104,7 +128,7 @@ async def check_permission(
|
|||
Raises:
|
||||
HTTPException: If user doesn't have access or permission
|
||||
"""
|
||||
membership = await get_user_membership(session, user.id, search_space_id)
|
||||
membership = await get_user_membership(session, auth.user.id, search_space_id)
|
||||
|
||||
if not membership:
|
||||
raise HTTPException(
|
||||
|
|
@ -123,12 +147,14 @@ async def check_permission(
|
|||
if not has_permission(permissions, required_permission):
|
||||
raise HTTPException(status_code=403, detail=error_message)
|
||||
|
||||
await _enforce_api_access_gate(session, auth, search_space_id)
|
||||
|
||||
return membership
|
||||
|
||||
|
||||
async def check_search_space_access(
|
||||
session: AsyncSession,
|
||||
user: User,
|
||||
auth: AuthContext,
|
||||
search_space_id: int,
|
||||
) -> SearchSpaceMembership:
|
||||
"""
|
||||
|
|
@ -146,7 +172,7 @@ async def check_search_space_access(
|
|||
Raises:
|
||||
HTTPException: If user doesn't have access
|
||||
"""
|
||||
membership = await get_user_membership(session, user.id, search_space_id)
|
||||
membership = await get_user_membership(session, auth.user.id, search_space_id)
|
||||
|
||||
if not membership:
|
||||
raise HTTPException(
|
||||
|
|
@ -154,6 +180,8 @@ async def check_search_space_access(
|
|||
detail="You don't have access to this search space",
|
||||
)
|
||||
|
||||
await _enforce_api_access_gate(session, auth, search_space_id)
|
||||
|
||||
return membership
|
||||
|
||||
|
||||
|
|
@ -179,7 +207,7 @@ async def is_search_space_owner(
|
|||
|
||||
async def get_search_space_with_access_check(
|
||||
session: AsyncSession,
|
||||
user: User,
|
||||
auth: AuthContext,
|
||||
search_space_id: int,
|
||||
required_permission: str | None = None,
|
||||
) -> tuple[SearchSpace, SearchSpaceMembership]:
|
||||
|
|
@ -210,10 +238,12 @@ async def get_search_space_with_access_check(
|
|||
# Check access
|
||||
if required_permission:
|
||||
membership = await check_permission(
|
||||
session, user, search_space_id, required_permission
|
||||
session, auth, search_space_id, required_permission
|
||||
)
|
||||
else:
|
||||
membership = await check_search_space_access(session, user, search_space_id)
|
||||
membership = await check_search_space_access(session, auth, search_space_id)
|
||||
|
||||
await _enforce_api_access_gate(session, auth, search_space_id, search_space)
|
||||
|
||||
return search_space, membership
|
||||
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ Maximize logic covered by unit tests; keep integration tests for what genuinely
|
|||
|
||||
`conftest.py` is scoped to its directory and below. Keep truly global fixtures in `tests/conftest.py`; put module-specific fixtures in that module's `conftest.py` so a DB fixture never loads for a pure unit test.
|
||||
|
||||
For API integration tests, override `get_async_session` and `current_active_user` to ride the test's transactional `db_session` (see `tests/integration/notifications/conftest.py`): rows seeded in the test and rows read via the endpoint share one transaction that rolls back automatically.
|
||||
For API integration tests, override `get_async_session` and `get_auth_context` to ride the test's transactional `db_session` (see `tests/integration/notifications/conftest.py`): rows seeded in the test and rows read via the endpoint share one transaction that rolls back automatically.
|
||||
|
||||
## Import mode
|
||||
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ import pytest_asyncio
|
|||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import (
|
||||
ChatVisibility,
|
||||
NewChatMessage,
|
||||
|
|
@ -395,7 +396,7 @@ class TestAppendMessageRecoveryAfterFinalize:
|
|||
thread_id=thread_id,
|
||||
request=request,
|
||||
session=db_session,
|
||||
user=db_user,
|
||||
auth=AuthContext.session(db_user),
|
||||
)
|
||||
|
||||
# Response must echo the SERVER's rich payload, not the FE's
|
||||
|
|
@ -469,7 +470,7 @@ class TestAppendMessageRecoveryAfterFinalize:
|
|||
thread_id=thread_id,
|
||||
request=_FakeRequest(fe_request_body),
|
||||
session=db_session,
|
||||
user=db_user,
|
||||
auth=AuthContext.session(db_user),
|
||||
)
|
||||
assert fe_response.role == NewChatMessageRole.ASSISTANT
|
||||
|
||||
|
|
@ -552,7 +553,7 @@ class TestAppendMessageRecoveryAfterFinalize:
|
|||
}
|
||||
),
|
||||
session=db_session,
|
||||
user=db_user,
|
||||
auth=AuthContext.session(db_user),
|
||||
)
|
||||
assert ok_response.role == NewChatMessageRole.USER
|
||||
assert ok_response.turn_id is None
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from fastapi import HTTPException
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import (
|
||||
ChatVisibility,
|
||||
SearchSpace,
|
||||
|
|
@ -33,6 +34,10 @@ from app.schemas.new_chat import (
|
|||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
def _auth(user: User) -> AuthContext:
|
||||
return AuthContext.session(user)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_member(db_session: AsyncSession, db_search_space: SearchSpace) -> User:
|
||||
member = User(
|
||||
|
|
@ -85,7 +90,7 @@ async def _create_thread(
|
|||
visibility=ChatVisibility.PRIVATE,
|
||||
),
|
||||
session=db_session,
|
||||
user=db_user,
|
||||
auth=_auth(db_user),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -108,13 +113,13 @@ async def test_private_thread_is_hidden_from_other_search_space_member(
|
|||
member_threads = await new_chat_routes.list_threads(
|
||||
search_space_id=db_search_space.id,
|
||||
session=db_session,
|
||||
user=db_member,
|
||||
auth=_auth(db_member),
|
||||
)
|
||||
member_search = await new_chat_routes.search_threads(
|
||||
search_space_id=db_search_space.id,
|
||||
title="Visibility",
|
||||
session=db_session,
|
||||
user=db_member,
|
||||
auth=_auth(db_member),
|
||||
)
|
||||
|
||||
assert thread.id not in _active_thread_ids(member_threads)
|
||||
|
|
@ -123,7 +128,7 @@ async def test_private_thread_is_hidden_from_other_search_space_member(
|
|||
await new_chat_routes.get_thread_full(
|
||||
thread_id=thread.id,
|
||||
session=db_session,
|
||||
user=db_member,
|
||||
auth=_auth(db_member),
|
||||
)
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
|
|
@ -142,24 +147,24 @@ async def test_creator_can_share_thread_and_member_can_list_search_read_it(
|
|||
visibility=ChatVisibility.SEARCH_SPACE,
|
||||
),
|
||||
session=db_session,
|
||||
user=db_user,
|
||||
auth=_auth(db_user),
|
||||
)
|
||||
|
||||
member_threads = await new_chat_routes.list_threads(
|
||||
search_space_id=db_search_space.id,
|
||||
session=db_session,
|
||||
user=db_member,
|
||||
auth=_auth(db_member),
|
||||
)
|
||||
member_search = await new_chat_routes.search_threads(
|
||||
search_space_id=db_search_space.id,
|
||||
title="Visibility",
|
||||
session=db_session,
|
||||
user=db_member,
|
||||
auth=_auth(db_member),
|
||||
)
|
||||
full_thread = await new_chat_routes.get_thread_full(
|
||||
thread_id=thread.id,
|
||||
session=db_session,
|
||||
user=db_member,
|
||||
auth=_auth(db_member),
|
||||
)
|
||||
|
||||
assert updated.visibility == ChatVisibility.SEARCH_SPACE
|
||||
|
|
@ -181,20 +186,20 @@ async def test_rename_and_archive_do_not_reset_shared_visibility(
|
|||
visibility=ChatVisibility.SEARCH_SPACE,
|
||||
),
|
||||
session=db_session,
|
||||
user=db_user,
|
||||
auth=_auth(db_user),
|
||||
)
|
||||
|
||||
renamed = await new_chat_routes.update_thread(
|
||||
thread_id=thread.id,
|
||||
thread_update=NewChatThreadUpdate(title="Renamed Shared Chat"),
|
||||
session=db_session,
|
||||
user=db_user,
|
||||
auth=_auth(db_user),
|
||||
)
|
||||
archived = await new_chat_routes.update_thread(
|
||||
thread_id=thread.id,
|
||||
thread_update=NewChatThreadUpdate(archived=True),
|
||||
session=db_session,
|
||||
user=db_user,
|
||||
auth=_auth(db_user),
|
||||
)
|
||||
|
||||
assert renamed.visibility == ChatVisibility.SEARCH_SPACE
|
||||
|
|
@ -215,7 +220,7 @@ async def test_non_creator_cannot_change_shared_thread_back_to_private(
|
|||
visibility=ChatVisibility.SEARCH_SPACE,
|
||||
),
|
||||
session=db_session,
|
||||
user=db_user,
|
||||
auth=_auth(db_user),
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
|
|
@ -225,7 +230,7 @@ async def test_non_creator_cannot_change_shared_thread_back_to_private(
|
|||
visibility=ChatVisibility.PRIVATE,
|
||||
),
|
||||
session=db_session,
|
||||
user=db_member,
|
||||
auth=_auth(db_member),
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
|
|
@ -244,7 +249,7 @@ async def test_creator_can_make_shared_thread_private_again(
|
|||
visibility=ChatVisibility.SEARCH_SPACE,
|
||||
),
|
||||
session=db_session,
|
||||
user=db_user,
|
||||
auth=_auth(db_user),
|
||||
)
|
||||
|
||||
private_again = await new_chat_routes.update_thread_visibility(
|
||||
|
|
@ -253,18 +258,18 @@ async def test_creator_can_make_shared_thread_private_again(
|
|||
visibility=ChatVisibility.PRIVATE,
|
||||
),
|
||||
session=db_session,
|
||||
user=db_user,
|
||||
auth=_auth(db_user),
|
||||
)
|
||||
member_threads = await new_chat_routes.list_threads(
|
||||
search_space_id=db_search_space.id,
|
||||
session=db_session,
|
||||
user=db_member,
|
||||
auth=_auth(db_member),
|
||||
)
|
||||
member_search = await new_chat_routes.search_threads(
|
||||
search_space_id=db_search_space.id,
|
||||
title="Visibility",
|
||||
session=db_session,
|
||||
user=db_member,
|
||||
auth=_auth(db_member),
|
||||
)
|
||||
|
||||
assert private_again.visibility == ChatVisibility.PRIVATE
|
||||
|
|
@ -274,6 +279,6 @@ async def test_creator_can_make_shared_thread_private_again(
|
|||
await new_chat_routes.get_thread_full(
|
||||
thread_id=thread.id,
|
||||
session=db_session,
|
||||
user=db_member,
|
||||
auth=_auth(db_member),
|
||||
)
|
||||
assert exc_info.value.status_code == 403
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from httpx import ASGITransport
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.app import app, limiter
|
||||
from app.auth.context import AuthContext
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
SearchSourceConnector,
|
||||
|
|
@ -22,7 +23,7 @@ from app.db import (
|
|||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.users import get_auth_context
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
|
@ -40,12 +41,12 @@ async def client(
|
|||
async def override_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
yield db_session
|
||||
|
||||
async def override_user() -> User:
|
||||
return db_user
|
||||
async def override_auth() -> AuthContext:
|
||||
return AuthContext.session(db_user)
|
||||
|
||||
previous_overrides = app.dependency_overrides.copy()
|
||||
app.dependency_overrides[get_async_session] = override_session
|
||||
app.dependency_overrides[current_active_user] = override_user
|
||||
app.dependency_overrides[get_auth_context] = override_auth
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
"""Notifications integration fixtures.
|
||||
|
||||
The app's DB session and current-user dependencies are overridden to ride the
|
||||
The app's DB session and auth-context dependencies are overridden to ride the
|
||||
test's transactional `db_session`, so API calls and seeded rows share one
|
||||
transaction that rolls back per test. Overriding `current_active_user` also
|
||||
bypasses real JWT auth, so these tests don't depend on AUTH_TYPE.
|
||||
transaction that rolls back per test. Overriding `get_auth_context` also bypasses
|
||||
real JWT auth, so these tests don't depend on AUTH_TYPE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -17,8 +17,9 @@ from httpx import ASGITransport
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.app import app, limiter
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import User, get_async_session
|
||||
from app.users import current_active_user
|
||||
from app.users import get_auth_context
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
|
@ -33,12 +34,12 @@ async def client(
|
|||
async def override_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
yield db_session
|
||||
|
||||
async def override_user() -> User:
|
||||
return db_user
|
||||
async def override_auth() -> AuthContext:
|
||||
return AuthContext.session(db_user)
|
||||
|
||||
previous_overrides = app.dependency_overrides.copy()
|
||||
app.dependency_overrides[get_async_session] = override_session
|
||||
app.dependency_overrides[current_active_user] = override_user
|
||||
app.dependency_overrides[get_auth_context] = override_auth
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from httpx import ASGITransport
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.app import app, limiter
|
||||
from app.auth.context import AuthContext
|
||||
from app.config import config as app_config
|
||||
from app.db import SearchSpace, User, get_async_session
|
||||
from app.podcasts.persistence import Podcast, PodcastStatus
|
||||
|
|
@ -39,7 +40,7 @@ from app.podcasts.schemas import (
|
|||
from app.podcasts.service import PodcastService
|
||||
from app.podcasts.tts import SynthesisRequest, SynthesizedAudio, TextToSpeech
|
||||
from app.routes.search_spaces_routes import create_default_roles_and_membership
|
||||
from app.users import current_active_user
|
||||
from app.users import get_auth_context
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
|
@ -54,12 +55,12 @@ async def client(
|
|||
async def override_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
yield db_session
|
||||
|
||||
async def override_user() -> User:
|
||||
return db_user
|
||||
async def override_auth() -> AuthContext:
|
||||
return AuthContext.session(db_user)
|
||||
|
||||
previous_overrides = app.dependency_overrides.copy()
|
||||
app.dependency_overrides[get_async_session] = override_session
|
||||
app.dependency_overrides[current_active_user] = override_user
|
||||
app.dependency_overrides[get_auth_context] = override_auth
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
|
|
@ -290,7 +291,7 @@ def act_as():
|
|||
"""
|
||||
|
||||
def _act(user: User) -> None:
|
||||
app.dependency_overrides[current_active_user] = lambda: user
|
||||
app.dependency_overrides[get_auth_context] = lambda: AuthContext.session(user)
|
||||
|
||||
return _act
|
||||
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ import pytest
|
|||
from fastapi import HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import (
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
|
|
@ -109,7 +110,7 @@ class TestConnectorIndexCrossSpaceAuthz:
|
|||
connector_id=connector_a.id,
|
||||
search_space_id=space_b.id, # the attacker's own space
|
||||
session=db_session,
|
||||
user=attacker,
|
||||
auth=AuthContext.session(attacker),
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
|
|
@ -140,7 +141,7 @@ class TestConnectorIndexCrossSpaceAuthz:
|
|||
connector_id=connector.id,
|
||||
search_space_id=space.id, # the connector's own space
|
||||
session=db_session,
|
||||
user=owner,
|
||||
auth=AuthContext.session(owner),
|
||||
)
|
||||
|
||||
check_permission_mock.assert_awaited_once()
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ from sqlalchemy import func, select, text
|
|||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import (
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
|
|
@ -42,6 +43,7 @@ from app.routes.obsidian_plugin_routes import (
|
|||
obsidian_stats,
|
||||
obsidian_sync,
|
||||
)
|
||||
from app.routes.search_spaces_routes import create_default_roles_and_membership
|
||||
from app.schemas.obsidian_plugin import (
|
||||
ConnectRequest,
|
||||
DeleteAck,
|
||||
|
|
@ -65,6 +67,10 @@ pytestmark = pytest.mark.integration
|
|||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _auth(user: User) -> AuthContext:
|
||||
return AuthContext.session(user)
|
||||
|
||||
|
||||
def _make_note_payload(vault_id: str, path: str, content_hash: str) -> NotePayload:
|
||||
"""Minimal NotePayload that the schema accepts; the indexer is mocked
|
||||
out so the values don't have to round-trip through the real pipeline."""
|
||||
|
|
@ -102,6 +108,8 @@ async def race_user_and_space(async_engine):
|
|||
)
|
||||
space = SearchSpace(name="Race Space", user_id=user_id)
|
||||
setup.add_all([user, space])
|
||||
await setup.flush()
|
||||
await create_default_roles_and_membership(setup, space.id, user_id)
|
||||
await setup.commit()
|
||||
await setup.refresh(space)
|
||||
space_id = space.id
|
||||
|
|
@ -116,6 +124,14 @@ async def race_user_and_space(async_engine):
|
|||
text("DELETE FROM search_source_connectors WHERE user_id = :uid"),
|
||||
{"uid": user_id},
|
||||
)
|
||||
await cleanup.execute(
|
||||
text("DELETE FROM search_space_memberships WHERE search_space_id = :id"),
|
||||
{"id": space_id},
|
||||
)
|
||||
await cleanup.execute(
|
||||
text("DELETE FROM search_space_roles WHERE search_space_id = :id"),
|
||||
{"id": space_id},
|
||||
)
|
||||
await cleanup.execute(
|
||||
text("DELETE FROM searchspaces WHERE id = :id"),
|
||||
{"id": space_id},
|
||||
|
|
@ -154,7 +170,7 @@ class TestConnectRace:
|
|||
search_space_id=space_id,
|
||||
vault_fingerprint=fingerprint,
|
||||
)
|
||||
await obsidian_connect(payload, user=fresh_user, session=s)
|
||||
await obsidian_connect(payload, auth=_auth(fresh_user), session=s)
|
||||
|
||||
results = await asyncio.gather(_call("a"), _call("b"), return_exceptions=True)
|
||||
for r in results:
|
||||
|
|
@ -281,7 +297,7 @@ class TestConnectRace:
|
|||
search_space_id=space_id,
|
||||
vault_fingerprint=fingerprint,
|
||||
),
|
||||
user=fresh_user,
|
||||
auth=_auth(fresh_user),
|
||||
session=s,
|
||||
)
|
||||
|
||||
|
|
@ -294,7 +310,7 @@ class TestConnectRace:
|
|||
search_space_id=space_id,
|
||||
vault_fingerprint=fingerprint,
|
||||
),
|
||||
user=fresh_user,
|
||||
auth=_auth(fresh_user),
|
||||
session=s,
|
||||
)
|
||||
|
||||
|
|
@ -337,7 +353,7 @@ class TestWireContractSmoke:
|
|||
search_space_id=db_search_space.id,
|
||||
vault_fingerprint="fp-" + uuid.uuid4().hex,
|
||||
),
|
||||
user=db_user,
|
||||
auth=_auth(db_user),
|
||||
session=db_session,
|
||||
)
|
||||
assert connect_resp.connector_id > 0
|
||||
|
|
@ -361,7 +377,7 @@ class TestWireContractSmoke:
|
|||
_make_note_payload(vault_id, "fail.md", "hash-fail"),
|
||||
],
|
||||
),
|
||||
user=db_user,
|
||||
auth=_auth(db_user),
|
||||
session=db_session,
|
||||
)
|
||||
|
||||
|
|
@ -394,7 +410,7 @@ class TestWireContractSmoke:
|
|||
_make_note_payload(vault_id, "fail.md", "h2"),
|
||||
],
|
||||
),
|
||||
user=db_user,
|
||||
auth=_auth(db_user),
|
||||
session=db_session,
|
||||
)
|
||||
assert sync_resp.indexed == 1
|
||||
|
|
@ -420,7 +436,7 @@ class TestWireContractSmoke:
|
|||
RenameItem(old_path="missing.md", new_path="x.md"),
|
||||
],
|
||||
),
|
||||
user=db_user,
|
||||
auth=_auth(db_user),
|
||||
session=db_session,
|
||||
)
|
||||
assert isinstance(rename_resp, RenameAck)
|
||||
|
|
@ -441,7 +457,7 @@ class TestWireContractSmoke:
|
|||
):
|
||||
delete_resp = await obsidian_delete_notes(
|
||||
DeleteBatchRequest(vault_id=vault_id, paths=["b.md", "ghost.md"]),
|
||||
user=db_user,
|
||||
auth=_auth(db_user),
|
||||
session=db_session,
|
||||
)
|
||||
assert isinstance(delete_resp, DeleteAck)
|
||||
|
|
@ -456,7 +472,7 @@ class TestWireContractSmoke:
|
|||
# upsert_note was mocked) but the response shape is what we care
|
||||
# about.
|
||||
manifest_resp = await obsidian_manifest(
|
||||
vault_id=vault_id, user=db_user, session=db_session
|
||||
vault_id=vault_id, auth=_auth(db_user), session=db_session
|
||||
)
|
||||
assert isinstance(manifest_resp, ManifestResponse)
|
||||
assert manifest_resp.vault_id == vault_id
|
||||
|
|
@ -464,7 +480,7 @@ class TestWireContractSmoke:
|
|||
|
||||
# 6. /stats — same; row count is 0 because upsert_note was mocked.
|
||||
stats_resp = await obsidian_stats(
|
||||
vault_id=vault_id, user=db_user, session=db_session
|
||||
vault_id=vault_id, auth=_auth(db_user), session=db_session
|
||||
)
|
||||
assert isinstance(stats_resp, StatsResponse)
|
||||
assert stats_resp.vault_id == vault_id
|
||||
|
|
@ -482,7 +498,7 @@ class TestWireContractSmoke:
|
|||
search_space_id=db_search_space.id,
|
||||
vault_fingerprint="fp-" + uuid.uuid4().hex,
|
||||
),
|
||||
user=db_user,
|
||||
auth=_auth(db_user),
|
||||
session=db_session,
|
||||
)
|
||||
|
||||
|
|
@ -511,7 +527,7 @@ class TestWireContractSmoke:
|
|||
binary_note,
|
||||
],
|
||||
),
|
||||
user=db_user,
|
||||
auth=_auth(db_user),
|
||||
session=db_session,
|
||||
)
|
||||
|
||||
|
|
@ -533,7 +549,7 @@ class TestWireContractSmoke:
|
|||
search_space_id=db_search_space.id,
|
||||
vault_fingerprint="fp-" + uuid.uuid4().hex,
|
||||
),
|
||||
user=db_user,
|
||||
auth=_auth(db_user),
|
||||
session=db_session,
|
||||
)
|
||||
|
||||
|
|
@ -562,7 +578,7 @@ class TestWireContractSmoke:
|
|||
bad_note,
|
||||
],
|
||||
),
|
||||
user=db_user,
|
||||
auth=_auth(db_user),
|
||||
session=db_session,
|
||||
)
|
||||
|
||||
|
|
@ -587,7 +603,7 @@ class TestWireContractSmoke:
|
|||
search_space_id=db_search_space.id,
|
||||
vault_fingerprint="fp-" + uuid.uuid4().hex,
|
||||
),
|
||||
user=db_user,
|
||||
auth=_auth(db_user),
|
||||
session=db_session,
|
||||
)
|
||||
|
||||
|
|
@ -616,7 +632,7 @@ class TestWireContractSmoke:
|
|||
mismatched,
|
||||
],
|
||||
),
|
||||
user=db_user,
|
||||
auth=_auth(db_user),
|
||||
session=db_session,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,71 @@
|
|||
"""Runtime smoke tests for fail-closed PAT authorization primitives."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import PersonalAccessToken, SearchSpace, User
|
||||
from app.users import allow_any_principal, require_session_context
|
||||
from app.utils.rbac import check_search_space_access
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
def _pat_auth(user: User) -> AuthContext:
|
||||
pat = PersonalAccessToken(
|
||||
user_id=user.id,
|
||||
user=user,
|
||||
token_hash="0" * 64,
|
||||
token_prefix="ss_pat_test",
|
||||
label="Test PAT",
|
||||
)
|
||||
return AuthContext.pat_auth(user, pat)
|
||||
|
||||
|
||||
async def test_pat_is_rejected_by_session_only_dependency(db_user: User):
|
||||
auth = _pat_auth(db_user)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await require_session_context(auth=auth)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
|
||||
async def test_pat_is_allowed_by_bootstrap_dependency(db_user: User):
|
||||
auth = _pat_auth(db_user)
|
||||
|
||||
assert await allow_any_principal(auth=auth) is auth
|
||||
|
||||
|
||||
async def test_pat_is_rejected_for_api_disabled_space(
|
||||
db_session: AsyncSession,
|
||||
db_user: User,
|
||||
db_search_space: SearchSpace,
|
||||
):
|
||||
db_search_space.api_access_enabled = False
|
||||
await db_session.flush()
|
||||
auth = _pat_auth(db_user)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await check_search_space_access(db_session, auth, db_search_space.id)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert exc_info.value.detail == "API access is not enabled for this search space."
|
||||
|
||||
|
||||
async def test_pat_is_allowed_for_api_enabled_space(
|
||||
db_session: AsyncSession,
|
||||
db_user: User,
|
||||
db_search_space: SearchSpace,
|
||||
):
|
||||
db_search_space.api_access_enabled = True
|
||||
await db_session.flush()
|
||||
auth = _pat_auth(db_user)
|
||||
|
||||
membership = await check_search_space_access(db_session, auth, db_search_space.id)
|
||||
|
||||
assert membership.user_id == db_user.id
|
||||
assert membership.search_space_id == db_search_space.id
|
||||
|
|
@ -15,6 +15,7 @@ import pytest
|
|||
from fastapi import HTTPException
|
||||
|
||||
import app.automations.services.automation as automation_mod
|
||||
from app.auth.context import AuthContext
|
||||
from app.automations.schemas.api import AutomationCreate, AutomationUpdate
|
||||
from app.automations.schemas.definition.envelope import (
|
||||
AutomationDefinition,
|
||||
|
|
@ -45,7 +46,8 @@ class _FakeSession:
|
|||
|
||||
def _service(search_space: Any) -> AutomationService:
|
||||
return AutomationService(
|
||||
session=_FakeSession(search_space), user=SimpleNamespace(id="u-1")
|
||||
session=_FakeSession(search_space),
|
||||
auth=AuthContext.session(SimpleNamespace(id="u-1")),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -38,7 +38,9 @@ async def cleanup_supervisors():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_byo_long_poll_noops_when_mode_is_webhook(monkeypatch):
|
||||
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_ENABLED", True)
|
||||
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_TELEGRAM_INTAKE_MODE", "webhook")
|
||||
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_WHATSAPP_INTAKE_MODE", "disabled")
|
||||
|
||||
await byo_long_poll.start_byo_long_poll_supervisors()
|
||||
|
||||
|
|
@ -47,9 +49,11 @@ async def test_start_byo_long_poll_noops_when_mode_is_webhook(monkeypatch):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_byo_long_poll_noops_when_no_byo_accounts(mocker, monkeypatch):
|
||||
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_ENABLED", True)
|
||||
monkeypatch.setattr(
|
||||
byo_long_poll.config, "GATEWAY_TELEGRAM_INTAKE_MODE", "longpoll"
|
||||
)
|
||||
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_WHATSAPP_INTAKE_MODE", "disabled")
|
||||
session = mocker.AsyncMock()
|
||||
session.execute.return_value = ScalarResult([])
|
||||
monkeypatch.setattr(
|
||||
|
|
@ -67,9 +71,11 @@ async def test_start_byo_long_poll_noops_when_no_byo_accounts(mocker, monkeypatc
|
|||
async def test_start_byo_long_poll_spawns_one_supervisor_per_account(
|
||||
mocker, monkeypatch
|
||||
):
|
||||
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_ENABLED", True)
|
||||
monkeypatch.setattr(
|
||||
byo_long_poll.config, "GATEWAY_TELEGRAM_INTAKE_MODE", "longpoll"
|
||||
)
|
||||
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_WHATSAPP_INTAKE_MODE", "disabled")
|
||||
accounts = [mocker.Mock(id=1), mocker.Mock(id=2)]
|
||||
session = mocker.AsyncMock()
|
||||
session.execute.return_value = ScalarResult(accounts)
|
||||
|
|
@ -115,9 +121,11 @@ async def test_supervisor_retries_after_run_returns(mocker, monkeypatch):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shutdown_cancels_running_supervisors(mocker, monkeypatch):
|
||||
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_ENABLED", True)
|
||||
monkeypatch.setattr(
|
||||
byo_long_poll.config, "GATEWAY_TELEGRAM_INTAKE_MODE", "longpoll"
|
||||
)
|
||||
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_WHATSAPP_INTAKE_MODE", "disabled")
|
||||
session = mocker.AsyncMock()
|
||||
session.execute.return_value = ScalarResult([mocker.Mock(id=1)])
|
||||
monkeypatch.setattr(
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ async def test_inbox_worker_claims_and_processes_in_fastapi_process(
|
|||
async def test_start_stop_gateway_inbox_worker(mocker, monkeypatch):
|
||||
started = asyncio.Event()
|
||||
stopped = asyncio.Event()
|
||||
monkeypatch.setattr(inbox_worker.config, "GATEWAY_ENABLED", True)
|
||||
|
||||
async def run_forever():
|
||||
started.set()
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from types import SimpleNamespace
|
|||
|
||||
import pytest
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import ExternalChatAccount, ExternalChatAccountMode, ExternalChatPlatform
|
||||
from app.routes import gateway_webhook_routes as routes
|
||||
|
||||
|
|
@ -333,7 +334,9 @@ async def test_discord_gateway_install_returns_oauth_url(monkeypatch, mocker):
|
|||
|
||||
response = await routes.install_discord_gateway(
|
||||
search_space_id=123,
|
||||
user=SimpleNamespace(id="00000000-0000-0000-0000-000000000001"),
|
||||
auth=AuthContext.session(
|
||||
SimpleNamespace(id="00000000-0000-0000-0000-000000000001")
|
||||
),
|
||||
session=mocker.AsyncMock(),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from unittest.mock import AsyncMock, patch
|
|||
import pytest
|
||||
|
||||
from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
|
||||
from app.auth.context import AuthContext
|
||||
from app.routes import agent_revert_route
|
||||
from app.services.revert_service import RevertOutcome
|
||||
|
||||
|
|
@ -147,7 +148,7 @@ class TestFlagGuard:
|
|||
thread_id=1,
|
||||
chat_turn_id="42:1700000000000",
|
||||
session=session,
|
||||
user=_FakeUser(),
|
||||
auth=AuthContext.session(_FakeUser()),
|
||||
)
|
||||
assert getattr(exc.value, "status_code", None) == 503
|
||||
|
||||
|
|
@ -167,7 +168,7 @@ class TestRevertTurnDispatch:
|
|||
thread_id=1,
|
||||
chat_turn_id="ct-empty",
|
||||
session=session,
|
||||
user=_FakeUser(),
|
||||
auth=AuthContext.session(_FakeUser()),
|
||||
)
|
||||
assert response.status == "ok"
|
||||
assert response.total == 0
|
||||
|
|
@ -209,7 +210,7 @@ class TestRevertTurnDispatch:
|
|||
thread_id=1,
|
||||
chat_turn_id="ct-3",
|
||||
session=session,
|
||||
user=_FakeUser(),
|
||||
auth=AuthContext.session(_FakeUser()),
|
||||
)
|
||||
|
||||
assert response.status == "ok"
|
||||
|
|
@ -248,7 +249,7 @@ class TestRevertTurnDispatch:
|
|||
thread_id=1,
|
||||
chat_turn_id="ct-i",
|
||||
session=session,
|
||||
user=_FakeUser(),
|
||||
auth=AuthContext.session(_FakeUser()),
|
||||
)
|
||||
assert response.status == "ok"
|
||||
assert response.already_reverted == 1
|
||||
|
|
@ -275,7 +276,7 @@ class TestRevertTurnDispatch:
|
|||
thread_id=1,
|
||||
chat_turn_id="ct-rev",
|
||||
session=session,
|
||||
user=_FakeUser(),
|
||||
auth=AuthContext.session(_FakeUser()),
|
||||
)
|
||||
assert response.status == "ok"
|
||||
assert response.results[0].status == "skipped"
|
||||
|
|
@ -315,7 +316,7 @@ class TestRevertTurnDispatch:
|
|||
thread_id=1,
|
||||
chat_turn_id="ct-mix",
|
||||
session=session,
|
||||
user=_FakeUser(),
|
||||
auth=AuthContext.session(_FakeUser()),
|
||||
)
|
||||
assert response.status == "partial"
|
||||
assert response.reverted == 1
|
||||
|
|
@ -354,7 +355,7 @@ class TestRevertTurnDispatch:
|
|||
thread_id=1,
|
||||
chat_turn_id="ct-fail",
|
||||
session=session,
|
||||
user=_FakeUser(),
|
||||
auth=AuthContext.session(_FakeUser()),
|
||||
)
|
||||
assert response.status == "partial"
|
||||
assert response.failed == 1
|
||||
|
|
@ -386,7 +387,7 @@ class TestRevertTurnDispatch:
|
|||
thread_id=1,
|
||||
chat_turn_id="ct-perm",
|
||||
session=session,
|
||||
user=_FakeUser(id="not-owner"),
|
||||
auth=AuthContext.session(_FakeUser(id="not-owner")),
|
||||
)
|
||||
assert response.status == "partial"
|
||||
assert response.results[0].status == "permission_denied"
|
||||
|
|
@ -449,7 +450,7 @@ class TestRevertTurnDispatch:
|
|||
thread_id=1,
|
||||
chat_turn_id="ct-mixed-all",
|
||||
session=session,
|
||||
user=_FakeUser(), # only id=7 has a different user_id
|
||||
auth=AuthContext.session(_FakeUser()), # only id=7 has a different user_id
|
||||
)
|
||||
|
||||
assert response.total == len(rows) == 6
|
||||
|
|
@ -518,7 +519,7 @@ class TestRevertTurnDispatch:
|
|||
thread_id=1,
|
||||
chat_turn_id="ct-race",
|
||||
session=session,
|
||||
user=_FakeUser(),
|
||||
auth=AuthContext.session(_FakeUser()),
|
||||
)
|
||||
|
||||
assert response.failed == 0, (
|
||||
|
|
|
|||
101
surfsense_backend/tests/unit/test_pat_fail_closed_static.py
Normal file
101
surfsense_backend/tests/unit/test_pat_fail_closed_static.py
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
"""Static guards for the fail-closed PAT authorization model."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
BACKEND_ROOT = Path(__file__).resolve().parents[2]
|
||||
APP_ROOT = BACKEND_ROOT / "app"
|
||||
|
||||
ALLOW_ANY_EXPECTED = {
|
||||
"app.py": "auth: AuthContext = Depends(allow_any_principal)",
|
||||
"routes/obsidian_plugin_routes.py": (
|
||||
"_auth: AuthContext = Depends(allow_any_principal)"
|
||||
),
|
||||
"routes/search_spaces_routes.py": (
|
||||
"auth: AuthContext = Depends(allow_any_principal)"
|
||||
),
|
||||
}
|
||||
|
||||
CONNECTOR_LISTERS = [
|
||||
"routes/slack_add_connector_route.py",
|
||||
"routes/composio_routes.py",
|
||||
"routes/google_drive_add_connector_route.py",
|
||||
"routes/discord_add_connector_route.py",
|
||||
"routes/dropbox_add_connector_route.py",
|
||||
"routes/onedrive_add_connector_route.py",
|
||||
]
|
||||
|
||||
|
||||
def _python_files() -> list[Path]:
|
||||
return [
|
||||
path
|
||||
for path in APP_ROOT.rglob("*.py")
|
||||
if "__pycache__" not in path.parts
|
||||
]
|
||||
|
||||
|
||||
def test_current_active_user_is_removed_from_app_tree() -> None:
|
||||
offenders = [
|
||||
str(path.relative_to(BACKEND_ROOT))
|
||||
for path in _python_files()
|
||||
if "current_active_user" in path.read_text()
|
||||
]
|
||||
|
||||
assert offenders == []
|
||||
|
||||
|
||||
def test_allow_any_principal_is_only_used_by_bootstrap_allowlist() -> None:
|
||||
actual: dict[str, int] = {}
|
||||
for path in _python_files():
|
||||
text = path.read_text()
|
||||
count = text.count("Depends(allow_any_principal)")
|
||||
if count:
|
||||
actual[str(path.relative_to(APP_ROOT))] = count
|
||||
|
||||
assert actual == dict.fromkeys(ALLOW_ANY_EXPECTED, 1)
|
||||
|
||||
for rel_path, expected_snippet in ALLOW_ANY_EXPECTED.items():
|
||||
text = (APP_ROOT / rel_path).read_text()
|
||||
assert expected_snippet in text
|
||||
|
||||
|
||||
def test_connector_listers_route_pat_through_search_space_gate() -> None:
|
||||
for rel_path in CONNECTOR_LISTERS:
|
||||
text = (APP_ROOT / rel_path).read_text()
|
||||
assert "auth: AuthContext = Depends(get_auth_context)" in text, rel_path
|
||||
assert (
|
||||
"await check_search_space_access(session, auth, connector.search_space_id)"
|
||||
in text
|
||||
), rel_path
|
||||
|
||||
|
||||
def test_identity_routes_are_session_only() -> None:
|
||||
session_only_files = [
|
||||
"routes/prompts_routes.py",
|
||||
"routes/memory_routes.py",
|
||||
"routes/model_list_routes.py",
|
||||
"routes/agent_flags_route.py",
|
||||
"routes/youtube_routes.py",
|
||||
"routes/incentive_tasks_routes.py",
|
||||
"notifications/api/api.py",
|
||||
"routes/chat_comments_routes.py",
|
||||
"routes/public_chat_routes.py",
|
||||
]
|
||||
|
||||
for rel_path in session_only_files:
|
||||
text = (APP_ROOT / rel_path).read_text()
|
||||
assert "require_session_context" in text, rel_path
|
||||
assert "Depends(get_auth_context)" not in text, rel_path
|
||||
|
||||
|
||||
def test_model_connection_personal_writes_default_to_session_required() -> None:
|
||||
text = (APP_ROOT / "routes/model_connections_routes.py").read_text()
|
||||
|
||||
assert "allow_spaceless_pat: bool = False" in text
|
||||
assert "auth.is_gated and not allow_spaceless_pat" in text
|
||||
assert "Managing personal model connections requires an interactive session" in text
|
||||
|
|
@ -16,7 +16,7 @@ const ApiKeyForm = () => {
|
|||
|
||||
const validateForm = () => {
|
||||
if (!apiKey) {
|
||||
setError("API key is required");
|
||||
setError("Personal access token is required");
|
||||
return false;
|
||||
}
|
||||
setError("");
|
||||
|
|
@ -39,11 +39,11 @@ const ApiKeyForm = () => {
|
|||
setLoading(false);
|
||||
|
||||
if (response.ok) {
|
||||
// Store the API key as the token
|
||||
// Store the PAT as the bearer token for existing background handlers.
|
||||
await storage.set("token", apiKey);
|
||||
navigation("/");
|
||||
} else {
|
||||
setError("Invalid API key. Please check and try again.");
|
||||
setError("Invalid personal access token. Please check and try again.");
|
||||
}
|
||||
} catch (error) {
|
||||
setLoading(false);
|
||||
|
|
@ -67,15 +67,15 @@ const ApiKeyForm = () => {
|
|||
|
||||
<div className="bg-gray-800/70 backdrop-blur-sm rounded-xl shadow-xl border border-gray-700 p-6">
|
||||
<div className="space-y-6">
|
||||
<h2 className="text-xl font-medium text-white">Enter your API Key</h2>
|
||||
<h2 className="text-xl font-medium text-white">Enter your personal access token</h2>
|
||||
<p className="text-gray-400 text-sm">
|
||||
Your API key connects this extension to the SurfSense.
|
||||
Your personal access token connects this extension to SurfSense.
|
||||
</p>
|
||||
|
||||
<form onSubmit={handleSubmit} className="space-y-4">
|
||||
<div className="space-y-2">
|
||||
<label htmlFor="apiKey" className="text-sm font-medium text-gray-300">
|
||||
API Key
|
||||
Personal access token
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
|
|
@ -83,7 +83,7 @@ const ApiKeyForm = () => {
|
|||
value={apiKey}
|
||||
onChange={(e) => setApiKey(e.target.value)}
|
||||
className="w-full px-3 py-2 bg-gray-900/50 border border-gray-700 rounded-md focus:outline-none focus:ring-2 focus:ring-teal-500 text-white placeholder:text-gray-500"
|
||||
placeholder="Enter your API key"
|
||||
placeholder="Enter your personal access token"
|
||||
/>
|
||||
{error && <p className="text-red-400 text-sm mt-1">{error}</p>}
|
||||
</div>
|
||||
|
|
@ -106,7 +106,7 @@ const ApiKeyForm = () => {
|
|||
|
||||
<div className="text-center mt-4">
|
||||
<p className="text-sm text-gray-400">
|
||||
Need an API key?{" "}
|
||||
Need a personal access token?{" "}
|
||||
<a
|
||||
href="https://www.surfsense.com"
|
||||
target="_blank"
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ Open **Settings → SurfSense** in Obsidian and fill in:
|
|||
| Setting | Value |
|
||||
| --- | --- |
|
||||
| Server URL | `https://surfsense.com` for SurfSense Cloud, or your self-hosted URL |
|
||||
| API token | Copy from the *Connectors → Obsidian* dialog in the SurfSense web app |
|
||||
| API token | Create a personal access token from the *Connectors → Obsidian* dialog or *User settings → API access* in the SurfSense web app |
|
||||
| Search space | Pick the search space this vault should sync into |
|
||||
| Vault name | Defaults to your Obsidian vault name; rename if you have multiple vaults |
|
||||
| Sync mode | *Auto* (recommended) or *Manual* |
|
||||
|
|
@ -62,11 +62,6 @@ The connector row appears automatically inside SurfSense the first time the
|
|||
plugin successfully calls `/obsidian/connect`. You can manage or delete it
|
||||
from *Connectors → Obsidian* in the web app.
|
||||
|
||||
> **Token lifetime.** The web app currently issues 24-hour JWTs. If you see
|
||||
> *"token expired"* in the plugin status bar, paste a fresh token from the
|
||||
> SurfSense web app. Long-lived personal access tokens are coming in a future
|
||||
> release.
|
||||
|
||||
## Mobile
|
||||
|
||||
The plugin works on Obsidian for iOS and Android. Sync runs whenever the
|
||||
|
|
|
|||
|
|
@ -22,11 +22,11 @@ import type {
|
|||
*
|
||||
* Auth + wire contract:
|
||||
* - Every request carries `Authorization: Bearer <token>` only. No
|
||||
* custom headers — the backend identifies the caller from the JWT
|
||||
* custom headers — the backend identifies the caller from the PAT
|
||||
* and feature-detects the API via the `capabilities` array on
|
||||
* `/health` and `/connect`.
|
||||
* - 401 surfaces as `AuthError` so the orchestrator can show the
|
||||
* "token expired, paste a fresh one" UX.
|
||||
* "token invalid or expired" UX.
|
||||
* - HealthResponse / ConnectResponse use index signatures so any
|
||||
* additive backend field (e.g. new capabilities) parses without
|
||||
* breaking the decoder. This mirrors `ConfigDict(extra='ignore')`
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue