mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-26 21:39:43 +02:00
feat: resolve auth context from sessions and PATs
This commit is contained in:
parent
4463990ca4
commit
cddfb3660b
4 changed files with 175 additions and 2 deletions
1
surfsense_backend/app/auth/__init__.py
Normal file
1
surfsense_backend/app/auth/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
"""Authentication principals and helpers."""
|
||||||
38
surfsense_backend/app/auth/context.py
Normal file
38
surfsense_backend/app/auth/context.py
Normal file
|
|
@ -0,0 +1,38 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from app.db import PersonalAccessToken, User
|
||||||
|
|
||||||
|
AuthMethod = Literal["session", "pat", "system"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AuthContext:
|
||||||
|
"""Typed principal for authorization decisions."""
|
||||||
|
|
||||||
|
user: User
|
||||||
|
method: AuthMethod
|
||||||
|
pat: PersonalAccessToken | None = None
|
||||||
|
source: str | None = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def session(cls, user: User) -> AuthContext:
|
||||||
|
return cls(user=user, method="session")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def pat_auth(cls, user: User, pat: PersonalAccessToken) -> AuthContext:
|
||||||
|
return cls(user=user, method="pat", pat=pat)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def system(cls, user: User, source: str) -> AuthContext:
|
||||||
|
return cls(user=user, method="system", source=source)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_gated(self) -> bool:
|
||||||
|
return self.method == "pat"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_session(self) -> bool:
|
||||||
|
return self.method == "session"
|
||||||
|
|
@ -3,7 +3,7 @@ import uuid
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import Depends, Request, Response
|
from fastapi import Depends, HTTPException, Request, Response, status
|
||||||
from fastapi.responses import JSONResponse, RedirectResponse
|
from fastapi.responses import JSONResponse, RedirectResponse
|
||||||
from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, models
|
from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, models
|
||||||
from fastapi_users.authentication import (
|
from fastapi_users.authentication import (
|
||||||
|
|
@ -16,6 +16,7 @@ from pydantic import BaseModel
|
||||||
from sqlalchemy import update
|
from sqlalchemy import update
|
||||||
|
|
||||||
from app.config import config
|
from app.config import config
|
||||||
|
from app.auth.context import AuthContext
|
||||||
from app.db import (
|
from app.db import (
|
||||||
Prompt,
|
Prompt,
|
||||||
SearchSpace,
|
SearchSpace,
|
||||||
|
|
@ -23,11 +24,14 @@ from app.db import (
|
||||||
SearchSpaceRole,
|
SearchSpaceRole,
|
||||||
User,
|
User,
|
||||||
async_session_maker,
|
async_session_maker,
|
||||||
|
get_async_session,
|
||||||
get_default_roles_config,
|
get_default_roles_config,
|
||||||
get_user_db,
|
get_user_db,
|
||||||
)
|
)
|
||||||
from app.prompts.system_defaults import SYSTEM_PROMPT_DEFAULTS
|
from app.prompts.system_defaults import SYSTEM_PROMPT_DEFAULTS
|
||||||
|
from app.utils.pat import PAT_PREFIX, maybe_touch_last_used, resolve_pat
|
||||||
from app.utils.refresh_tokens import create_refresh_token
|
from app.utils.refresh_tokens import create_refresh_token
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -298,5 +302,62 @@ auth_backend = AuthenticationBackend(
|
||||||
|
|
||||||
fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])
|
fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])
|
||||||
|
|
||||||
current_active_user = fastapi_users.current_user(active=True)
|
|
||||||
|
async def get_auth_context(
|
||||||
|
request: Request,
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
user_manager: UserManager = Depends(get_user_manager),
|
||||||
|
) -> AuthContext:
|
||||||
|
auth_header = request.headers.get("Authorization")
|
||||||
|
if not auth_header:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Unauthorized",
|
||||||
|
)
|
||||||
|
|
||||||
|
scheme, _, token = auth_header.partition(" ")
|
||||||
|
if scheme.lower() != "bearer" or not token:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Unauthorized",
|
||||||
|
)
|
||||||
|
|
||||||
|
if token.startswith(PAT_PREFIX):
|
||||||
|
pat = await resolve_pat(session, token)
|
||||||
|
if pat and pat.user and pat.user.is_active:
|
||||||
|
maybe_touch_last_used(pat)
|
||||||
|
return AuthContext.pat_auth(pat.user, pat)
|
||||||
|
|
||||||
|
try:
|
||||||
|
user = await get_jwt_strategy().read_token(token, user_manager)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to read access token")
|
||||||
|
user = None
|
||||||
|
|
||||||
|
if not user or not user.is_active:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Unauthorized",
|
||||||
|
)
|
||||||
|
|
||||||
|
return AuthContext.session(user)
|
||||||
|
|
||||||
|
|
||||||
|
async def current_active_user(
|
||||||
|
auth: AuthContext = Depends(get_auth_context),
|
||||||
|
) -> User:
|
||||||
|
return auth.user
|
||||||
|
|
||||||
|
|
||||||
|
async def require_session_context(
|
||||||
|
auth: AuthContext = Depends(get_auth_context),
|
||||||
|
) -> AuthContext:
|
||||||
|
if not auth.is_session:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="This action requires an interactive session",
|
||||||
|
)
|
||||||
|
return auth
|
||||||
|
|
||||||
|
|
||||||
current_optional_user = fastapi_users.current_user(active=True, optional=True)
|
current_optional_user = fastapi_users.current_user(active=True, optional=True)
|
||||||
|
|
|
||||||
73
surfsense_backend/app/utils/pat.py
Normal file
73
surfsense_backend/app/utils/pat.py
Normal file
|
|
@ -0,0 +1,73 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
import secrets
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
|
||||||
|
from sqlalchemy import update
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
|
from app.db import PersonalAccessToken, User, async_session_maker
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
PAT_PREFIX = "ss_pat_"
|
||||||
|
PAT_TOKEN_BYTES = 32
|
||||||
|
LAST_USED_THROTTLE = timedelta(minutes=10)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_pat() -> str:
|
||||||
|
return f"{PAT_PREFIX}{secrets.token_urlsafe(PAT_TOKEN_BYTES)}"
|
||||||
|
|
||||||
|
|
||||||
|
def hash_pat(token: str) -> str:
|
||||||
|
return hashlib.sha256(token.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def token_prefix(token: str) -> str:
|
||||||
|
return token[:16]
|
||||||
|
|
||||||
|
|
||||||
|
async def resolve_pat(
|
||||||
|
session: AsyncSession,
|
||||||
|
token: str,
|
||||||
|
) -> PersonalAccessToken | None:
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
result = await session.execute(
|
||||||
|
select(PersonalAccessToken)
|
||||||
|
.options(selectinload(PersonalAccessToken.user))
|
||||||
|
.join(User)
|
||||||
|
.where(
|
||||||
|
PersonalAccessToken.token_hash == hash_pat(token),
|
||||||
|
(PersonalAccessToken.expires_at.is_(None))
|
||||||
|
| (PersonalAccessToken.expires_at > now),
|
||||||
|
User.is_active == True, # noqa: E712
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalars().first()
|
||||||
|
|
||||||
|
|
||||||
|
async def _touch_last_used(token_id: int) -> None:
|
||||||
|
try:
|
||||||
|
async with async_session_maker() as session:
|
||||||
|
await session.execute(
|
||||||
|
update(PersonalAccessToken)
|
||||||
|
.where(PersonalAccessToken.id == token_id)
|
||||||
|
.values(last_used_at=datetime.now(UTC))
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to update PAT last_used_at for token %s", token_id)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_touch_last_used(pat: PersonalAccessToken) -> None:
|
||||||
|
last_used_at = pat.last_used_at
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
if last_used_at is not None and now - last_used_at < LAST_USED_THROTTLE:
|
||||||
|
return
|
||||||
|
|
||||||
|
asyncio.create_task(_touch_last_used(pat.id))
|
||||||
Loading…
Add table
Add a link
Reference in a new issue