diff --git a/surfsense_backend/app/notifications/api/api.py b/surfsense_backend/app/notifications/api/api.py index 9a136ca7b..7794a5867 100644 --- a/surfsense_backend/app/notifications/api/api.py +++ b/surfsense_backend/app/notifications/api/api.py @@ -8,7 +8,8 @@ from fastapi import APIRouter, Depends, HTTPException, Query, status from sqlalchemy import case, desc, func, literal, literal_column, select, update from sqlalchemy.ext.asyncio import AsyncSession -from app.db import User, get_async_session +from app.auth.context import AuthContext +from app.db import get_async_session from app.notifications.api.schemas import ( BatchUnreadCountResponse, CategoryUnreadCount, @@ -27,7 +28,7 @@ from app.notifications.api.transform import ( from app.notifications.constants import CATEGORY_TYPES, SYNC_WINDOW_DAYS from app.notifications.persistence import Notification from app.notifications.types import NotificationCategory, NotificationType -from app.users import current_active_user +from app.users import require_session_context router = APIRouter(prefix="/notifications", tags=["notifications"]) @@ -35,10 +36,11 @@ router = APIRouter(prefix="/notifications", tags=["notifications"]) @router.get("/unread-counts-batch", response_model=BatchUnreadCountResponse) async def get_unread_counts_batch( search_space_id: int | None = Query(None, description="Filter by search space ID"), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), session: AsyncSession = Depends(get_async_session), ) -> BatchUnreadCountResponse: """Unread counts for every category in a single query.""" + user = auth.user cutoff_date = datetime.now(UTC) - timedelta(days=SYNC_WINDOW_DAYS) base_filter = [ @@ -86,10 +88,11 @@ async def get_unread_counts_batch( @router.get("/source-types", response_model=SourceTypesResponse) async def get_notification_source_types( search_space_id: int | None = Query(None, description="Filter by search space ID"), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), session: AsyncSession = Depends(get_async_session), ) -> SourceTypesResponse: """Distinct connector/document source types for the Status tab filter.""" + user = auth.user base_filter = [Notification.user_id == user.id] if search_space_id is not None: @@ -160,7 +163,7 @@ async def get_unread_count( category: NotificationCategory | None = Query( None, description="Filter by category: 'comments' or 'status'" ), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), session: AsyncSession = Depends(get_async_session), ) -> UnreadCountResponse: """Total and recent (within sync window) unread counts for the user. @@ -168,6 +171,7 @@ async def get_unread_count( Returning both lets a client hold the older count static while live-syncing the recent ones. """ + user = auth.user cutoff_date = datetime.now(UTC) - timedelta(days=SYNC_WINDOW_DAYS) base_filter = [ @@ -230,10 +234,11 @@ async def list_notifications( ), limit: int = Query(50, ge=1, le=100, description="Number of items to return"), offset: int = Query(0, ge=0, description="Number of items to skip"), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), session: AsyncSession = Depends(get_async_session), ) -> NotificationListResponse: """Paginated inbox fallback for items outside the Zero sync window.""" + user = auth.user query = select(Notification).where(Notification.user_id == user.id) count_query = select(func.count(Notification.id)).where( Notification.user_id == user.id @@ -328,10 +333,11 @@ async def list_notifications( @router.patch("/{notification_id}/read", response_model=MarkReadResponse) async def mark_notification_as_read( notification_id: int, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), session: AsyncSession = Depends(get_async_session), ) -> MarkReadResponse: """Mark one of the user's notifications read; Zero syncs the change.""" + user = auth.user # Scope to the caller's own notifications. result = await session.execute( select(Notification).where( @@ -364,10 +370,11 @@ async def mark_notification_as_read( @router.patch("/read-all", response_model=MarkAllReadResponse) async def mark_all_notifications_as_read( - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), session: AsyncSession = Depends(get_async_session), ) -> MarkAllReadResponse: """Mark all of the user's notifications read; Zero syncs the changes.""" + user = auth.user result = await session.execute( update(Notification) .where( diff --git a/surfsense_backend/app/routes/agent_flags_route.py b/surfsense_backend/app/routes/agent_flags_route.py index e97608cbe..222909c59 100644 --- a/surfsense_backend/app/routes/agent_flags_route.py +++ b/surfsense_backend/app/routes/agent_flags_route.py @@ -26,9 +26,9 @@ from app.agents.chat.multi_agent_chat.shared.feature_flags import ( AgentFeatureFlags, get_flags, ) +from app.auth.context import AuthContext from app.config import config -from app.db import User -from app.users import current_active_user +from app.users import require_session_context router = APIRouter() @@ -75,6 +75,6 @@ class AgentFeatureFlagsRead(BaseModel): @router.get("/agent/flags", response_model=AgentFeatureFlagsRead) async def get_agent_flags( - _user: User = Depends(current_active_user), + _auth: AuthContext = Depends(require_session_context), ) -> AgentFeatureFlagsRead: return AgentFeatureFlagsRead.from_flags(get_flags()) diff --git a/surfsense_backend/app/routes/chat_comments_routes.py b/surfsense_backend/app/routes/chat_comments_routes.py index 5bbcd253e..2e1eb1d27 100644 --- a/surfsense_backend/app/routes/chat_comments_routes.py +++ b/surfsense_backend/app/routes/chat_comments_routes.py @@ -6,7 +6,7 @@ from fastapi import APIRouter, Depends from sqlalchemy.ext.asyncio import AsyncSession from app.auth.context import AuthContext -from app.db import User, get_async_session +from app.db import get_async_session from app.schemas.chat_comments import ( CommentBatchRequest, CommentBatchResponse, @@ -26,7 +26,7 @@ from app.services.chat_comments_service import ( get_user_mentions, update_comment, ) -from app.users import get_auth_context +from app.users import require_session_context router = APIRouter() @@ -35,22 +35,20 @@ router = APIRouter() async def batch_list_comments( request: CommentBatchRequest, session: AsyncSession = Depends(get_async_session), - auth: AuthContext = Depends(get_auth_context), + auth: AuthContext = Depends(require_session_context), ): - user = auth.user """Batch-fetch comments for multiple messages in one request.""" - return await get_comments_for_messages_batch(session, request.message_ids, user) + return await get_comments_for_messages_batch(session, request.message_ids, auth) @router.get("/messages/{message_id}/comments", response_model=CommentListResponse) async def list_comments( message_id: int, session: AsyncSession = Depends(get_async_session), - auth: AuthContext = Depends(get_auth_context), + auth: AuthContext = Depends(require_session_context), ): - user = auth.user """List all comments for a message with their replies.""" - return await get_comments_for_message(session, message_id, user) + return await get_comments_for_message(session, message_id, auth) @router.post("/messages/{message_id}/comments", response_model=CommentResponse) @@ -58,11 +56,10 @@ async def add_comment( message_id: int, request: CommentCreateRequest, session: AsyncSession = Depends(get_async_session), - auth: AuthContext = Depends(get_auth_context), + auth: AuthContext = Depends(require_session_context), ): - user = auth.user """Create a top-level comment on an AI response.""" - return await create_comment(session, message_id, request.content, user) + return await create_comment(session, message_id, request.content, auth) @router.post("/comments/{comment_id}/replies", response_model=CommentReplyResponse) @@ -70,11 +67,10 @@ async def add_reply( comment_id: int, request: CommentCreateRequest, session: AsyncSession = Depends(get_async_session), - auth: AuthContext = Depends(get_auth_context), + auth: AuthContext = Depends(require_session_context), ): - user = auth.user """Reply to an existing comment.""" - return await create_reply(session, comment_id, request.content, user) + return await create_reply(session, comment_id, request.content, auth) @router.put("/comments/{comment_id}", response_model=CommentReplyResponse) @@ -82,22 +78,20 @@ async def edit_comment( comment_id: int, request: CommentUpdateRequest, session: AsyncSession = Depends(get_async_session), - auth: AuthContext = Depends(get_auth_context), + auth: AuthContext = Depends(require_session_context), ): - user = auth.user """Update a comment's content (author only).""" - return await update_comment(session, comment_id, request.content, user) + return await update_comment(session, comment_id, request.content, auth) @router.delete("/comments/{comment_id}") async def remove_comment( comment_id: int, session: AsyncSession = Depends(get_async_session), - auth: AuthContext = Depends(get_auth_context), + auth: AuthContext = Depends(require_session_context), ): - user = auth.user """Delete a comment (author or user with COMMENTS_DELETE permission).""" - return await delete_comment(session, comment_id, user) + return await delete_comment(session, comment_id, auth) # ============================================================================= @@ -109,8 +103,7 @@ async def remove_comment( async def list_mentions( search_space_id: int | None = None, session: AsyncSession = Depends(get_async_session), - auth: AuthContext = Depends(get_auth_context), + auth: AuthContext = Depends(require_session_context), ): - user = auth.user """List mentions for the current user.""" - return await get_user_mentions(session, user, search_space_id) + return await get_user_mentions(session, auth, search_space_id) diff --git a/surfsense_backend/app/routes/incentive_tasks_routes.py b/surfsense_backend/app/routes/incentive_tasks_routes.py index 1dae09a2d..2635df42f 100644 --- a/surfsense_backend/app/routes/incentive_tasks_routes.py +++ b/surfsense_backend/app/routes/incentive_tasks_routes.py @@ -8,10 +8,10 @@ from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from app.auth.context import AuthContext from app.db import ( INCENTIVE_TASKS_CONFIG, IncentiveTaskType, - User, UserIncentiveTask, get_async_session, ) @@ -21,19 +21,20 @@ from app.schemas.incentive_tasks import ( IncentiveTasksResponse, TaskAlreadyCompletedResponse, ) -from app.users import current_active_user +from app.users import require_session_context router = APIRouter(prefix="/incentive-tasks", tags=["incentive-tasks"]) @router.get("", response_model=IncentiveTasksResponse) async def get_incentive_tasks( - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), session: AsyncSession = Depends(get_async_session), ) -> IncentiveTasksResponse: """ Get all available incentive tasks with the user's completion status. """ + user = auth.user # Get all completed tasks for this user result = await session.execute( select(UserIncentiveTask).where(UserIncentiveTask.user_id == user.id) @@ -75,7 +76,7 @@ async def get_incentive_tasks( ) async def complete_task( task_type: IncentiveTaskType, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), session: AsyncSession = Depends(get_async_session), ) -> CompleteTaskResponse | TaskAlreadyCompletedResponse: """ @@ -84,6 +85,7 @@ async def complete_task( Each task can only be completed once. If the task was already completed, returns the existing completion information without awarding additional credit. """ + user = auth.user # Validate task type exists in config task_config = INCENTIVE_TASKS_CONFIG.get(task_type) if not task_config: diff --git a/surfsense_backend/app/routes/memory_routes.py b/surfsense_backend/app/routes/memory_routes.py index 8e73a277c..d2a82a81c 100644 --- a/surfsense_backend/app/routes/memory_routes.py +++ b/surfsense_backend/app/routes/memory_routes.py @@ -6,7 +6,8 @@ from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel from sqlalchemy.ext.asyncio import AsyncSession -from app.db import User, get_async_session +from app.auth.context import AuthContext +from app.db import get_async_session from app.services.memory import ( MemoryRead, MemoryScope, @@ -15,7 +16,7 @@ from app.services.memory import ( reset_memory, save_memory, ) -from app.users import current_active_user +from app.users import require_session_context router = APIRouter() @@ -26,9 +27,10 @@ class MemoryUpdate(BaseModel): @router.get("/users/me/memory", response_model=MemoryRead) async def get_user_memory( - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), session: AsyncSession = Depends(get_async_session), ): + user = auth.user memory_md = await read_memory( scope=MemoryScope.USER, target_id=user.id, @@ -40,9 +42,10 @@ async def get_user_memory( @router.put("/users/me/memory", response_model=MemoryRead) async def update_user_memory( body: MemoryUpdate, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), session: AsyncSession = Depends(get_async_session), ): + user = auth.user result = await save_memory( scope=MemoryScope.USER, target_id=user.id, @@ -56,9 +59,10 @@ async def update_user_memory( @router.post("/users/me/memory/reset", response_model=MemoryRead) async def reset_user_memory( - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), session: AsyncSession = Depends(get_async_session), ): + user = auth.user result = await reset_memory( scope=MemoryScope.USER, target_id=user.id, diff --git a/surfsense_backend/app/routes/model_list_routes.py b/surfsense_backend/app/routes/model_list_routes.py index 79ae7221f..e2535f684 100644 --- a/surfsense_backend/app/routes/model_list_routes.py +++ b/surfsense_backend/app/routes/model_list_routes.py @@ -10,9 +10,9 @@ import logging from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel -from app.db import User +from app.auth.context import AuthContext from app.services.model_list_service import get_model_list -from app.users import current_active_user +from app.users import require_session_context router = APIRouter() logger = logging.getLogger(__name__) @@ -27,7 +27,7 @@ class ModelListItem(BaseModel): @router.get("/models", response_model=list[ModelListItem]) async def list_available_models( - user: User = Depends(current_active_user), + _auth: AuthContext = Depends(require_session_context), ): """ Return all available models grouped by provider. diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index 1ca598fe3..87ed68b2f 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -1334,8 +1334,8 @@ async def append_message( Requires CHATS_UPDATE permission. """ try: - # Capture ``user.id`` as a primitive UUID up front. The - # ``current_active_user`` dependency hands us a ``User`` ORM + # Capture ``user.id`` as a primitive UUID up front. The auth + # dependency hands us a ``User`` ORM # row bound to ``session``; if the outer ``except # IntegrityError`` block below ever fires (an unexpected # constraint like a foreign key violation — the common diff --git a/surfsense_backend/app/routes/prompts_routes.py b/surfsense_backend/app/routes/prompts_routes.py index 8dd47537e..b4cb1466c 100644 --- a/surfsense_backend/app/routes/prompts_routes.py +++ b/surfsense_backend/app/routes/prompts_routes.py @@ -3,14 +3,15 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload -from app.db import Prompt, SearchSpaceMembership, User, get_async_session +from app.auth.context import AuthContext +from app.db import Prompt, SearchSpaceMembership, get_async_session from app.schemas.prompts import ( PromptCreate, PromptRead, PromptUpdate, PublicPromptRead, ) -from app.users import current_active_user +from app.users import require_session_context router = APIRouter(tags=["Prompts"]) @@ -19,8 +20,9 @@ router = APIRouter(tags=["Prompts"]) async def list_prompts( search_space_id: int | None = None, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), ): + user = auth.user query = select(Prompt).where(Prompt.user_id == user.id) if search_space_id is not None: query = query.where(Prompt.search_space_id == search_space_id) @@ -33,8 +35,9 @@ async def list_prompts( async def create_prompt( body: PromptCreate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), ): + user = auth.user if body.search_space_id is not None: membership = await session.execute( select(SearchSpaceMembership).where( @@ -67,8 +70,9 @@ async def update_prompt( prompt_id: int, body: PromptUpdate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), ): + user = auth.user result = await session.execute( select(Prompt).where( Prompt.id == prompt_id, @@ -99,8 +103,9 @@ async def update_prompt( async def delete_prompt( prompt_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), ): + user = auth.user result = await session.execute( select(Prompt).where( Prompt.id == prompt_id, @@ -119,8 +124,9 @@ async def delete_prompt( @router.get("/prompts/public", response_model=list[PublicPromptRead]) async def list_public_prompts( session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), ): + user = auth.user result = await session.execute( select(Prompt) .options(selectinload(Prompt.user)) @@ -141,8 +147,9 @@ async def list_public_prompts( async def copy_public_prompt( prompt_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(require_session_context), ): + user = auth.user result = await session.execute( select(Prompt).where( Prompt.id == prompt_id, diff --git a/surfsense_backend/app/routes/public_chat_routes.py b/surfsense_backend/app/routes/public_chat_routes.py index 4029cd139..70f012911 100644 --- a/surfsense_backend/app/routes/public_chat_routes.py +++ b/surfsense_backend/app/routes/public_chat_routes.py @@ -12,7 +12,7 @@ from fastapi.responses import StreamingResponse from sqlalchemy.ext.asyncio import AsyncSession from app.auth.context import AuthContext -from app.db import User, get_async_session +from app.db import get_async_session from app.schemas.new_chat import ( CloneResponse, PublicChatResponse, @@ -24,7 +24,7 @@ from app.services.public_chat_service import ( get_snapshot_report, get_snapshot_video_presentation, ) -from app.users import get_auth_context +from app.users import require_session_context router = APIRouter(prefix="/public", tags=["public"]) @@ -47,7 +47,7 @@ async def read_public_chat( async def clone_public_chat( share_token: str, session: AsyncSession = Depends(get_async_session), - auth: AuthContext = Depends(get_auth_context), + auth: AuthContext = Depends(require_session_context), ): user = auth.user """ diff --git a/surfsense_backend/app/routes/youtube_routes.py b/surfsense_backend/app/routes/youtube_routes.py index 9fc6d1dfc..c9d958aa8 100644 --- a/surfsense_backend/app/routes/youtube_routes.py +++ b/surfsense_backend/app/routes/youtube_routes.py @@ -8,8 +8,8 @@ import time from fastapi import APIRouter, Depends, HTTPException, Query from scrapling.fetchers import AsyncFetcher -from app.db import User -from app.users import current_active_user +from app.auth.context import AuthContext +from app.users import require_session_context from app.utils.proxy import get_proxy_url router = APIRouter() @@ -29,7 +29,7 @@ _INNERTUBE_CLIENT = { @router.get("/youtube/playlist-videos") async def get_playlist_videos( url: str = Query(..., description="YouTube playlist URL"), - _user: User = Depends(current_active_user), + _auth: AuthContext = Depends(require_session_context), ): """Resolve a YouTube playlist URL into individual video URLs.""" match = _PLAYLIST_ID_RE.search(url)