refactor: route authorization through auth context

This commit is contained in:
Anish Sarkar 2026-06-19 20:27:28 +05:30
parent 630880bf7a
commit 7e8d26fa81
4 changed files with 105 additions and 47 deletions

View file

@ -11,12 +11,12 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from app.auth.context import AuthContext
from app.db import (
Permission,
SearchSpace,
SearchSpaceMembership,
SearchSpaceRole,
User,
has_permission,
)
@ -80,9 +80,33 @@ async def get_user_permissions(
return []
async def _enforce_api_access_gate(
session: AsyncSession,
auth: AuthContext,
search_space_id: int,
search_space: SearchSpace | None = None,
) -> SearchSpace:
if search_space is None:
result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == search_space_id)
)
search_space = result.scalars().first()
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found")
if auth.is_gated and not search_space.api_access_enabled:
raise HTTPException(
status_code=403,
detail="API access is not enabled for this search space.",
)
return search_space
async def check_permission(
session: AsyncSession,
user: User,
auth: AuthContext,
search_space_id: int,
required_permission: str,
error_message: str = "You don't have permission to perform this action",
@ -104,7 +128,7 @@ async def check_permission(
Raises:
HTTPException: If user doesn't have access or permission
"""
membership = await get_user_membership(session, user.id, search_space_id)
membership = await get_user_membership(session, auth.user.id, search_space_id)
if not membership:
raise HTTPException(
@ -123,12 +147,14 @@ async def check_permission(
if not has_permission(permissions, required_permission):
raise HTTPException(status_code=403, detail=error_message)
await _enforce_api_access_gate(session, auth, search_space_id)
return membership
async def check_search_space_access(
session: AsyncSession,
user: User,
auth: AuthContext,
search_space_id: int,
) -> SearchSpaceMembership:
"""
@ -146,7 +172,7 @@ async def check_search_space_access(
Raises:
HTTPException: If user doesn't have access
"""
membership = await get_user_membership(session, user.id, search_space_id)
membership = await get_user_membership(session, auth.user.id, search_space_id)
if not membership:
raise HTTPException(
@ -154,6 +180,8 @@ async def check_search_space_access(
detail="You don't have access to this search space",
)
await _enforce_api_access_gate(session, auth, search_space_id)
return membership
@ -179,7 +207,7 @@ async def is_search_space_owner(
async def get_search_space_with_access_check(
session: AsyncSession,
user: User,
auth: AuthContext,
search_space_id: int,
required_permission: str | None = None,
) -> tuple[SearchSpace, SearchSpaceMembership]:
@ -210,10 +238,12 @@ async def get_search_space_with_access_check(
# Check access
if required_permission:
membership = await check_permission(
session, user, search_space_id, required_permission
session, auth, search_space_id, required_permission
)
else:
membership = await check_search_space_access(session, user, search_space_id)
membership = await check_search_space_access(session, auth, search_space_id)
await _enforce_api_access_gate(session, auth, search_space_id, search_space)
return search_space, membership