feat(auth): require sessions for user-scoped routes

This commit is contained in:
Anish Sarkar 2026-06-20 01:57:48 +05:30
parent 2315b2f344
commit 1f9cf326e5
10 changed files with 75 additions and 62 deletions

View file

@ -8,7 +8,8 @@ from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import case, desc, func, literal, literal_column, select, update from sqlalchemy import case, desc, func, literal, literal_column, select, update
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.db import User, get_async_session from app.auth.context import AuthContext
from app.db import get_async_session
from app.notifications.api.schemas import ( from app.notifications.api.schemas import (
BatchUnreadCountResponse, BatchUnreadCountResponse,
CategoryUnreadCount, CategoryUnreadCount,
@ -27,7 +28,7 @@ from app.notifications.api.transform import (
from app.notifications.constants import CATEGORY_TYPES, SYNC_WINDOW_DAYS from app.notifications.constants import CATEGORY_TYPES, SYNC_WINDOW_DAYS
from app.notifications.persistence import Notification from app.notifications.persistence import Notification
from app.notifications.types import NotificationCategory, NotificationType from app.notifications.types import NotificationCategory, NotificationType
from app.users import current_active_user from app.users import require_session_context
router = APIRouter(prefix="/notifications", tags=["notifications"]) router = APIRouter(prefix="/notifications", tags=["notifications"])
@ -35,10 +36,11 @@ router = APIRouter(prefix="/notifications", tags=["notifications"])
@router.get("/unread-counts-batch", response_model=BatchUnreadCountResponse) @router.get("/unread-counts-batch", response_model=BatchUnreadCountResponse)
async def get_unread_counts_batch( async def get_unread_counts_batch(
search_space_id: int | None = Query(None, description="Filter by search space ID"), search_space_id: int | None = Query(None, description="Filter by search space ID"),
user: User = Depends(current_active_user), auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
) -> BatchUnreadCountResponse: ) -> BatchUnreadCountResponse:
"""Unread counts for every category in a single query.""" """Unread counts for every category in a single query."""
user = auth.user
cutoff_date = datetime.now(UTC) - timedelta(days=SYNC_WINDOW_DAYS) cutoff_date = datetime.now(UTC) - timedelta(days=SYNC_WINDOW_DAYS)
base_filter = [ base_filter = [
@ -86,10 +88,11 @@ async def get_unread_counts_batch(
@router.get("/source-types", response_model=SourceTypesResponse) @router.get("/source-types", response_model=SourceTypesResponse)
async def get_notification_source_types( async def get_notification_source_types(
search_space_id: int | None = Query(None, description="Filter by search space ID"), search_space_id: int | None = Query(None, description="Filter by search space ID"),
user: User = Depends(current_active_user), auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
) -> SourceTypesResponse: ) -> SourceTypesResponse:
"""Distinct connector/document source types for the Status tab filter.""" """Distinct connector/document source types for the Status tab filter."""
user = auth.user
base_filter = [Notification.user_id == user.id] base_filter = [Notification.user_id == user.id]
if search_space_id is not None: if search_space_id is not None:
@ -160,7 +163,7 @@ async def get_unread_count(
category: NotificationCategory | None = Query( category: NotificationCategory | None = Query(
None, description="Filter by category: 'comments' or 'status'" None, description="Filter by category: 'comments' or 'status'"
), ),
user: User = Depends(current_active_user), auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
) -> UnreadCountResponse: ) -> UnreadCountResponse:
"""Total and recent (within sync window) unread counts for the user. """Total and recent (within sync window) unread counts for the user.
@ -168,6 +171,7 @@ async def get_unread_count(
Returning both lets a client hold the older count static while Returning both lets a client hold the older count static while
live-syncing the recent ones. live-syncing the recent ones.
""" """
user = auth.user
cutoff_date = datetime.now(UTC) - timedelta(days=SYNC_WINDOW_DAYS) cutoff_date = datetime.now(UTC) - timedelta(days=SYNC_WINDOW_DAYS)
base_filter = [ base_filter = [
@ -230,10 +234,11 @@ async def list_notifications(
), ),
limit: int = Query(50, ge=1, le=100, description="Number of items to return"), limit: int = Query(50, ge=1, le=100, description="Number of items to return"),
offset: int = Query(0, ge=0, description="Number of items to skip"), offset: int = Query(0, ge=0, description="Number of items to skip"),
user: User = Depends(current_active_user), auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
) -> NotificationListResponse: ) -> NotificationListResponse:
"""Paginated inbox fallback for items outside the Zero sync window.""" """Paginated inbox fallback for items outside the Zero sync window."""
user = auth.user
query = select(Notification).where(Notification.user_id == user.id) query = select(Notification).where(Notification.user_id == user.id)
count_query = select(func.count(Notification.id)).where( count_query = select(func.count(Notification.id)).where(
Notification.user_id == user.id Notification.user_id == user.id
@ -328,10 +333,11 @@ async def list_notifications(
@router.patch("/{notification_id}/read", response_model=MarkReadResponse) @router.patch("/{notification_id}/read", response_model=MarkReadResponse)
async def mark_notification_as_read( async def mark_notification_as_read(
notification_id: int, notification_id: int,
user: User = Depends(current_active_user), auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
) -> MarkReadResponse: ) -> MarkReadResponse:
"""Mark one of the user's notifications read; Zero syncs the change.""" """Mark one of the user's notifications read; Zero syncs the change."""
user = auth.user
# Scope to the caller's own notifications. # Scope to the caller's own notifications.
result = await session.execute( result = await session.execute(
select(Notification).where( select(Notification).where(
@ -364,10 +370,11 @@ async def mark_notification_as_read(
@router.patch("/read-all", response_model=MarkAllReadResponse) @router.patch("/read-all", response_model=MarkAllReadResponse)
async def mark_all_notifications_as_read( async def mark_all_notifications_as_read(
user: User = Depends(current_active_user), auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
) -> MarkAllReadResponse: ) -> MarkAllReadResponse:
"""Mark all of the user's notifications read; Zero syncs the changes.""" """Mark all of the user's notifications read; Zero syncs the changes."""
user = auth.user
result = await session.execute( result = await session.execute(
update(Notification) update(Notification)
.where( .where(

View file

@ -26,9 +26,9 @@ from app.agents.chat.multi_agent_chat.shared.feature_flags import (
AgentFeatureFlags, AgentFeatureFlags,
get_flags, get_flags,
) )
from app.auth.context import AuthContext
from app.config import config from app.config import config
from app.db import User from app.users import require_session_context
from app.users import current_active_user
router = APIRouter() router = APIRouter()
@ -75,6 +75,6 @@ class AgentFeatureFlagsRead(BaseModel):
@router.get("/agent/flags", response_model=AgentFeatureFlagsRead) @router.get("/agent/flags", response_model=AgentFeatureFlagsRead)
async def get_agent_flags( async def get_agent_flags(
_user: User = Depends(current_active_user), _auth: AuthContext = Depends(require_session_context),
) -> AgentFeatureFlagsRead: ) -> AgentFeatureFlagsRead:
return AgentFeatureFlagsRead.from_flags(get_flags()) return AgentFeatureFlagsRead.from_flags(get_flags())

View file

@ -6,7 +6,7 @@ from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext from app.auth.context import AuthContext
from app.db import User, get_async_session from app.db import get_async_session
from app.schemas.chat_comments import ( from app.schemas.chat_comments import (
CommentBatchRequest, CommentBatchRequest,
CommentBatchResponse, CommentBatchResponse,
@ -26,7 +26,7 @@ from app.services.chat_comments_service import (
get_user_mentions, get_user_mentions,
update_comment, update_comment,
) )
from app.users import get_auth_context from app.users import require_session_context
router = APIRouter() router = APIRouter()
@ -35,22 +35,20 @@ router = APIRouter()
async def batch_list_comments( async def batch_list_comments(
request: CommentBatchRequest, request: CommentBatchRequest,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = Depends(require_session_context),
): ):
user = auth.user
"""Batch-fetch comments for multiple messages in one request.""" """Batch-fetch comments for multiple messages in one request."""
return await get_comments_for_messages_batch(session, request.message_ids, user) return await get_comments_for_messages_batch(session, request.message_ids, auth)
@router.get("/messages/{message_id}/comments", response_model=CommentListResponse) @router.get("/messages/{message_id}/comments", response_model=CommentListResponse)
async def list_comments( async def list_comments(
message_id: int, message_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = Depends(require_session_context),
): ):
user = auth.user
"""List all comments for a message with their replies.""" """List all comments for a message with their replies."""
return await get_comments_for_message(session, message_id, user) return await get_comments_for_message(session, message_id, auth)
@router.post("/messages/{message_id}/comments", response_model=CommentResponse) @router.post("/messages/{message_id}/comments", response_model=CommentResponse)
@ -58,11 +56,10 @@ async def add_comment(
message_id: int, message_id: int,
request: CommentCreateRequest, request: CommentCreateRequest,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = Depends(require_session_context),
): ):
user = auth.user
"""Create a top-level comment on an AI response.""" """Create a top-level comment on an AI response."""
return await create_comment(session, message_id, request.content, user) return await create_comment(session, message_id, request.content, auth)
@router.post("/comments/{comment_id}/replies", response_model=CommentReplyResponse) @router.post("/comments/{comment_id}/replies", response_model=CommentReplyResponse)
@ -70,11 +67,10 @@ async def add_reply(
comment_id: int, comment_id: int,
request: CommentCreateRequest, request: CommentCreateRequest,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = Depends(require_session_context),
): ):
user = auth.user
"""Reply to an existing comment.""" """Reply to an existing comment."""
return await create_reply(session, comment_id, request.content, user) return await create_reply(session, comment_id, request.content, auth)
@router.put("/comments/{comment_id}", response_model=CommentReplyResponse) @router.put("/comments/{comment_id}", response_model=CommentReplyResponse)
@ -82,22 +78,20 @@ async def edit_comment(
comment_id: int, comment_id: int,
request: CommentUpdateRequest, request: CommentUpdateRequest,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = Depends(require_session_context),
): ):
user = auth.user
"""Update a comment's content (author only).""" """Update a comment's content (author only)."""
return await update_comment(session, comment_id, request.content, user) return await update_comment(session, comment_id, request.content, auth)
@router.delete("/comments/{comment_id}") @router.delete("/comments/{comment_id}")
async def remove_comment( async def remove_comment(
comment_id: int, comment_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = Depends(require_session_context),
): ):
user = auth.user
"""Delete a comment (author or user with COMMENTS_DELETE permission).""" """Delete a comment (author or user with COMMENTS_DELETE permission)."""
return await delete_comment(session, comment_id, user) return await delete_comment(session, comment_id, auth)
# ============================================================================= # =============================================================================
@ -109,8 +103,7 @@ async def remove_comment(
async def list_mentions( async def list_mentions(
search_space_id: int | None = None, search_space_id: int | None = None,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = Depends(require_session_context),
): ):
user = auth.user
"""List mentions for the current user.""" """List mentions for the current user."""
return await get_user_mentions(session, user, search_space_id) return await get_user_mentions(session, auth, search_space_id)

View file

@ -8,10 +8,10 @@ from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.db import ( from app.db import (
INCENTIVE_TASKS_CONFIG, INCENTIVE_TASKS_CONFIG,
IncentiveTaskType, IncentiveTaskType,
User,
UserIncentiveTask, UserIncentiveTask,
get_async_session, get_async_session,
) )
@ -21,19 +21,20 @@ from app.schemas.incentive_tasks import (
IncentiveTasksResponse, IncentiveTasksResponse,
TaskAlreadyCompletedResponse, TaskAlreadyCompletedResponse,
) )
from app.users import current_active_user from app.users import require_session_context
router = APIRouter(prefix="/incentive-tasks", tags=["incentive-tasks"]) router = APIRouter(prefix="/incentive-tasks", tags=["incentive-tasks"])
@router.get("", response_model=IncentiveTasksResponse) @router.get("", response_model=IncentiveTasksResponse)
async def get_incentive_tasks( async def get_incentive_tasks(
user: User = Depends(current_active_user), auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
) -> IncentiveTasksResponse: ) -> IncentiveTasksResponse:
""" """
Get all available incentive tasks with the user's completion status. Get all available incentive tasks with the user's completion status.
""" """
user = auth.user
# Get all completed tasks for this user # Get all completed tasks for this user
result = await session.execute( result = await session.execute(
select(UserIncentiveTask).where(UserIncentiveTask.user_id == user.id) select(UserIncentiveTask).where(UserIncentiveTask.user_id == user.id)
@ -75,7 +76,7 @@ async def get_incentive_tasks(
) )
async def complete_task( async def complete_task(
task_type: IncentiveTaskType, task_type: IncentiveTaskType,
user: User = Depends(current_active_user), auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
) -> CompleteTaskResponse | TaskAlreadyCompletedResponse: ) -> CompleteTaskResponse | TaskAlreadyCompletedResponse:
""" """
@ -84,6 +85,7 @@ async def complete_task(
Each task can only be completed once. If the task was already completed, Each task can only be completed once. If the task was already completed,
returns the existing completion information without awarding additional credit. returns the existing completion information without awarding additional credit.
""" """
user = auth.user
# Validate task type exists in config # Validate task type exists in config
task_config = INCENTIVE_TASKS_CONFIG.get(task_type) task_config = INCENTIVE_TASKS_CONFIG.get(task_type)
if not task_config: if not task_config:

View file

@ -6,7 +6,8 @@ from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.db import User, get_async_session from app.auth.context import AuthContext
from app.db import get_async_session
from app.services.memory import ( from app.services.memory import (
MemoryRead, MemoryRead,
MemoryScope, MemoryScope,
@ -15,7 +16,7 @@ from app.services.memory import (
reset_memory, reset_memory,
save_memory, save_memory,
) )
from app.users import current_active_user from app.users import require_session_context
router = APIRouter() router = APIRouter()
@ -26,9 +27,10 @@ class MemoryUpdate(BaseModel):
@router.get("/users/me/memory", response_model=MemoryRead) @router.get("/users/me/memory", response_model=MemoryRead)
async def get_user_memory( async def get_user_memory(
user: User = Depends(current_active_user), auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
): ):
user = auth.user
memory_md = await read_memory( memory_md = await read_memory(
scope=MemoryScope.USER, scope=MemoryScope.USER,
target_id=user.id, target_id=user.id,
@ -40,9 +42,10 @@ async def get_user_memory(
@router.put("/users/me/memory", response_model=MemoryRead) @router.put("/users/me/memory", response_model=MemoryRead)
async def update_user_memory( async def update_user_memory(
body: MemoryUpdate, body: MemoryUpdate,
user: User = Depends(current_active_user), auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
): ):
user = auth.user
result = await save_memory( result = await save_memory(
scope=MemoryScope.USER, scope=MemoryScope.USER,
target_id=user.id, target_id=user.id,
@ -56,9 +59,10 @@ async def update_user_memory(
@router.post("/users/me/memory/reset", response_model=MemoryRead) @router.post("/users/me/memory/reset", response_model=MemoryRead)
async def reset_user_memory( async def reset_user_memory(
user: User = Depends(current_active_user), auth: AuthContext = Depends(require_session_context),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
): ):
user = auth.user
result = await reset_memory( result = await reset_memory(
scope=MemoryScope.USER, scope=MemoryScope.USER,
target_id=user.id, target_id=user.id,

View file

@ -10,9 +10,9 @@ import logging
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
from app.db import User from app.auth.context import AuthContext
from app.services.model_list_service import get_model_list from app.services.model_list_service import get_model_list
from app.users import current_active_user from app.users import require_session_context
router = APIRouter() router = APIRouter()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -27,7 +27,7 @@ class ModelListItem(BaseModel):
@router.get("/models", response_model=list[ModelListItem]) @router.get("/models", response_model=list[ModelListItem])
async def list_available_models( async def list_available_models(
user: User = Depends(current_active_user), _auth: AuthContext = Depends(require_session_context),
): ):
""" """
Return all available models grouped by provider. Return all available models grouped by provider.

View file

@ -1334,8 +1334,8 @@ async def append_message(
Requires CHATS_UPDATE permission. Requires CHATS_UPDATE permission.
""" """
try: try:
# Capture ``user.id`` as a primitive UUID up front. The # Capture ``user.id`` as a primitive UUID up front. The auth
# ``current_active_user`` dependency hands us a ``User`` ORM # dependency hands us a ``User`` ORM
# row bound to ``session``; if the outer ``except # row bound to ``session``; if the outer ``except
# IntegrityError`` block below ever fires (an unexpected # IntegrityError`` block below ever fires (an unexpected
# constraint like a foreign key violation — the common # constraint like a foreign key violation — the common

View file

@ -3,14 +3,15 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
from app.db import Prompt, SearchSpaceMembership, User, get_async_session from app.auth.context import AuthContext
from app.db import Prompt, SearchSpaceMembership, get_async_session
from app.schemas.prompts import ( from app.schemas.prompts import (
PromptCreate, PromptCreate,
PromptRead, PromptRead,
PromptUpdate, PromptUpdate,
PublicPromptRead, PublicPromptRead,
) )
from app.users import current_active_user from app.users import require_session_context
router = APIRouter(tags=["Prompts"]) router = APIRouter(tags=["Prompts"])
@ -19,8 +20,9 @@ router = APIRouter(tags=["Prompts"])
async def list_prompts( async def list_prompts(
search_space_id: int | None = None, search_space_id: int | None = None,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(require_session_context),
): ):
user = auth.user
query = select(Prompt).where(Prompt.user_id == user.id) query = select(Prompt).where(Prompt.user_id == user.id)
if search_space_id is not None: if search_space_id is not None:
query = query.where(Prompt.search_space_id == search_space_id) query = query.where(Prompt.search_space_id == search_space_id)
@ -33,8 +35,9 @@ async def list_prompts(
async def create_prompt( async def create_prompt(
body: PromptCreate, body: PromptCreate,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(require_session_context),
): ):
user = auth.user
if body.search_space_id is not None: if body.search_space_id is not None:
membership = await session.execute( membership = await session.execute(
select(SearchSpaceMembership).where( select(SearchSpaceMembership).where(
@ -67,8 +70,9 @@ async def update_prompt(
prompt_id: int, prompt_id: int,
body: PromptUpdate, body: PromptUpdate,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(require_session_context),
): ):
user = auth.user
result = await session.execute( result = await session.execute(
select(Prompt).where( select(Prompt).where(
Prompt.id == prompt_id, Prompt.id == prompt_id,
@ -99,8 +103,9 @@ async def update_prompt(
async def delete_prompt( async def delete_prompt(
prompt_id: int, prompt_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(require_session_context),
): ):
user = auth.user
result = await session.execute( result = await session.execute(
select(Prompt).where( select(Prompt).where(
Prompt.id == prompt_id, Prompt.id == prompt_id,
@ -119,8 +124,9 @@ async def delete_prompt(
@router.get("/prompts/public", response_model=list[PublicPromptRead]) @router.get("/prompts/public", response_model=list[PublicPromptRead])
async def list_public_prompts( async def list_public_prompts(
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(require_session_context),
): ):
user = auth.user
result = await session.execute( result = await session.execute(
select(Prompt) select(Prompt)
.options(selectinload(Prompt.user)) .options(selectinload(Prompt.user))
@ -141,8 +147,9 @@ async def list_public_prompts(
async def copy_public_prompt( async def copy_public_prompt(
prompt_id: int, prompt_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(require_session_context),
): ):
user = auth.user
result = await session.execute( result = await session.execute(
select(Prompt).where( select(Prompt).where(
Prompt.id == prompt_id, Prompt.id == prompt_id,

View file

@ -12,7 +12,7 @@ from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext from app.auth.context import AuthContext
from app.db import User, get_async_session from app.db import get_async_session
from app.schemas.new_chat import ( from app.schemas.new_chat import (
CloneResponse, CloneResponse,
PublicChatResponse, PublicChatResponse,
@ -24,7 +24,7 @@ from app.services.public_chat_service import (
get_snapshot_report, get_snapshot_report,
get_snapshot_video_presentation, get_snapshot_video_presentation,
) )
from app.users import get_auth_context from app.users import require_session_context
router = APIRouter(prefix="/public", tags=["public"]) router = APIRouter(prefix="/public", tags=["public"])
@ -47,7 +47,7 @@ async def read_public_chat(
async def clone_public_chat( async def clone_public_chat(
share_token: str, share_token: str,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = Depends(require_session_context),
): ):
user = auth.user user = auth.user
""" """

View file

@ -8,8 +8,8 @@ import time
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query
from scrapling.fetchers import AsyncFetcher from scrapling.fetchers import AsyncFetcher
from app.db import User from app.auth.context import AuthContext
from app.users import current_active_user from app.users import require_session_context
from app.utils.proxy import get_proxy_url from app.utils.proxy import get_proxy_url
router = APIRouter() router = APIRouter()
@ -29,7 +29,7 @@ _INNERTUBE_CLIENT = {
@router.get("/youtube/playlist-videos") @router.get("/youtube/playlist-videos")
async def get_playlist_videos( async def get_playlist_videos(
url: str = Query(..., description="YouTube playlist URL"), url: str = Query(..., description="YouTube playlist URL"),
_user: User = Depends(current_active_user), _auth: AuthContext = Depends(require_session_context),
): ):
"""Resolve a YouTube playlist URL into individual video URLs.""" """Resolve a YouTube playlist URL into individual video URLs."""
match = _PLAYLIST_ID_RE.search(url) match = _PLAYLIST_ID_RE.search(url)