mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-24 21:38:09 +02:00
feat(auth): require sessions for user-scoped routes
This commit is contained in:
parent
2315b2f344
commit
1f9cf326e5
10 changed files with 75 additions and 62 deletions
|
|
@ -8,7 +8,8 @@ from fastapi import APIRouter, Depends, HTTPException, Query, status
|
|||
from sqlalchemy import case, desc, func, literal, literal_column, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import User, get_async_session
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import get_async_session
|
||||
from app.notifications.api.schemas import (
|
||||
BatchUnreadCountResponse,
|
||||
CategoryUnreadCount,
|
||||
|
|
@ -27,7 +28,7 @@ from app.notifications.api.transform import (
|
|||
from app.notifications.constants import CATEGORY_TYPES, SYNC_WINDOW_DAYS
|
||||
from app.notifications.persistence import Notification
|
||||
from app.notifications.types import NotificationCategory, NotificationType
|
||||
from app.users import current_active_user
|
||||
from app.users import require_session_context
|
||||
|
||||
router = APIRouter(prefix="/notifications", tags=["notifications"])
|
||||
|
||||
|
|
@ -35,10 +36,11 @@ router = APIRouter(prefix="/notifications", tags=["notifications"])
|
|||
@router.get("/unread-counts-batch", response_model=BatchUnreadCountResponse)
|
||||
async def get_unread_counts_batch(
|
||||
search_space_id: int | None = Query(None, description="Filter by search space ID"),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> BatchUnreadCountResponse:
|
||||
"""Unread counts for every category in a single query."""
|
||||
user = auth.user
|
||||
cutoff_date = datetime.now(UTC) - timedelta(days=SYNC_WINDOW_DAYS)
|
||||
|
||||
base_filter = [
|
||||
|
|
@ -86,10 +88,11 @@ async def get_unread_counts_batch(
|
|||
@router.get("/source-types", response_model=SourceTypesResponse)
|
||||
async def get_notification_source_types(
|
||||
search_space_id: int | None = Query(None, description="Filter by search space ID"),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> SourceTypesResponse:
|
||||
"""Distinct connector/document source types for the Status tab filter."""
|
||||
user = auth.user
|
||||
base_filter = [Notification.user_id == user.id]
|
||||
|
||||
if search_space_id is not None:
|
||||
|
|
@ -160,7 +163,7 @@ async def get_unread_count(
|
|||
category: NotificationCategory | None = Query(
|
||||
None, description="Filter by category: 'comments' or 'status'"
|
||||
),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> UnreadCountResponse:
|
||||
"""Total and recent (within sync window) unread counts for the user.
|
||||
|
|
@ -168,6 +171,7 @@ async def get_unread_count(
|
|||
Returning both lets a client hold the older count static while
|
||||
live-syncing the recent ones.
|
||||
"""
|
||||
user = auth.user
|
||||
cutoff_date = datetime.now(UTC) - timedelta(days=SYNC_WINDOW_DAYS)
|
||||
|
||||
base_filter = [
|
||||
|
|
@ -230,10 +234,11 @@ async def list_notifications(
|
|||
),
|
||||
limit: int = Query(50, ge=1, le=100, description="Number of items to return"),
|
||||
offset: int = Query(0, ge=0, description="Number of items to skip"),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> NotificationListResponse:
|
||||
"""Paginated inbox fallback for items outside the Zero sync window."""
|
||||
user = auth.user
|
||||
query = select(Notification).where(Notification.user_id == user.id)
|
||||
count_query = select(func.count(Notification.id)).where(
|
||||
Notification.user_id == user.id
|
||||
|
|
@ -328,10 +333,11 @@ async def list_notifications(
|
|||
@router.patch("/{notification_id}/read", response_model=MarkReadResponse)
|
||||
async def mark_notification_as_read(
|
||||
notification_id: int,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> MarkReadResponse:
|
||||
"""Mark one of the user's notifications read; Zero syncs the change."""
|
||||
user = auth.user
|
||||
# Scope to the caller's own notifications.
|
||||
result = await session.execute(
|
||||
select(Notification).where(
|
||||
|
|
@ -364,10 +370,11 @@ async def mark_notification_as_read(
|
|||
|
||||
@router.patch("/read-all", response_model=MarkAllReadResponse)
|
||||
async def mark_all_notifications_as_read(
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> MarkAllReadResponse:
|
||||
"""Mark all of the user's notifications read; Zero syncs the changes."""
|
||||
user = auth.user
|
||||
result = await session.execute(
|
||||
update(Notification)
|
||||
.where(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue