refactor: unify authentication handling by replacing current_active_user with auth context across routes

This commit is contained in:
Anish Sarkar 2026-06-19 21:38:18 +05:30
parent 6fd3f8570e
commit 49b5247210
23 changed files with 192 additions and 84 deletions

View file

@ -27,6 +27,7 @@ from app.agents.chat.runtime.checkpointer import (
close_checkpointer,
setup_checkpointer_tables,
)
from app.auth.context import AuthContext
from app.config import (
config,
initialize_image_gen_router,
@ -34,7 +35,7 @@ from app.config import (
initialize_openrouter_integration,
initialize_pricing_registration,
)
from app.db import User, create_db_and_tables, get_async_session
from app.db import create_db_and_tables, get_async_session
from app.exceptions import GENERIC_5XX_MESSAGE, ISSUES_URL, SurfSenseError
from app.gateway.byo_long_poll import (
start_byo_long_poll_supervisors,
@ -55,7 +56,7 @@ from app.routes import router as crud_router
from app.routes.auth_routes import router as auth_router
from app.schemas import UserCreate, UserRead, UserUpdate
from app.session_events import register_session_hooks
from app.users import SECRET, auth_backend, current_active_user, fastapi_users
from app.users import SECRET, auth_backend, fastapi_users, get_auth_context
from app.utils.perf import log_system_snapshot
_error_logger = logging.getLogger("surfsense.errors")
@ -1032,7 +1033,7 @@ async def readiness_check():
@app.get("/verify-token")
async def authenticated_route(
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session),
):
return {"message": "Token is valid"}
return {"message": "Token is valid", "method": auth.method}

View file

@ -5,7 +5,7 @@ here" affordance. To prevent accidental usage during the gap we return
``503 Service Unavailable`` until the ``SURFSENSE_ENABLE_REVERT_ROUTE``
flag flips. Once enabled, the route runs:
1. Authentication via :func:`current_active_user`.
1. Authentication via an interactive session context.
2. Action lookup; 404 if the action does not belong to the thread.
3. Authorization via :func:`app.services.revert_service.can_revert`.
4. Revert dispatch via :func:`app.services.revert_service.revert_action`.
@ -33,9 +33,9 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.chat.multi_agent_chat.shared.feature_flags import get_flags
from app.auth.context import AuthContext
from app.db import (
AgentActionLog,
User,
get_async_session,
)
from app.services.revert_service import (
@ -45,7 +45,7 @@ from app.services.revert_service import (
load_thread,
revert_action,
)
from app.users import current_active_user
from app.users import require_session_context
logger = logging.getLogger(__name__)
@ -57,8 +57,9 @@ async def revert_agent_action(
thread_id: int,
action_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
) -> dict:
user = auth.user
flags = get_flags()
if flags.disable_new_agent_stack or not flags.enable_revert_route:
raise HTTPException(
@ -269,7 +270,7 @@ async def revert_agent_turn(
thread_id: int,
chat_turn_id: str,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
) -> RevertTurnResponse:
"""Revert every reversible action emitted during ``chat_turn_id``.
@ -281,6 +282,7 @@ async def revert_agent_turn(
Partial success is intentional and returned with HTTP 200. Callers
must inspect ``results[*].status`` to find rows that need attention.
"""
user = auth.user
flags = get_flags()
if flags.disable_new_agent_stack or not flags.enable_revert_route:

View file

@ -10,16 +10,16 @@ from pydantic import ValidationError
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.config import config
from app.connectors.airtable_connector import fetch_airtable_user_email
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
User,
get_async_session,
)
from app.schemas.airtable_auth_credentials import AirtableAuthCredentialsBase
from app.users import current_active_user
from app.users import require_session_context
from app.utils.connector_naming import (
check_duplicate_connector,
generate_unique_connector_name,
@ -78,7 +78,10 @@ def make_basic_auth_header(client_id: str, client_secret: str) -> str:
@router.get("/auth/airtable/connector/add")
async def connect_airtable(space_id: int, user: User = Depends(current_active_user)):
async def connect_airtable(
space_id: int,
auth: AuthContext = Depends(require_session_context),
):
"""
Initiate Airtable OAuth flow.
@ -89,6 +92,7 @@ async def connect_airtable(space_id: int, user: User = Depends(current_active_us
Returns:
Authorization URL for redirect
"""
user = auth.user
try:
if not space_id:
raise HTTPException(status_code=400, detail="space_id is required")

View file

@ -5,6 +5,7 @@ import logging
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select
from app.auth.context import AuthContext
from app.db import User, async_session_maker
from app.schemas.auth import (
LogoutAllResponse,
@ -13,7 +14,7 @@ from app.schemas.auth import (
RefreshTokenRequest,
RefreshTokenResponse,
)
from app.users import current_active_user, get_jwt_strategy
from app.users import get_jwt_strategy, require_session_context
from app.utils.refresh_tokens import (
revoke_all_user_tokens,
revoke_refresh_token,
@ -83,11 +84,14 @@ async def revoke_token(request: LogoutRequest):
@router.post("/logout-all", response_model=LogoutAllResponse)
async def logout_all_devices(user: User = Depends(current_active_user)):
async def logout_all_devices(
auth: AuthContext = Depends(require_session_context),
):
"""
Logout from all devices by revoking all refresh tokens for the user.
Requires valid access token.
"""
user = auth.user
await revoke_all_user_tokens(user.id)
logger.info(f"User {user.id} logged out from all devices")
return LogoutAllResponse()

View file

@ -16,15 +16,15 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.auth.context import AuthContext
from app.config import config
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
User,
get_async_session,
)
from app.schemas.clickup_auth_credentials import ClickUpAuthCredentialsBase
from app.users import current_active_user
from app.users import require_session_context
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
logger = logging.getLogger(__name__)
@ -61,7 +61,10 @@ def get_token_encryption() -> TokenEncryption:
@router.get("/auth/clickup/connector/add")
async def connect_clickup(space_id: int, user: User = Depends(current_active_user)):
async def connect_clickup(
space_id: int,
auth: AuthContext = Depends(require_session_context),
):
"""
Initiate ClickUp OAuth flow.
@ -72,6 +75,7 @@ async def connect_clickup(space_id: int, user: User = Depends(current_active_use
Returns:
Authorization URL for redirect
"""
user = auth.user
try:
if not space_id:
raise HTTPException(status_code=400, detail="space_id is required")

View file

@ -22,6 +22,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm.attributes import flag_modified
from app.auth.context import AuthContext
from app.config import config
from app.db import (
SearchSourceConnector,
@ -35,7 +36,7 @@ from app.services.composio_service import (
TOOLKIT_TO_CONNECTOR_TYPE,
ComposioService,
)
from app.users import current_active_user
from app.users import current_active_user, require_session_context
from app.utils.connector_naming import (
count_connectors_of_type,
get_base_name_for_type,
@ -98,7 +99,7 @@ async def initiate_composio_auth(
toolkit_id: str = Query(
..., description="Composio toolkit ID (e.g., 'googledrive', 'gmail')"
),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
):
"""
Initiate Composio OAuth flow for a specific toolkit.
@ -110,6 +111,7 @@ async def initiate_composio_auth(
Returns:
JSON with auth_url to redirect user to Composio authorization
"""
user = auth.user
if not ComposioService.is_enabled():
raise HTTPException(
status_code=503,
@ -446,7 +448,7 @@ async def reauth_composio_connector(
space_id: int,
connector_id: int,
return_url: str | None = None,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
):
"""
@ -460,6 +462,7 @@ async def reauth_composio_connector(
connector_id: ID of the existing Composio connector to re-authenticate
return_url: Optional frontend path to redirect to after completion
"""
user = auth.user
if not ComposioService.is_enabled():
raise HTTPException(
status_code=503, detail="Composio integration is not enabled."

View file

@ -15,15 +15,15 @@ from pydantic import ValidationError
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.config import config
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
User,
get_async_session,
)
from app.schemas.atlassian_auth_credentials import AtlassianAuthCredentialsBase
from app.users import current_active_user
from app.users import require_session_context
from app.utils.connector_naming import (
check_duplicate_connector,
extract_identifier_from_credentials,
@ -77,7 +77,10 @@ def get_token_encryption() -> TokenEncryption:
@router.get("/auth/confluence/connector/add")
async def connect_confluence(space_id: int, user: User = Depends(current_active_user)):
async def connect_confluence(
space_id: int,
auth: AuthContext = Depends(require_session_context),
):
"""
Initiate Confluence OAuth flow.
@ -88,6 +91,7 @@ async def connect_confluence(space_id: int, user: User = Depends(current_active_
Returns:
Authorization URL for redirect
"""
user = auth.user
try:
if not space_id:
raise HTTPException(status_code=400, detail="space_id is required")
@ -421,10 +425,11 @@ async def reauth_confluence(
space_id: int,
connector_id: int,
return_url: str | None = None,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
):
"""Initiate Confluence re-authentication to upgrade OAuth scopes."""
user = auth.user
try:
from sqlalchemy.future import select

View file

@ -15,6 +15,7 @@ from pydantic import ValidationError
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.config import config
from app.db import (
SearchSourceConnector,
@ -23,7 +24,7 @@ from app.db import (
get_async_session,
)
from app.schemas.discord_auth_credentials import DiscordAuthCredentialsBase
from app.users import current_active_user
from app.users import current_active_user, require_session_context
from app.utils.connector_naming import (
check_duplicate_connector,
extract_identifier_from_credentials,
@ -77,7 +78,10 @@ def get_token_encryption() -> TokenEncryption:
@router.get("/auth/discord/connector/add")
async def connect_discord(space_id: int, user: User = Depends(current_active_user)):
async def connect_discord(
space_id: int,
auth: AuthContext = Depends(require_session_context),
):
"""
Initiate Discord OAuth flow.
@ -88,6 +92,7 @@ async def connect_discord(space_id: int, user: User = Depends(current_active_use
Returns:
Authorization URL for redirect
"""
user = auth.user
try:
if not space_id:
raise HTTPException(status_code=400, detail="space_id is required")

View file

@ -21,6 +21,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm.attributes import flag_modified
from app.auth.context import AuthContext
from app.config import config
from app.connectors.dropbox import DropboxClient, list_folder_contents
from app.db import (
@ -29,7 +30,7 @@ from app.db import (
User,
get_async_session,
)
from app.users import current_active_user
from app.users import current_active_user, require_session_context
from app.utils.connector_naming import (
check_duplicate_connector,
extract_identifier_from_credentials,
@ -66,8 +67,12 @@ def get_token_encryption() -> TokenEncryption:
@router.get("/auth/dropbox/connector/add")
async def connect_dropbox(space_id: int, user: User = Depends(current_active_user)):
async def connect_dropbox(
space_id: int,
auth: AuthContext = Depends(require_session_context),
):
"""Initiate Dropbox OAuth flow."""
user = auth.user
try:
if not space_id:
raise HTTPException(status_code=400, detail="space_id is required")
@ -109,10 +114,11 @@ async def reauth_dropbox(
space_id: int,
connector_id: int,
return_url: str | None = None,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
):
"""Re-authenticate an existing Dropbox connector."""
user = auth.user
try:
result = await session.execute(
select(SearchSourceConnector).filter(

View file

@ -15,15 +15,15 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.auth.context import AuthContext
from app.config import config
from app.connectors.google_gmail_connector import fetch_google_user_email
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
User,
get_async_session,
)
from app.users import current_active_user
from app.users import require_session_context
from app.utils.connector_naming import (
check_duplicate_connector,
generate_unique_connector_name,
@ -88,7 +88,11 @@ def get_google_flow():
@router.get("/auth/google/calendar/connector/add")
async def connect_calendar(space_id: int, user: User = Depends(current_active_user)):
async def connect_calendar(
space_id: int,
auth: AuthContext = Depends(require_session_context),
):
user = auth.user
try:
if not space_id:
raise HTTPException(status_code=400, detail="space_id is required")
@ -127,10 +131,11 @@ async def reauth_calendar(
space_id: int,
connector_id: int,
return_url: str | None = None,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
):
"""Initiate Google Calendar re-authentication for an existing connector."""
user = auth.user
try:
result = await session.execute(
select(SearchSourceConnector).filter(

View file

@ -23,6 +23,7 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.auth.context import AuthContext
from app.config import config
from app.connectors.google_drive import (
GoogleDriveClient,
@ -36,7 +37,7 @@ from app.db import (
User,
get_async_session,
)
from app.users import current_active_user
from app.users import current_active_user, require_session_context
from app.utils.connector_naming import (
check_duplicate_connector,
generate_unique_connector_name,
@ -110,7 +111,10 @@ def get_google_flow():
@router.get("/auth/google/drive/connector/add")
async def connect_drive(space_id: int, user: User = Depends(current_active_user)):
async def connect_drive(
space_id: int,
auth: AuthContext = Depends(require_session_context),
):
"""
Initiate Google Drive OAuth flow.
@ -120,6 +124,7 @@ async def connect_drive(space_id: int, user: User = Depends(current_active_user)
Returns:
JSON with auth_url to redirect user to Google authorization
"""
user = auth.user
try:
if not space_id:
raise HTTPException(status_code=400, detail="space_id is required")
@ -165,7 +170,7 @@ async def reauth_drive(
space_id: int,
connector_id: int,
return_url: str | None = None,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
):
"""
@ -178,6 +183,7 @@ async def reauth_drive(
Returns:
JSON with auth_url to redirect user to Google authorization
"""
user = auth.user
try:
result = await session.execute(
select(SearchSourceConnector).filter(

View file

@ -15,15 +15,15 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.auth.context import AuthContext
from app.config import config
from app.connectors.google_gmail_connector import fetch_google_user_email
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
User,
get_async_session,
)
from app.users import current_active_user
from app.users import require_session_context
from app.utils.connector_naming import (
check_duplicate_connector,
generate_unique_connector_name,
@ -92,7 +92,10 @@ def get_google_flow():
@router.get("/auth/google/gmail/connector/add")
async def connect_gmail(space_id: int, user: User = Depends(current_active_user)):
async def connect_gmail(
space_id: int,
auth: AuthContext = Depends(require_session_context),
):
"""
Initiate Google Gmail OAuth flow.
@ -102,6 +105,7 @@ async def connect_gmail(space_id: int, user: User = Depends(current_active_user)
Returns:
JSON with auth_url to redirect user to Google authorization
"""
user = auth.user
try:
if not space_id:
raise HTTPException(status_code=400, detail="space_id is required")
@ -145,10 +149,11 @@ async def reauth_gmail(
space_id: int,
connector_id: int,
return_url: str | None = None,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
):
"""Initiate Gmail re-authentication for an existing connector."""
user = auth.user
try:
result = await session.execute(
select(SearchSourceConnector).filter(

View file

@ -16,15 +16,15 @@ from pydantic import ValidationError
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.config import config
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
User,
get_async_session,
)
from app.schemas.atlassian_auth_credentials import AtlassianAuthCredentialsBase
from app.users import current_active_user
from app.users import require_session_context
from app.utils.connector_naming import (
check_duplicate_connector,
extract_identifier_from_credentials,
@ -75,7 +75,10 @@ def get_token_encryption() -> TokenEncryption:
@router.get("/auth/jira/connector/add")
async def connect_jira(space_id: int, user: User = Depends(current_active_user)):
async def connect_jira(
space_id: int,
auth: AuthContext = Depends(require_session_context),
):
"""
Initiate Jira OAuth flow.
@ -86,6 +89,7 @@ async def connect_jira(space_id: int, user: User = Depends(current_active_user))
Returns:
Authorization URL for redirect
"""
user = auth.user
try:
if not space_id:
raise HTTPException(status_code=400, detail="space_id is required")
@ -438,10 +442,11 @@ async def reauth_jira(
space_id: int,
connector_id: int,
return_url: str | None = None,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
):
"""Initiate Jira re-authentication to upgrade OAuth scopes."""
user = auth.user
try:
from sqlalchemy.future import select

View file

@ -17,16 +17,16 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.auth.context import AuthContext
from app.config import config
from app.connectors.linear_connector import fetch_linear_organization_name
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
User,
get_async_session,
)
from app.schemas.linear_auth_credentials import LinearAuthCredentialsBase
from app.users import current_active_user
from app.users import require_session_context
from app.utils.connector_naming import (
check_duplicate_connector,
generate_unique_connector_name,
@ -79,7 +79,10 @@ def make_basic_auth_header(client_id: str, client_secret: str) -> str:
@router.get("/auth/linear/connector/add")
async def connect_linear(space_id: int, user: User = Depends(current_active_user)):
async def connect_linear(
space_id: int,
auth: AuthContext = Depends(require_session_context),
):
"""
Initiate Linear OAuth flow.
@ -90,6 +93,7 @@ async def connect_linear(space_id: int, user: User = Depends(current_active_user
Returns:
Authorization URL for redirect
"""
user = auth.user
try:
if not space_id:
raise HTTPException(status_code=400, detail="space_id is required")
@ -134,10 +138,11 @@ async def reauth_linear(
space_id: int,
connector_id: int,
return_url: str | None = None,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
):
"""Initiate Linear re-authentication for an existing connector."""
user = auth.user
try:
result = await session.execute(
select(SearchSourceConnector).filter(

View file

@ -6,13 +6,13 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.auth.context import AuthContext
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
User,
get_async_session,
)
from app.users import current_active_user
from app.users import require_session_context
logger = logging.getLogger(__name__)
@ -29,7 +29,7 @@ class AddLumaConnectorRequest(BaseModel):
@router.post("/connectors/luma/add")
async def add_luma_connector(
request: AddLumaConnectorRequest,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
):
"""
@ -46,6 +46,7 @@ async def add_luma_connector(
Raises:
HTTPException: If connector already exists or validation fails
"""
user = auth.user
try:
# Check if a Luma connector already exists for this search space and user
result = await session.execute(
@ -118,7 +119,7 @@ async def add_luma_connector(
@router.delete("/connectors/luma")
async def delete_luma_connector(
space_id: int,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
):
"""
@ -135,6 +136,7 @@ async def delete_luma_connector(
Raises:
HTTPException: If connector doesn't exist
"""
user = auth.user
try:
result = await session.execute(
select(SearchSourceConnector).filter(
@ -173,7 +175,7 @@ async def delete_luma_connector(
@router.get("/connectors/luma/test")
async def test_luma_connector(
space_id: int,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
):
"""
@ -190,6 +192,7 @@ async def test_luma_connector(
Raises:
HTTPException: If connector doesn't exist or test fails
"""
user = auth.user
try:
# Get the Luma connector for this search space and user
result = await session.execute(

View file

@ -20,14 +20,14 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.auth.context import AuthContext
from app.config import config
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
User,
get_async_session,
)
from app.users import current_active_user
from app.users import require_session_context
from app.utils.connector_naming import generate_unique_connector_name
from app.utils.oauth_security import (
OAuthStateManager,
@ -164,8 +164,9 @@ def _frontend_redirect(
async def connect_mcp_service(
service: str,
space_id: int,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
):
user = auth.user
from app.services.mcp_oauth.registry import get_service
svc = get_service(service)
@ -523,9 +524,10 @@ async def reauth_mcp_service(
space_id: int,
connector_id: int,
return_url: str | None = None,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
):
user = auth.user
from app.services.mcp_oauth.registry import get_service
svc = get_service(service)

View file

@ -17,15 +17,15 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.auth.context import AuthContext
from app.config import config
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
User,
get_async_session,
)
from app.schemas.notion_auth_credentials import NotionAuthCredentialsBase
from app.users import current_active_user
from app.users import require_session_context
from app.utils.connector_naming import (
check_duplicate_connector,
extract_identifier_from_credentials,
@ -76,7 +76,10 @@ def make_basic_auth_header(client_id: str, client_secret: str) -> str:
@router.get("/auth/notion/connector/add")
async def connect_notion(space_id: int, user: User = Depends(current_active_user)):
async def connect_notion(
space_id: int,
auth: AuthContext = Depends(require_session_context),
):
"""
Initiate Notion OAuth flow.
@ -87,6 +90,7 @@ async def connect_notion(space_id: int, user: User = Depends(current_active_user
Returns:
Authorization URL for redirect
"""
user = auth.user
try:
if not space_id:
raise HTTPException(status_code=400, detail="space_id is required")
@ -131,10 +135,11 @@ async def reauth_notion(
space_id: int,
connector_id: int,
return_url: str | None = None,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
):
"""Initiate Notion re-authentication for an existing connector."""
user = auth.user
try:
result = await session.execute(
select(SearchSourceConnector).filter(

View file

@ -24,14 +24,14 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.auth.context import AuthContext
from app.config import config
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
User,
get_async_session,
)
from app.users import current_active_user
from app.users import require_session_context
from app.utils.connector_naming import (
check_duplicate_connector,
generate_unique_connector_name,
@ -361,8 +361,9 @@ class OAuthConnectorRoute:
@router.get(f"{oauth.auth_prefix}/connector/add")
async def connect(
space_id: int,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
):
user = auth.user
if not space_id:
raise HTTPException(status_code=400, detail="space_id is required")
@ -406,9 +407,10 @@ class OAuthConnectorRoute:
space_id: int,
connector_id: int,
return_url: str | None = None,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
):
user = auth.user
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == connector_id,

View file

@ -21,6 +21,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm.attributes import flag_modified
from app.auth.context import AuthContext
from app.config import config
from app.connectors.onedrive import OneDriveClient, list_folder_contents
from app.db import (
@ -29,7 +30,7 @@ from app.db import (
User,
get_async_session,
)
from app.users import current_active_user
from app.users import current_active_user, require_session_context
from app.utils.connector_naming import (
check_duplicate_connector,
extract_identifier_from_credentials,
@ -73,8 +74,12 @@ def get_token_encryption() -> TokenEncryption:
@router.get("/auth/onedrive/connector/add")
async def connect_onedrive(space_id: int, user: User = Depends(current_active_user)):
async def connect_onedrive(
space_id: int,
auth: AuthContext = Depends(require_session_context),
):
"""Initiate OneDrive OAuth flow."""
user = auth.user
try:
if not space_id:
raise HTTPException(status_code=400, detail="space_id is required")
@ -119,10 +124,11 @@ async def reauth_onedrive(
space_id: int,
connector_id: int,
return_url: str | None = None,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session),
):
"""Re-authenticate an existing OneDrive connector."""
user = auth.user
try:
result = await session.execute(
select(SearchSourceConnector).filter(

View file

@ -17,6 +17,7 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.auth.context import AuthContext
from app.config import config
from app.db import (
SearchSourceConnector,
@ -25,7 +26,7 @@ from app.db import (
get_async_session,
)
from app.schemas.slack_auth_credentials import SlackAuthCredentialsBase
from app.users import current_active_user
from app.users import current_active_user, require_session_context
from app.utils.connector_naming import (
check_duplicate_connector,
extract_identifier_from_credentials,
@ -78,7 +79,10 @@ def get_token_encryption() -> TokenEncryption:
@router.get("/auth/slack/connector/add")
async def connect_slack(space_id: int, user: User = Depends(current_active_user)):
async def connect_slack(
space_id: int,
auth: AuthContext = Depends(require_session_context),
):
"""
Initiate Slack OAuth flow.
@ -89,6 +93,7 @@ async def connect_slack(space_id: int, user: User = Depends(current_active_user)
Returns:
Authorization URL for redirect
"""
user = auth.user
try:
if not space_id:
raise HTTPException(status_code=400, detail="space_id is required")

View file

@ -18,6 +18,7 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from stripe import SignatureVerificationError, StripeClient, StripeError
from app.auth.context import AuthContext
from app.config import config
from app.db import (
CreditPurchase,
@ -39,7 +40,7 @@ from app.schemas.stripe import (
StripeWebhookResponse,
UpdateAutoReloadSettingsRequest,
)
from app.users import current_active_user
from app.users import require_session_context
logger = logging.getLogger(__name__)
@ -456,7 +457,7 @@ async def _reconcile_auto_reload_payment_intent(
)
async def create_credit_checkout_session(
body: CreateCreditCheckoutSessionRequest,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
db_session: AsyncSession = Depends(get_async_session),
) -> CreateCreditCheckoutSessionResponse:
"""Create a Stripe Checkout Session for buying credit packs.
@ -466,6 +467,7 @@ async def create_credit_checkout_session(
cost reported by LiteLLM (premium calls) or ``MICROS_PER_PAGE`` per page
(ETL), so $1 of credit always buys $1 worth of usage at cost.
"""
user = auth.user
_ensure_credit_buying_enabled()
stripe_client = get_stripe_client()
price_id = _get_required_credit_price_id()
@ -644,7 +646,7 @@ async def stripe_webhook(
@router.get("/finalize-checkout", response_model=FinalizeCheckoutResponse)
async def finalize_checkout(
session_id: str,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
db_session: AsyncSession = Depends(get_async_session),
) -> FinalizeCheckoutResponse:
"""Synchronously fulfil a credit checkout session from the success page.
@ -659,6 +661,7 @@ async def finalize_checkout(
Authorization: the session's ``client_reference_id`` must match the
authenticated user's id.
"""
user = auth.user
stripe_client = get_stripe_client()
try:
@ -718,13 +721,14 @@ async def finalize_checkout(
@router.get("/credit-status", response_model=CreditStripeStatusResponse)
async def get_credit_status(
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
) -> CreditStripeStatusResponse:
"""Return credit-buying availability and current balance for the frontend.
``credit_micros_balance`` is in micro-USD (1_000_000 = $1.00); the FE
divides by 1M when displaying.
"""
user = auth.user
return CreditStripeStatusResponse(
credit_buying_enabled=config.STRIPE_CREDIT_BUYING_ENABLED,
credit_micros_balance=user.credit_micros_balance,
@ -733,12 +737,13 @@ async def get_credit_status(
@router.get("/credit-purchases", response_model=CreditPurchaseHistoryResponse)
async def get_credit_purchases(
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
db_session: AsyncSession = Depends(get_async_session),
offset: int = 0,
limit: int = 50,
) -> CreditPurchaseHistoryResponse:
"""Return the authenticated user's credit purchase history."""
user = auth.user
limit = min(limit, 100)
purchases = (
(
@ -759,7 +764,7 @@ async def get_credit_purchases(
@router.get("/purchases", response_model=PagePurchaseHistoryResponse)
async def get_page_purchases(
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
db_session: AsyncSession = Depends(get_async_session),
offset: int = 0,
limit: int = 50,
@ -768,6 +773,7 @@ async def get_page_purchases(
Page buying is removed; this endpoint stays for historical records.
"""
user = auth.user
limit = min(limit, 100)
purchases = (
(
@ -804,7 +810,7 @@ def _auto_reload_settings_response(user: User) -> AutoReloadSettingsResponse:
)
async def create_auto_reload_setup_session(
body: CreateAutoReloadSetupSessionRequest,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
db_session: AsyncSession = Depends(get_async_session),
) -> CreateAutoReloadSetupSessionResponse:
"""Start a ``mode=setup`` checkout session to save a card for auto-reload.
@ -813,6 +819,7 @@ async def create_auto_reload_setup_session(
Customer so the card can later be charged off-session. On completion the
webhook stores the resulting payment method on the user.
"""
user = auth.user
_ensure_auto_reload_enabled()
_ensure_credit_buying_enabled()
stripe_client = get_stripe_client()
@ -871,16 +878,17 @@ async def create_auto_reload_setup_session(
@router.get("/auto-reload", response_model=AutoReloadSettingsResponse)
async def get_auto_reload_settings(
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
) -> AutoReloadSettingsResponse:
"""Return the user's auto-reload configuration and saved-card state."""
user = auth.user
return _auto_reload_settings_response(user)
@router.put("/auto-reload", response_model=AutoReloadSettingsResponse)
async def update_auto_reload_settings(
body: UpdateAutoReloadSettingsRequest,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(require_session_context),
db_session: AsyncSession = Depends(get_async_session),
) -> AutoReloadSettingsResponse:
"""Update auto-reload preferences.
@ -889,6 +897,7 @@ async def update_auto_reload_settings(
at least ``AUTO_RELOAD_MIN_AMOUNT_MICROS``. Disabling always succeeds and
clears any prior failure flag.
"""
user = auth.user
_ensure_auto_reload_enabled()
locked = (

View file

@ -14,15 +14,15 @@ from fastapi.responses import RedirectResponse
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.config import config
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
User,
get_async_session,
)
from app.schemas.teams_auth_credentials import TeamsAuthCredentialsBase
from app.users import current_active_user
from app.users import require_session_context
from app.utils.connector_naming import (
check_duplicate_connector,
extract_identifier_from_credentials,
@ -74,7 +74,10 @@ def get_token_encryption() -> TokenEncryption:
@router.get("/auth/teams/connector/add")
async def connect_teams(space_id: int, user: User = Depends(current_active_user)):
async def connect_teams(
space_id: int,
auth: AuthContext = Depends(require_session_context),
):
"""
Initiate Microsoft Teams OAuth flow.
@ -85,6 +88,7 @@ async def connect_teams(space_id: int, user: User = Depends(current_active_user)
Returns:
Authorization URL for redirect
"""
user = auth.user
try:
if not space_id:
raise HTTPException(status_code=400, detail="space_id is required")

View file

@ -14,9 +14,10 @@ from fastapi_users.authentication import (
from fastapi_users.db import SQLAlchemyUserDatabase
from pydantic import BaseModel
from sqlalchemy import update
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.auth.context import AuthContext
from app.config import config
from app.db import (
Prompt,
SearchSpace,
@ -31,7 +32,6 @@ from app.db import (
from app.prompts.system_defaults import SYSTEM_PROMPT_DEFAULTS
from app.utils.pat import PAT_PREFIX, maybe_touch_last_used, resolve_pat
from app.utils.refresh_tokens import create_refresh_token
from sqlalchemy.ext.asyncio import AsyncSession
logger = logging.getLogger(__name__)
@ -308,6 +308,12 @@ async def get_auth_context(
session: AsyncSession = Depends(get_async_session),
user_manager: UserManager = Depends(get_user_manager),
) -> AuthContext:
"""Resolve the authenticated principal.
Use this for authorization-sensitive routes where session-vs-PAT matters.
FastAPI-Users still handles JWT mechanics; PATs are resolved here so RBAC
receives the full SurfSense principal instead of a bare User.
"""
auth_header = request.headers.get("Authorization")
if not auth_header:
raise HTTPException(
@ -346,12 +352,18 @@ async def get_auth_context(
async def current_active_user(
auth: AuthContext = Depends(get_auth_context),
) -> User:
"""Compatibility wrapper for identity-only routes.
Do not use this for space-scoped authorization or session-grade account
actions. Those should depend on get_auth_context or require_session_context.
"""
return auth.user
async def require_session_context(
auth: AuthContext = Depends(get_auth_context),
) -> AuthContext:
"""Require an interactive session and reject PAT-authenticated requests."""
if not auth.is_session:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,