mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-24 21:38:09 +02:00
feat: resolve auth context from sessions and PATs
This commit is contained in:
parent
4463990ca4
commit
cddfb3660b
4 changed files with 175 additions and 2 deletions
1
surfsense_backend/app/auth/__init__.py
Normal file
1
surfsense_backend/app/auth/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""Authentication principals and helpers."""
|
||||
38
surfsense_backend/app/auth/context.py
Normal file
38
surfsense_backend/app/auth/context.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
from app.db import PersonalAccessToken, User
|
||||
|
||||
AuthMethod = Literal["session", "pat", "system"]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AuthContext:
|
||||
"""Typed principal for authorization decisions."""
|
||||
|
||||
user: User
|
||||
method: AuthMethod
|
||||
pat: PersonalAccessToken | None = None
|
||||
source: str | None = None
|
||||
|
||||
@classmethod
|
||||
def session(cls, user: User) -> AuthContext:
|
||||
return cls(user=user, method="session")
|
||||
|
||||
@classmethod
|
||||
def pat_auth(cls, user: User, pat: PersonalAccessToken) -> AuthContext:
|
||||
return cls(user=user, method="pat", pat=pat)
|
||||
|
||||
@classmethod
|
||||
def system(cls, user: User, source: str) -> AuthContext:
|
||||
return cls(user=user, method="system", source=source)
|
||||
|
||||
@property
|
||||
def is_gated(self) -> bool:
|
||||
return self.method == "pat"
|
||||
|
||||
@property
|
||||
def is_session(self) -> bool:
|
||||
return self.method == "session"
|
||||
|
|
@ -3,7 +3,7 @@ import uuid
|
|||
from datetime import UTC, datetime
|
||||
|
||||
import httpx
|
||||
from fastapi import Depends, Request, Response
|
||||
from fastapi import Depends, HTTPException, Request, Response, status
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, models
|
||||
from fastapi_users.authentication import (
|
||||
|
|
@ -16,6 +16,7 @@ from pydantic import BaseModel
|
|||
from sqlalchemy import update
|
||||
|
||||
from app.config import config
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import (
|
||||
Prompt,
|
||||
SearchSpace,
|
||||
|
|
@ -23,11 +24,14 @@ from app.db import (
|
|||
SearchSpaceRole,
|
||||
User,
|
||||
async_session_maker,
|
||||
get_async_session,
|
||||
get_default_roles_config,
|
||||
get_user_db,
|
||||
)
|
||||
from app.prompts.system_defaults import SYSTEM_PROMPT_DEFAULTS
|
||||
from app.utils.pat import PAT_PREFIX, maybe_touch_last_used, resolve_pat
|
||||
from app.utils.refresh_tokens import create_refresh_token
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -298,5 +302,62 @@ auth_backend = AuthenticationBackend(
|
|||
|
||||
fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])
|
||||
|
||||
current_active_user = fastapi_users.current_user(active=True)
|
||||
|
||||
async def get_auth_context(
|
||||
request: Request,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user_manager: UserManager = Depends(get_user_manager),
|
||||
) -> AuthContext:
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if not auth_header:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Unauthorized",
|
||||
)
|
||||
|
||||
scheme, _, token = auth_header.partition(" ")
|
||||
if scheme.lower() != "bearer" or not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Unauthorized",
|
||||
)
|
||||
|
||||
if token.startswith(PAT_PREFIX):
|
||||
pat = await resolve_pat(session, token)
|
||||
if pat and pat.user and pat.user.is_active:
|
||||
maybe_touch_last_used(pat)
|
||||
return AuthContext.pat_auth(pat.user, pat)
|
||||
|
||||
try:
|
||||
user = await get_jwt_strategy().read_token(token, user_manager)
|
||||
except Exception:
|
||||
logger.exception("Failed to read access token")
|
||||
user = None
|
||||
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Unauthorized",
|
||||
)
|
||||
|
||||
return AuthContext.session(user)
|
||||
|
||||
|
||||
async def current_active_user(
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> User:
|
||||
return auth.user
|
||||
|
||||
|
||||
async def require_session_context(
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> AuthContext:
|
||||
if not auth.is_session:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="This action requires an interactive session",
|
||||
)
|
||||
return auth
|
||||
|
||||
|
||||
current_optional_user = fastapi_users.current_user(active=True, optional=True)
|
||||
|
|
|
|||
73
surfsense_backend/app/utils/pat.py
Normal file
73
surfsense_backend/app/utils/pat.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.db import PersonalAccessToken, User, async_session_maker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PAT_PREFIX = "ss_pat_"
|
||||
PAT_TOKEN_BYTES = 32
|
||||
LAST_USED_THROTTLE = timedelta(minutes=10)
|
||||
|
||||
|
||||
def generate_pat() -> str:
|
||||
return f"{PAT_PREFIX}{secrets.token_urlsafe(PAT_TOKEN_BYTES)}"
|
||||
|
||||
|
||||
def hash_pat(token: str) -> str:
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
|
||||
def token_prefix(token: str) -> str:
|
||||
return token[:16]
|
||||
|
||||
|
||||
async def resolve_pat(
|
||||
session: AsyncSession,
|
||||
token: str,
|
||||
) -> PersonalAccessToken | None:
|
||||
now = datetime.now(UTC)
|
||||
result = await session.execute(
|
||||
select(PersonalAccessToken)
|
||||
.options(selectinload(PersonalAccessToken.user))
|
||||
.join(User)
|
||||
.where(
|
||||
PersonalAccessToken.token_hash == hash_pat(token),
|
||||
(PersonalAccessToken.expires_at.is_(None))
|
||||
| (PersonalAccessToken.expires_at > now),
|
||||
User.is_active == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
|
||||
async def _touch_last_used(token_id: int) -> None:
|
||||
try:
|
||||
async with async_session_maker() as session:
|
||||
await session.execute(
|
||||
update(PersonalAccessToken)
|
||||
.where(PersonalAccessToken.id == token_id)
|
||||
.values(last_used_at=datetime.now(UTC))
|
||||
)
|
||||
await session.commit()
|
||||
except Exception:
|
||||
logger.exception("Failed to update PAT last_used_at for token %s", token_id)
|
||||
|
||||
|
||||
def maybe_touch_last_used(pat: PersonalAccessToken) -> None:
|
||||
last_used_at = pat.last_used_at
|
||||
now = datetime.now(UTC)
|
||||
if last_used_at is not None and now - last_used_at < LAST_USED_THROTTLE:
|
||||
return
|
||||
|
||||
asyncio.create_task(_touch_last_used(pat.id))
|
||||
Loading…
Add table
Add a link
Reference in a new issue