mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-24 21:38:09 +02:00
feat(auth): require sessions for user-scoped routes
This commit is contained in:
parent
2315b2f344
commit
1f9cf326e5
10 changed files with 75 additions and 62 deletions
|
|
@ -8,7 +8,8 @@ from fastapi import APIRouter, Depends, HTTPException, Query, status
|
|||
from sqlalchemy import case, desc, func, literal, literal_column, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import User, get_async_session
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import get_async_session
|
||||
from app.notifications.api.schemas import (
|
||||
BatchUnreadCountResponse,
|
||||
CategoryUnreadCount,
|
||||
|
|
@ -27,7 +28,7 @@ from app.notifications.api.transform import (
|
|||
from app.notifications.constants import CATEGORY_TYPES, SYNC_WINDOW_DAYS
|
||||
from app.notifications.persistence import Notification
|
||||
from app.notifications.types import NotificationCategory, NotificationType
|
||||
from app.users import current_active_user
|
||||
from app.users import require_session_context
|
||||
|
||||
router = APIRouter(prefix="/notifications", tags=["notifications"])
|
||||
|
||||
|
|
@ -35,10 +36,11 @@ router = APIRouter(prefix="/notifications", tags=["notifications"])
|
|||
@router.get("/unread-counts-batch", response_model=BatchUnreadCountResponse)
|
||||
async def get_unread_counts_batch(
|
||||
search_space_id: int | None = Query(None, description="Filter by search space ID"),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> BatchUnreadCountResponse:
|
||||
"""Unread counts for every category in a single query."""
|
||||
user = auth.user
|
||||
cutoff_date = datetime.now(UTC) - timedelta(days=SYNC_WINDOW_DAYS)
|
||||
|
||||
base_filter = [
|
||||
|
|
@ -86,10 +88,11 @@ async def get_unread_counts_batch(
|
|||
@router.get("/source-types", response_model=SourceTypesResponse)
|
||||
async def get_notification_source_types(
|
||||
search_space_id: int | None = Query(None, description="Filter by search space ID"),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> SourceTypesResponse:
|
||||
"""Distinct connector/document source types for the Status tab filter."""
|
||||
user = auth.user
|
||||
base_filter = [Notification.user_id == user.id]
|
||||
|
||||
if search_space_id is not None:
|
||||
|
|
@ -160,7 +163,7 @@ async def get_unread_count(
|
|||
category: NotificationCategory | None = Query(
|
||||
None, description="Filter by category: 'comments' or 'status'"
|
||||
),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> UnreadCountResponse:
|
||||
"""Total and recent (within sync window) unread counts for the user.
|
||||
|
|
@ -168,6 +171,7 @@ async def get_unread_count(
|
|||
Returning both lets a client hold the older count static while
|
||||
live-syncing the recent ones.
|
||||
"""
|
||||
user = auth.user
|
||||
cutoff_date = datetime.now(UTC) - timedelta(days=SYNC_WINDOW_DAYS)
|
||||
|
||||
base_filter = [
|
||||
|
|
@ -230,10 +234,11 @@ async def list_notifications(
|
|||
),
|
||||
limit: int = Query(50, ge=1, le=100, description="Number of items to return"),
|
||||
offset: int = Query(0, ge=0, description="Number of items to skip"),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> NotificationListResponse:
|
||||
"""Paginated inbox fallback for items outside the Zero sync window."""
|
||||
user = auth.user
|
||||
query = select(Notification).where(Notification.user_id == user.id)
|
||||
count_query = select(func.count(Notification.id)).where(
|
||||
Notification.user_id == user.id
|
||||
|
|
@ -328,10 +333,11 @@ async def list_notifications(
|
|||
@router.patch("/{notification_id}/read", response_model=MarkReadResponse)
|
||||
async def mark_notification_as_read(
|
||||
notification_id: int,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> MarkReadResponse:
|
||||
"""Mark one of the user's notifications read; Zero syncs the change."""
|
||||
user = auth.user
|
||||
# Scope to the caller's own notifications.
|
||||
result = await session.execute(
|
||||
select(Notification).where(
|
||||
|
|
@ -364,10 +370,11 @@ async def mark_notification_as_read(
|
|||
|
||||
@router.patch("/read-all", response_model=MarkAllReadResponse)
|
||||
async def mark_all_notifications_as_read(
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> MarkAllReadResponse:
|
||||
"""Mark all of the user's notifications read; Zero syncs the changes."""
|
||||
user = auth.user
|
||||
result = await session.execute(
|
||||
update(Notification)
|
||||
.where(
|
||||
|
|
|
|||
|
|
@ -26,9 +26,9 @@ from app.agents.chat.multi_agent_chat.shared.feature_flags import (
|
|||
AgentFeatureFlags,
|
||||
get_flags,
|
||||
)
|
||||
from app.auth.context import AuthContext
|
||||
from app.config import config
|
||||
from app.db import User
|
||||
from app.users import current_active_user
|
||||
from app.users import require_session_context
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -75,6 +75,6 @@ class AgentFeatureFlagsRead(BaseModel):
|
|||
|
||||
@router.get("/agent/flags", response_model=AgentFeatureFlagsRead)
|
||||
async def get_agent_flags(
|
||||
_user: User = Depends(current_active_user),
|
||||
_auth: AuthContext = Depends(require_session_context),
|
||||
) -> AgentFeatureFlagsRead:
|
||||
return AgentFeatureFlagsRead.from_flags(get_flags())
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from fastapi import APIRouter, Depends
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import User, get_async_session
|
||||
from app.db import get_async_session
|
||||
from app.schemas.chat_comments import (
|
||||
CommentBatchRequest,
|
||||
CommentBatchResponse,
|
||||
|
|
@ -26,7 +26,7 @@ from app.services.chat_comments_service import (
|
|||
get_user_mentions,
|
||||
update_comment,
|
||||
)
|
||||
from app.users import get_auth_context
|
||||
from app.users import require_session_context
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -35,22 +35,20 @@ router = APIRouter()
|
|||
async def batch_list_comments(
|
||||
request: CommentBatchRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Batch-fetch comments for multiple messages in one request."""
|
||||
return await get_comments_for_messages_batch(session, request.message_ids, user)
|
||||
return await get_comments_for_messages_batch(session, request.message_ids, auth)
|
||||
|
||||
|
||||
@router.get("/messages/{message_id}/comments", response_model=CommentListResponse)
|
||||
async def list_comments(
|
||||
message_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""List all comments for a message with their replies."""
|
||||
return await get_comments_for_message(session, message_id, user)
|
||||
return await get_comments_for_message(session, message_id, auth)
|
||||
|
||||
|
||||
@router.post("/messages/{message_id}/comments", response_model=CommentResponse)
|
||||
|
|
@ -58,11 +56,10 @@ async def add_comment(
|
|||
message_id: int,
|
||||
request: CommentCreateRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Create a top-level comment on an AI response."""
|
||||
return await create_comment(session, message_id, request.content, user)
|
||||
return await create_comment(session, message_id, request.content, auth)
|
||||
|
||||
|
||||
@router.post("/comments/{comment_id}/replies", response_model=CommentReplyResponse)
|
||||
|
|
@ -70,11 +67,10 @@ async def add_reply(
|
|||
comment_id: int,
|
||||
request: CommentCreateRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Reply to an existing comment."""
|
||||
return await create_reply(session, comment_id, request.content, user)
|
||||
return await create_reply(session, comment_id, request.content, auth)
|
||||
|
||||
|
||||
@router.put("/comments/{comment_id}", response_model=CommentReplyResponse)
|
||||
|
|
@ -82,22 +78,20 @@ async def edit_comment(
|
|||
comment_id: int,
|
||||
request: CommentUpdateRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Update a comment's content (author only)."""
|
||||
return await update_comment(session, comment_id, request.content, user)
|
||||
return await update_comment(session, comment_id, request.content, auth)
|
||||
|
||||
|
||||
@router.delete("/comments/{comment_id}")
|
||||
async def remove_comment(
|
||||
comment_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Delete a comment (author or user with COMMENTS_DELETE permission)."""
|
||||
return await delete_comment(session, comment_id, user)
|
||||
return await delete_comment(session, comment_id, auth)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
|
@ -109,8 +103,7 @@ async def remove_comment(
|
|||
async def list_mentions(
|
||||
search_space_id: int | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""List mentions for the current user."""
|
||||
return await get_user_mentions(session, user, search_space_id)
|
||||
return await get_user_mentions(session, auth, search_space_id)
|
||||
|
|
|
|||
|
|
@ -8,10 +8,10 @@ from fastapi import APIRouter, Depends, HTTPException, status
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import (
|
||||
INCENTIVE_TASKS_CONFIG,
|
||||
IncentiveTaskType,
|
||||
User,
|
||||
UserIncentiveTask,
|
||||
get_async_session,
|
||||
)
|
||||
|
|
@ -21,19 +21,20 @@ from app.schemas.incentive_tasks import (
|
|||
IncentiveTasksResponse,
|
||||
TaskAlreadyCompletedResponse,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.users import require_session_context
|
||||
|
||||
router = APIRouter(prefix="/incentive-tasks", tags=["incentive-tasks"])
|
||||
|
||||
|
||||
@router.get("", response_model=IncentiveTasksResponse)
|
||||
async def get_incentive_tasks(
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> IncentiveTasksResponse:
|
||||
"""
|
||||
Get all available incentive tasks with the user's completion status.
|
||||
"""
|
||||
user = auth.user
|
||||
# Get all completed tasks for this user
|
||||
result = await session.execute(
|
||||
select(UserIncentiveTask).where(UserIncentiveTask.user_id == user.id)
|
||||
|
|
@ -75,7 +76,7 @@ async def get_incentive_tasks(
|
|||
)
|
||||
async def complete_task(
|
||||
task_type: IncentiveTaskType,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> CompleteTaskResponse | TaskAlreadyCompletedResponse:
|
||||
"""
|
||||
|
|
@ -84,6 +85,7 @@ async def complete_task(
|
|||
Each task can only be completed once. If the task was already completed,
|
||||
returns the existing completion information without awarding additional credit.
|
||||
"""
|
||||
user = auth.user
|
||||
# Validate task type exists in config
|
||||
task_config = INCENTIVE_TASKS_CONFIG.get(task_type)
|
||||
if not task_config:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,8 @@ from fastapi import APIRouter, Depends, HTTPException
|
|||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import User, get_async_session
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import get_async_session
|
||||
from app.services.memory import (
|
||||
MemoryRead,
|
||||
MemoryScope,
|
||||
|
|
@ -15,7 +16,7 @@ from app.services.memory import (
|
|||
reset_memory,
|
||||
save_memory,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.users import require_session_context
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -26,9 +27,10 @@ class MemoryUpdate(BaseModel):
|
|||
|
||||
@router.get("/users/me/memory", response_model=MemoryRead)
|
||||
async def get_user_memory(
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
user = auth.user
|
||||
memory_md = await read_memory(
|
||||
scope=MemoryScope.USER,
|
||||
target_id=user.id,
|
||||
|
|
@ -40,9 +42,10 @@ async def get_user_memory(
|
|||
@router.put("/users/me/memory", response_model=MemoryRead)
|
||||
async def update_user_memory(
|
||||
body: MemoryUpdate,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
user = auth.user
|
||||
result = await save_memory(
|
||||
scope=MemoryScope.USER,
|
||||
target_id=user.id,
|
||||
|
|
@ -56,9 +59,10 @@ async def update_user_memory(
|
|||
|
||||
@router.post("/users/me/memory/reset", response_model=MemoryRead)
|
||||
async def reset_user_memory(
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
user = auth.user
|
||||
result = await reset_memory(
|
||||
scope=MemoryScope.USER,
|
||||
target_id=user.id,
|
||||
|
|
|
|||
|
|
@ -10,9 +10,9 @@ import logging
|
|||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.db import User
|
||||
from app.auth.context import AuthContext
|
||||
from app.services.model_list_service import get_model_list
|
||||
from app.users import current_active_user
|
||||
from app.users import require_session_context
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -27,7 +27,7 @@ class ModelListItem(BaseModel):
|
|||
|
||||
@router.get("/models", response_model=list[ModelListItem])
|
||||
async def list_available_models(
|
||||
user: User = Depends(current_active_user),
|
||||
_auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
"""
|
||||
Return all available models grouped by provider.
|
||||
|
|
|
|||
|
|
@ -1334,8 +1334,8 @@ async def append_message(
|
|||
Requires CHATS_UPDATE permission.
|
||||
"""
|
||||
try:
|
||||
# Capture ``user.id`` as a primitive UUID up front. The
|
||||
# ``current_active_user`` dependency hands us a ``User`` ORM
|
||||
# Capture ``user.id`` as a primitive UUID up front. The auth
|
||||
# dependency hands us a ``User`` ORM
|
||||
# row bound to ``session``; if the outer ``except
|
||||
# IntegrityError`` block below ever fires (an unexpected
|
||||
# constraint like a foreign key violation — the common
|
||||
|
|
|
|||
|
|
@ -3,14 +3,15 @@ from sqlalchemy import select
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.db import Prompt, SearchSpaceMembership, User, get_async_session
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import Prompt, SearchSpaceMembership, get_async_session
|
||||
from app.schemas.prompts import (
|
||||
PromptCreate,
|
||||
PromptRead,
|
||||
PromptUpdate,
|
||||
PublicPromptRead,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.users import require_session_context
|
||||
|
||||
router = APIRouter(tags=["Prompts"])
|
||||
|
||||
|
|
@ -19,8 +20,9 @@ router = APIRouter(tags=["Prompts"])
|
|||
async def list_prompts(
|
||||
search_space_id: int | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
user = auth.user
|
||||
query = select(Prompt).where(Prompt.user_id == user.id)
|
||||
if search_space_id is not None:
|
||||
query = query.where(Prompt.search_space_id == search_space_id)
|
||||
|
|
@ -33,8 +35,9 @@ async def list_prompts(
|
|||
async def create_prompt(
|
||||
body: PromptCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
user = auth.user
|
||||
if body.search_space_id is not None:
|
||||
membership = await session.execute(
|
||||
select(SearchSpaceMembership).where(
|
||||
|
|
@ -67,8 +70,9 @@ async def update_prompt(
|
|||
prompt_id: int,
|
||||
body: PromptUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
user = auth.user
|
||||
result = await session.execute(
|
||||
select(Prompt).where(
|
||||
Prompt.id == prompt_id,
|
||||
|
|
@ -99,8 +103,9 @@ async def update_prompt(
|
|||
async def delete_prompt(
|
||||
prompt_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
user = auth.user
|
||||
result = await session.execute(
|
||||
select(Prompt).where(
|
||||
Prompt.id == prompt_id,
|
||||
|
|
@ -119,8 +124,9 @@ async def delete_prompt(
|
|||
@router.get("/prompts/public", response_model=list[PublicPromptRead])
|
||||
async def list_public_prompts(
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
user = auth.user
|
||||
result = await session.execute(
|
||||
select(Prompt)
|
||||
.options(selectinload(Prompt.user))
|
||||
|
|
@ -141,8 +147,9 @@ async def list_public_prompts(
|
|||
async def copy_public_prompt(
|
||||
prompt_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
user = auth.user
|
||||
result = await session.execute(
|
||||
select(Prompt).where(
|
||||
Prompt.id == prompt_id,
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from fastapi.responses import StreamingResponse
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import User, get_async_session
|
||||
from app.db import get_async_session
|
||||
from app.schemas.new_chat import (
|
||||
CloneResponse,
|
||||
PublicChatResponse,
|
||||
|
|
@ -24,7 +24,7 @@ from app.services.public_chat_service import (
|
|||
get_snapshot_report,
|
||||
get_snapshot_video_presentation,
|
||||
)
|
||||
from app.users import get_auth_context
|
||||
from app.users import require_session_context
|
||||
|
||||
router = APIRouter(prefix="/public", tags=["public"])
|
||||
|
||||
|
|
@ -47,7 +47,7 @@ async def read_public_chat(
|
|||
async def clone_public_chat(
|
||||
share_token: str,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ import time
|
|||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from scrapling.fetchers import AsyncFetcher
|
||||
|
||||
from app.db import User
|
||||
from app.users import current_active_user
|
||||
from app.auth.context import AuthContext
|
||||
from app.users import require_session_context
|
||||
from app.utils.proxy import get_proxy_url
|
||||
|
||||
router = APIRouter()
|
||||
|
|
@ -29,7 +29,7 @@ _INNERTUBE_CLIENT = {
|
|||
@router.get("/youtube/playlist-videos")
|
||||
async def get_playlist_videos(
|
||||
url: str = Query(..., description="YouTube playlist URL"),
|
||||
_user: User = Depends(current_active_user),
|
||||
_auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
"""Resolve a YouTube playlist URL into individual video URLs."""
|
||||
match = _PLAYLIST_ID_RE.search(url)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue