Merge pull request #1535 from AnishSarkar22/feat/auth-revamp

feat(auth): complete session auth cutover with desktop oauth support
This commit is contained in:
Rohan Verma 2026-06-25 13:31:02 -07:00 committed by GitHub
commit 6950646bf1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
158 changed files with 4032 additions and 1270 deletions

View file

@ -113,6 +113,7 @@ jobs:
env:
HOSTED_BACKEND_URL: ${{ vars.HOSTED_BACKEND_URL }}
HOSTED_FRONTEND_URL: ${{ vars.HOSTED_FRONTEND_URL }}
GOOGLE_DESKTOP_CLIENT_ID: ${{ vars.GOOGLE_DESKTOP_CLIENT_ID }}
POSTHOG_KEY: ${{ secrets.POSTHOG_KEY }}
POSTHOG_HOST: ${{ vars.POSTHOG_HOST }}
@ -143,6 +144,7 @@ jobs:
working-directory: surfsense_desktop
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GOOGLE_DESKTOP_CLIENT_ID: ${{ vars.GOOGLE_DESKTOP_CLIENT_ID }}
WINDOWS_PUBLISHER_NAME: ${{ vars.WINDOWS_PUBLISHER_NAME }}
AZURE_CODESIGN_ENDPOINT: ${{ vars.AZURE_CODESIGN_ENDPOINT }}
AZURE_CODESIGN_ACCOUNT: ${{ vars.AZURE_CODESIGN_ACCOUNT }}

1
.gitignore vendored
View file

@ -20,3 +20,4 @@ surfsense_web/blob-report/
content_research/
automation-design-plan.md
automation-frontend-builder-plan.md
surfsense_desktop/.env

View file

@ -30,6 +30,11 @@ SECRET_KEY=replace_me_with_a_random_string
# Auth type: LOCAL (email/password) or GOOGLE (OAuth)
AUTH_TYPE=LOCAL
# Cloud only: set COOKIE_DOMAIN=.surfsense.com so api., zero., and app
# subdomains all receive the same first-party session cookie. Leave empty for
# self-hosted Docker where Caddy serves a single origin.
# COOKIE_DOMAIN=
# Deployment mode: self-hosted enables local filesystem connectors; cloud hides them.
DEPLOYMENT_MODE=self-hosted
@ -135,6 +140,19 @@ CERT_EMAIL=
# ZERO_MUTATE_URL=https://surf.example.com/api/zero/mutate
# ZERO_QUERY_URL=http://frontend:3000/api/zero/query
# ZERO_MUTATE_URL=http://frontend:3000/api/zero/mutate
#
# Forward browser session cookies from zero-cache to the query route. Keep this
# enabled before switching the web app to cookie-only auth.
# ZERO_QUERY_FORWARD_COOKIES=true
#
# Optional shared secret for the zero-cache -> /api/zero/query hop. Set the same
# value on zero-cache and the frontend. When unset, the query route accepts the
# request for backward-compatible rollout.
# ZERO_QUERY_API_KEY=
#
# Bounds for auth revocation and RBAC membership changes on already-open sockets.
# ZERO_AUTH_REVALIDATE_INTERVAL_SECONDS=60
# ZERO_AUTH_RETRANSFORM_INTERVAL_SECONDS=60
# ------------------------------------------------------------------------------
# Database (defaults work out of the box, change for security)

View file

@ -99,7 +99,7 @@ services:
# container to run migrations, so you must run `uv run alembic upgrade head`
# from `surfsense_backend/` on the host BEFORE `docker compose up -d`.
zero-cache:
image: rocicorp/zero:1.4.0
image: rocicorp/zero:1.6.0
ports:
- "${ZERO_CACHE_PORT:-4848}:4848"
extra_hosts:
@ -120,6 +120,10 @@ services:
- ZERO_CVR_MAX_CONNS=${ZERO_CVR_MAX_CONNS:-30}
- ZERO_QUERY_URL=${ZERO_QUERY_URL:-http://host.docker.internal:3000/api/zero/query}
- ZERO_MUTATE_URL=${ZERO_MUTATE_URL:-http://host.docker.internal:3000/api/zero/mutate}
- ZERO_QUERY_FORWARD_COOKIES=${ZERO_QUERY_FORWARD_COOKIES:-true}
- ZERO_QUERY_API_KEY=${ZERO_QUERY_API_KEY:-}
- ZERO_AUTH_REVALIDATE_INTERVAL_SECONDS=${ZERO_AUTH_REVALIDATE_INTERVAL_SECONDS:-60}
- ZERO_AUTH_RETRANSFORM_INTERVAL_SECONDS=${ZERO_AUTH_RETRANSFORM_INTERVAL_SECONDS:-60}
volumes:
- zero_cache_data:/data
restart: unless-stopped

View file

@ -220,7 +220,7 @@ services:
condition: service_started
zero-cache:
image: rocicorp/zero:1.4.0
image: rocicorp/zero:1.6.0
ports:
- "${ZERO_CACHE_PORT:-4848}:4848"
extra_hosts:
@ -243,6 +243,10 @@ services:
- ZERO_CVR_MAX_CONNS=${ZERO_CVR_MAX_CONNS:-30}
- ZERO_QUERY_URL=${ZERO_QUERY_URL:-http://frontend:3000/api/zero/query}
- ZERO_MUTATE_URL=${ZERO_MUTATE_URL:-http://frontend:3000/api/zero/mutate}
- ZERO_QUERY_FORWARD_COOKIES=${ZERO_QUERY_FORWARD_COOKIES:-true}
- ZERO_QUERY_API_KEY=${ZERO_QUERY_API_KEY:-}
- ZERO_AUTH_REVALIDATE_INTERVAL_SECONDS=${ZERO_AUTH_REVALIDATE_INTERVAL_SECONDS:-60}
- ZERO_AUTH_RETRANSFORM_INTERVAL_SECONDS=${ZERO_AUTH_RETRANSFORM_INTERVAL_SECONDS:-60}
volumes:
- zero_cache_data:/data
restart: unless-stopped

View file

@ -250,7 +250,7 @@ services:
restart: unless-stopped
zero-cache:
image: rocicorp/zero:1.4.0
image: rocicorp/zero:1.6.0
expose:
- "4848"
extra_hosts:
@ -268,6 +268,10 @@ services:
ZERO_CVR_MAX_CONNS: ${ZERO_CVR_MAX_CONNS:-30}
ZERO_QUERY_URL: ${ZERO_QUERY_URL:-http://frontend:3000/api/zero/query}
ZERO_MUTATE_URL: ${ZERO_MUTATE_URL:-http://frontend:3000/api/zero/mutate}
ZERO_QUERY_FORWARD_COOKIES: ${ZERO_QUERY_FORWARD_COOKIES:-true}
ZERO_QUERY_API_KEY: ${ZERO_QUERY_API_KEY:-}
ZERO_AUTH_REVALIDATE_INTERVAL_SECONDS: ${ZERO_AUTH_REVALIDATE_INTERVAL_SECONDS:-60}
ZERO_AUTH_RETRANSFORM_INTERVAL_SECONDS: ${ZERO_AUTH_RETRANSFORM_INTERVAL_SECONDS:-60}
volumes:
- zero_cache_data:/data
restart: unless-stopped

View file

@ -81,9 +81,24 @@ STRIPE_RECONCILIATION_INTERVAL=10m
SECRET_KEY=SECRET
# JWT Token Lifetimes (optional, defaults shown)
# ACCESS_TOKEN_LIFETIME_SECONDS=86400 # 1 day
# REFRESH_TOKEN_LIFETIME_SECONDS=1209600 # 2 weeks
# JWT/session lifetimes (optional, defaults shown)
# ACCESS_TOKEN_LIFETIME_SECONDS=1800 # 30 minutes
# REFRESH_TOKEN_LIFETIME_SECONDS=1209600 # 14-day inactivity window
# REFRESH_ROTATION_GRACE_SECONDS=45
# REFRESH_ABSOLUTE_LIFETIME_SECONDS=2592000 # 30-day absolute cap
#
# Web session cookies. Leave COOKIE_DOMAIN empty for self-hosted same-origin
# Docker. In cloud, use .surfsense.com so api., zero., and the app share the
# first-party session cookie.
# SESSION_COOKIE_NAME=surfsense_session
# REFRESH_COOKIE_NAME=surfsense_refresh
# SESSION_COOKIE_SECURE_POLICY=auto
# SESSION_COOKIE_SAMESITE=lax
# COOKIE_DOMAIN=
#
# Comma-separated allow-list for cookie-session unsafe requests. Defaults also
# include NEXT_FRONTEND_URL and SURFSENSE_PUBLIC_URL when set.
# CSRF_ALLOWED_ORIGINS=http://localhost:3000
# Personal Access Tokens (PATs). Empty/unset = no maximum; users may create
# never-expiring PATs. When set, PAT creation requires an expiry <= this many days.
# PAT_MAX_EXPIRY_DAYS=
@ -115,6 +130,8 @@ REGISTRATION_ENABLED=TRUE or FALSE
# For Google Auth Only
GOOGLE_OAUTH_CLIENT_ID=924507538m
GOOGLE_OAUTH_CLIENT_SECRET=GOCSV
GOOGLE_DESKTOP_CLIENT_ID=your_google_desktop_client_id
GOOGLE_DESKTOP_CLIENT_SECRET=your_google_desktop_client_secret
GOOGLE_PICKER_API_KEY=your-google-picker-api-key
# Google Connector Specific Configurations

View file

@ -77,7 +77,5 @@ def upgrade() -> None:
def downgrade() -> None:
op.execute(
"ALTER TABLE searchspaces DROP COLUMN IF EXISTS api_access_enabled"
)
op.execute("ALTER TABLE searchspaces DROP COLUMN IF EXISTS api_access_enabled")
op.execute("DROP TABLE IF EXISTS personal_access_tokens")

View file

@ -0,0 +1,23 @@
"""publish Zero authz parent tables
Revision ID: 167
Revises: 166
"""
from collections.abc import Sequence
from alembic import op
from app.zero_publication import apply_publication
revision: str = "167"
down_revision: str | None = "166"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
apply_publication(op.get_bind())
def downgrade() -> None:
"""No-op. Historical publication shapes are immutable."""

View file

@ -0,0 +1,66 @@
"""harden refresh token schema
Revision ID: 168
Revises: 167
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
revision: str = "168"
down_revision: str | None = "167"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
op.add_column(
"refresh_tokens",
sa.Column("revoked_at", sa.TIMESTAMP(timezone=True), nullable=True),
)
op.add_column(
"refresh_tokens",
sa.Column("absolute_expiry", sa.TIMESTAMP(timezone=True), nullable=True),
)
op.execute(
"""
UPDATE refresh_tokens
SET revoked_at = NOW()
WHERE is_revoked = TRUE
"""
)
op.alter_column(
"refresh_tokens",
"token_hash",
existing_type=sa.String(length=256),
type_=sa.String(length=64),
existing_nullable=False,
)
op.drop_column("refresh_tokens", "is_revoked")
def downgrade() -> None:
op.add_column(
"refresh_tokens",
sa.Column("is_revoked", sa.Boolean(), nullable=False, server_default="false"),
)
op.execute(
"""
UPDATE refresh_tokens
SET is_revoked = TRUE
WHERE revoked_at IS NOT NULL
"""
)
op.alter_column("refresh_tokens", "is_revoked", server_default=None)
op.alter_column(
"refresh_tokens",
"token_hash",
existing_type=sa.String(length=64),
type_=sa.String(length=256),
existing_nullable=False,
)
op.drop_column("refresh_tokens", "absolute_expiry")
op.drop_column("refresh_tokens", "revoked_at")

View file

@ -0,0 +1,74 @@
"""migrate Google OAuth account IDs to sub
Revision ID: 169
Revises: 168
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
revision: str = "169"
down_revision: str | None = "168"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def _oauth_account_table_exists() -> bool:
bind = op.get_bind()
return bool(
bind.execute(
sa.text(
"""
SELECT EXISTS (
SELECT 1
FROM information_schema.tables
WHERE table_schema = current_schema()
AND table_name = 'oauth_account'
)
"""
)
).scalar()
)
def upgrade() -> None:
if not _oauth_account_table_exists():
return
op.execute(
"""
UPDATE oauth_account AS legacy
SET account_id = regexp_replace(legacy.account_id, '^people/', '')
WHERE legacy.oauth_name = 'google'
AND legacy.account_id LIKE 'people/%'
AND NOT EXISTS (
SELECT 1
FROM oauth_account AS canonical
WHERE canonical.oauth_name = 'google'
AND canonical.account_id = regexp_replace(legacy.account_id, '^people/', '')
)
"""
)
def downgrade() -> None:
if not _oauth_account_table_exists():
return
op.execute(
"""
UPDATE oauth_account AS canonical
SET account_id = 'people/' || canonical.account_id
WHERE canonical.oauth_name = 'google'
AND canonical.account_id NOT LIKE 'people/%'
AND NOT EXISTS (
SELECT 1
FROM oauth_account AS legacy
WHERE legacy.oauth_name = 'google'
AND legacy.account_id = 'people/' || canonical.account_id
)
"""
)

View file

@ -58,6 +58,7 @@ def create_create_automation_tool(
``AsyncSession`` is opened per call to avoid stale sessions on
compiled-agent cache hits (same pattern as the Notion / memory tools).
"""
@tool
async def create_automation(intent: str, runtime: ToolRuntime) -> dict[str, Any]:
"""Draft + save an automation from a natural-language intent.

View file

@ -242,6 +242,7 @@ def create_generate_image_tool(
# Update all image URLs in response_dict to be absolute (for the serving endpoint)
from urllib.parse import urlparse
for image in images:
if image.get("url"):
raw_url: str = image["url"]

View file

@ -10,7 +10,7 @@ from datetime import UTC, datetime
from threading import Lock
import redis
from fastapi import Depends, FastAPI, HTTPException, Request, status
from fastapi import Depends, FastAPI, HTTPException, Request, Response, status
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
@ -28,6 +28,7 @@ from app.agents.chat.runtime.checkpointer import (
setup_checkpointer_tables,
)
from app.auth.context import AuthContext
from app.auth.csrf import CsrfOriginMiddleware
from app.config import (
config,
initialize_image_gen_router,
@ -53,8 +54,14 @@ from app.observability import metrics as ot_metrics
from app.observability.bootstrap import init_otel, shutdown_otel
from app.rate_limiter import get_real_client_ip, limiter
from app.routes import router as crud_router
from app.routes.auth_routes import router as auth_router
from app.schemas import UserCreate, UserRead, UserUpdate
from app.routes.auth_routes import (
resolve_google_user,
router as auth_router,
session_router,
)
from app.routes.users_routes import router as users_router
from app.routes.zero_context_routes import router as zero_context_router
from app.schemas import UserCreate, UserRead
from app.session_events import register_session_hooks
from app.users import SECRET, allow_any_principal, auth_backend, fastapi_users
from app.utils.perf import log_system_snapshot
@ -803,6 +810,7 @@ allowed_origins.extend(
]
)
app.add_middleware(CsrfOriginMiddleware)
app.add_middleware(
CORSMiddleware,
allow_origins=allowed_origins,
@ -855,16 +863,14 @@ if config.AUTH_TYPE != "GOOGLE":
tags=["auth"],
)
# /users/me (read/update profile) is needed in every auth mode, so it stays
# mounted unconditionally.
app.include_router(
fastapi_users.get_users_router(UserRead, UserUpdate),
prefix="/users",
tags=["users"],
)
# /users/me uses the unified auth resolver so web cookie sessions, desktop bearer
# sessions, and PAT principals all resolve through the same authority.
app.include_router(users_router)
# Include custom auth routes (refresh token, logout)
app.include_router(auth_router)
app.include_router(session_router)
app.include_router(zero_context_router)
if config.AUTH_TYPE == "GOOGLE":
from fastapi.responses import RedirectResponse
@ -890,36 +896,183 @@ if config.AUTH_TYPE == "GOOGLE":
parsed_url = urlparse(config.BACKEND_URL)
csrf_cookie_domain = parsed_url.hostname
app.include_router(
fastapi_users.get_oauth_router(
google_oauth_client,
auth_backend,
SECRET,
is_verified_by_default=True,
csrf_token_cookie_secure=is_secure_context,
csrf_token_cookie_samesite=csrf_cookie_samesite,
csrf_token_cookie_httponly=False, # Required for cross-site OAuth in Firefox/Safari
from fastapi_users.jwt import decode_jwt
from fastapi_users.router.oauth import (
CSRF_TOKEN_COOKIE_NAME,
CSRF_TOKEN_KEY,
STATE_TOKEN_AUDIENCE,
generate_state_token,
)
from google.auth.transport import requests as google_requests
from google.oauth2 import id_token as google_id_token
from app.users import get_user_manager
def _google_callback_url(request: Request) -> str:
if config.BACKEND_URL:
return f"{config.BACKEND_URL}/auth/google/callback"
return str(request.url_for("google_oauth_callback"))
def _set_google_oauth_csrf_cookie(response: Response, csrf_token: str) -> None:
response.set_cookie(
key=CSRF_TOKEN_COOKIE_NAME,
value=csrf_token,
max_age=3600,
path="/",
domain=csrf_cookie_domain,
secure=is_secure_context,
httponly=False, # Required for cross-site OAuth in Firefox/Safari
samesite=csrf_cookie_samesite,
)
if not config.BACKEND_URL
else fastapi_users.get_oauth_router(
google_oauth_client,
auth_backend,
async def _google_authorization_url(request: Request, response: Response) -> str:
import secrets
csrf_token = secrets.token_urlsafe(32)
state = generate_state_token(
{CSRF_TOKEN_KEY: csrf_token},
SECRET,
is_verified_by_default=True,
redirect_url=f"{config.BACKEND_URL}/auth/google/callback",
csrf_token_cookie_secure=is_secure_context,
csrf_token_cookie_samesite=csrf_cookie_samesite,
csrf_token_cookie_httponly=False, # Required for cross-site OAuth in Firefox/Safari
csrf_token_cookie_domain=csrf_cookie_domain, # Explicitly set cookie domain
),
prefix="/auth/google",
lifetime_seconds=3600,
)
authorization_url = await google_oauth_client.get_authorization_url(
_google_callback_url(request),
state,
scope=["openid", "email", "profile"],
)
_set_google_oauth_csrf_cookie(response, csrf_token)
return authorization_url
@app.get(
"/auth/google/authorize",
tags=["auth"],
# REGISTRATION_ENABLED is a master auth kill switch: when set to FALSE
# it blocks BOTH new OAuth signups AND login of existing OAuth users
# (the fastapi-users OAuth router shares one callback for create+login,
# so this dependency closes both paths together).
dependencies=[Depends(registration_allowed)],
)
async def google_authorize(request: Request, response: Response):
"""Return Google's authorization URL, matching fastapi-users' shape."""
return {"authorization_url": await _google_authorization_url(request, response)}
@app.get(
"/auth/google/callback",
name="google_oauth_callback",
tags=["auth"],
dependencies=[Depends(registration_allowed)],
)
async def google_oauth_callback(
request: Request,
user_manager=Depends(get_user_manager),
):
"""Handle web Google OAuth with the same verified-email policy as desktop."""
import secrets
import httpx
import jwt as pyjwt
state = request.query_params.get("state")
code = request.query_params.get("code")
if not state or not code:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="OAuth callback missing code or state",
)
try:
state_data = decode_jwt(state, SECRET, [STATE_TOKEN_AUDIENCE])
except pyjwt.DecodeError as exc:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="ACCESS_TOKEN_DECODE_ERROR",
) from exc
except pyjwt.ExpiredSignatureError as exc:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="ACCESS_TOKEN_ALREADY_EXPIRED",
) from exc
cookie_csrf_token = request.cookies.get(CSRF_TOKEN_COOKIE_NAME)
state_csrf_token = state_data.get(CSRF_TOKEN_KEY)
if (
not cookie_csrf_token
or not state_csrf_token
or not secrets.compare_digest(cookie_csrf_token, state_csrf_token)
):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="OAUTH_INVALID_STATE",
)
token_payload = {
"client_id": config.GOOGLE_OAUTH_CLIENT_ID,
"client_secret": config.GOOGLE_OAUTH_CLIENT_SECRET,
"code": code,
"grant_type": "authorization_code",
"redirect_uri": _google_callback_url(request),
}
async with httpx.AsyncClient(timeout=10) as client:
token_response = await client.post(
"https://oauth2.googleapis.com/token",
data=token_payload,
)
if token_response.status_code >= 400:
_error_logger.warning(
"Web Google OAuth exchange failed: %s", token_response.text
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="OAuth exchange failed",
)
token_data = token_response.json()
google_access_token = token_data.get("access_token")
google_id_token_value = token_data.get("id_token")
if not google_access_token or not google_id_token_value:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="OAuth exchange failed",
)
try:
claims = google_id_token.verify_oauth2_token(
google_id_token_value,
google_requests.Request(),
config.GOOGLE_OAUTH_CLIENT_ID,
)
except Exception as exc:
_error_logger.warning("Web Google id_token verification failed: %s", exc)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid Google identity token",
) from exc
expires_at = (
int(datetime.now(UTC).timestamp()) + int(token_data["expires_in"])
if token_data.get("expires_in")
else None
)
user = await resolve_google_user(
user_manager=user_manager,
request=request,
google_access_token=google_access_token,
claims=claims,
expires_at=expires_at,
google_refresh_token=token_data.get("refresh_token"),
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="LOGIN_BAD_CREDENTIALS",
)
response = await auth_backend.login(auth_backend.get_strategy(), user)
await user_manager.on_after_login(user, request, response)
response.delete_cookie(
key=CSRF_TOKEN_COOKIE_NAME,
path="/",
domain=csrf_cookie_domain,
secure=is_secure_context,
samesite=csrf_cookie_samesite,
httponly=False,
)
return response
# Add a redirect-based authorize endpoint for Firefox/Safari compatibility
# This endpoint performs a server-side redirect instead of returning JSON
@ -944,43 +1097,9 @@ if config.AUTH_TYPE == "GOOGLE":
This fixes CSRF cookie issues in Firefox and Safari where cookies set
via cross-origin fetch requests are not sent on subsequent redirects.
"""
import secrets
from fastapi_users.router.oauth import generate_state_token
# Generate CSRF token
csrf_token = secrets.token_urlsafe(32)
# Build state token
state_data = {"csrftoken": csrf_token}
state = generate_state_token(state_data, SECRET, lifetime_seconds=3600)
# Get the callback URL
if config.BACKEND_URL:
redirect_url = f"{config.BACKEND_URL}/auth/google/callback"
else:
redirect_url = str(request.url_for("oauth:google.jwt.callback"))
# Get authorization URL from Google
authorization_url = await google_oauth_client.get_authorization_url(
redirect_url,
state,
scope=["openid", "email", "profile"],
)
# Create redirect response and set CSRF cookie
response = RedirectResponse(url=authorization_url, status_code=302)
response.set_cookie(
key="fastapiusersoauthcsrf",
value=csrf_token,
max_age=3600,
path="/",
domain=csrf_cookie_domain,
secure=is_secure_context,
httponly=False, # Required for cross-site OAuth in Firefox/Safari
samesite=csrf_cookie_samesite,
)
response = RedirectResponse(url="", status_code=302)
authorization_url = await _google_authorization_url(request, response)
response.headers["location"] = authorization_url
return response

View file

@ -0,0 +1,61 @@
"""CSRF protection for ambient cookie-authenticated requests."""
from __future__ import annotations
from urllib.parse import urlparse
from fastapi import status
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from app.config import config
UNSAFE_METHODS = {"POST", "PUT", "PATCH", "DELETE"}
def _origin_from_url(url: str | None) -> str | None:
if not url:
return None
parsed = urlparse(url)
if not parsed.scheme or not parsed.netloc:
return None
return f"{parsed.scheme}://{parsed.netloc}"
def _allowed_origins() -> set[str]:
origins = set(config.CSRF_ALLOWED_ORIGINS)
for url in (config.NEXT_FRONTEND_URL, config.SURFSENSE_PUBLIC_URL):
origin = _origin_from_url(url)
if origin:
origins.add(origin)
return origins
class CsrfOriginMiddleware(BaseHTTPMiddleware):
async def dispatch(
self,
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
if request.method not in UNSAFE_METHODS:
return await call_next(request)
# PAT/Bearer credentials are not ambient browser credentials and are not
# CSRF-able. Enforce only when the web session cookie is the credential.
if (
request.headers.get("Authorization")
or config.SESSION_COOKIE_NAME not in request.cookies
):
return await call_next(request)
origin = request.headers.get("Origin") or _origin_from_url(
request.headers.get("Referer")
)
if origin not in _allowed_origins():
return JSONResponse(
{"detail": "CSRF origin check failed"},
status_code=status.HTTP_403_FORBIDDEN,
)
return await call_next(request)

View file

@ -0,0 +1,130 @@
"""Centralized session-cookie I/O for web authentication."""
from __future__ import annotations
from datetime import UTC, datetime, timedelta
from enum import Enum
from typing import Any
import jwt
from fastapi import Request, Response
from app.config import config
class TransportMode(Enum):
COOKIE = "cookie"
HEADER = "header"
def _cookie_secure(request: Request | None = None) -> bool:
policy = config.SESSION_COOKIE_SECURE_POLICY
if policy == "always":
return True
if policy == "never":
return False
if request is not None:
proto = request.headers.get("x-forwarded-proto")
if proto:
return proto.split(",", 1)[0].strip().lower() == "https"
return request.url.scheme == "https"
return bool(config.BACKEND_URL and config.BACKEND_URL.startswith("https://"))
def _set_persistent_cookie(
response: Response,
*,
key: str,
value: str,
max_age: int,
request: Request | None,
) -> None:
expires = datetime.now(UTC) + timedelta(seconds=max_age)
response.set_cookie(
key=key,
value=value,
max_age=max_age,
expires=expires,
httponly=True,
secure=_cookie_secure(request),
samesite=config.SESSION_COOKIE_SAMESITE,
domain=config.COOKIE_DOMAIN,
path="/",
)
def write_session(
response: Response,
access: str,
refresh: str | None = None,
request: Request | None = None,
) -> None:
_set_persistent_cookie(
response,
key=config.SESSION_COOKIE_NAME,
value=access,
max_age=config.ACCESS_TOKEN_LIFETIME_SECONDS,
request=request,
)
if refresh is not None:
_set_persistent_cookie(
response,
key=config.REFRESH_COOKIE_NAME,
value=refresh,
max_age=config.REFRESH_TOKEN_LIFETIME_SECONDS,
request=request,
)
def clear_session(response: Response, request: Request | None = None) -> None:
for key in (config.SESSION_COOKIE_NAME, config.REFRESH_COOKIE_NAME):
response.delete_cookie(
key=key,
path="/",
domain=config.COOKIE_DOMAIN,
secure=_cookie_secure(request),
samesite=config.SESSION_COOKIE_SAMESITE,
httponly=True,
)
def read_refresh(
request: Request, body: Any | None = None
) -> tuple[str | None, TransportMode]:
cookie = request.cookies.get(config.REFRESH_COOKIE_NAME)
if cookie:
return cookie, TransportMode.COOKIE
if body is None:
return None, TransportMode.HEADER
return getattr(body, "refresh_token", None), TransportMode.HEADER
def access_expires_at(access_token: str) -> int:
payload = jwt.decode(
access_token,
config.SECRET_KEY,
algorithms=["HS256"],
options={"verify_aud": False},
)
return int(payload["exp"])
def issue(
response: Response,
mode: TransportMode,
*,
access: str,
refresh: str | None,
access_expires_at: int,
request: Request | None = None,
) -> dict:
if mode is TransportMode.COOKIE:
write_session(response, access, refresh, request)
return {"authenticated": True, "access_expires_at": access_expires_at}
return {
"access_token": access,
"refresh_token": refresh,
"token_type": "bearer",
"access_expires_at": access_expires_at,
}

View file

@ -10,6 +10,7 @@ from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.auth.context import AuthContext
from app.automations.persistence.enums.trigger_type import TriggerType
from app.automations.persistence.models.automation import Automation
from app.automations.persistence.models.trigger import AutomationTrigger
@ -27,7 +28,6 @@ from app.automations.services.model_policy import (
)
from app.automations.triggers import get_trigger
from app.automations.triggers.builtin.schedule import compute_next_fire_at
from app.auth.context import AuthContext
from app.db import Permission, SearchSpace, get_async_session
from app.users import get_auth_context
from app.utils.rbac import check_permission

View file

@ -6,9 +6,9 @@ from fastapi import Depends, HTTPException
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.automations.persistence.models.automation import Automation
from app.automations.persistence.models.run import AutomationRun
from app.auth.context import AuthContext
from app.db import Permission, get_async_session
from app.users import get_auth_context
from app.utils.rbac import check_permission

View file

@ -8,13 +8,13 @@ from fastapi import Depends, HTTPException
from pydantic import ValidationError
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.automations.persistence.enums.trigger_type import TriggerType
from app.automations.persistence.models.automation import Automation
from app.automations.persistence.models.trigger import AutomationTrigger
from app.automations.schemas.api import TriggerCreate, TriggerUpdate
from app.automations.triggers import get_trigger
from app.automations.triggers.builtin.schedule import compute_next_fire_at
from app.auth.context import AuthContext
from app.db import Permission, get_async_session
from app.users import get_auth_context
from app.utils.rbac import check_permission

View file

@ -188,6 +188,7 @@ celery_app = Celery(
"app.tasks.celery_tasks.document_reindex_tasks",
"app.tasks.celery_tasks.stale_notification_cleanup_task",
"app.tasks.celery_tasks.stripe_reconciliation_task",
"app.tasks.celery_tasks.refresh_token_cleanup_task",
"app.tasks.celery_tasks.auto_reload_task",
"app.tasks.celery_tasks.gateway_tasks",
"app.etl_pipeline.cache.eviction.task",
@ -306,6 +307,11 @@ celery_app.conf.beat_schedule = {
"schedule": crontab(hour="3", minute="17"),
"options": {"expires": 600},
},
"purge-refresh-tokens": {
"task": "purge_refresh_tokens",
"schedule": crontab(hour="3", minute="41"),
"options": {"expires": 600},
},
# Prune the ETL parse cache (TTL + size budget) once daily, off-peak.
"evict-etl-cache": {
"task": "evict_etl_cache",

View file

@ -768,6 +768,8 @@ class Config:
# Google OAuth
GOOGLE_OAUTH_CLIENT_ID = os.getenv("GOOGLE_OAUTH_CLIENT_ID")
GOOGLE_OAUTH_CLIENT_SECRET = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET")
GOOGLE_DESKTOP_CLIENT_ID = os.getenv("GOOGLE_DESKTOP_CLIENT_ID")
GOOGLE_DESKTOP_CLIENT_SECRET = os.getenv("GOOGLE_DESKTOP_CLIENT_SECRET")
GOOGLE_PICKER_API_KEY = os.getenv("GOOGLE_PICKER_API_KEY")
# Google Calendar redirect URI
@ -914,15 +916,39 @@ class Config:
# JWT Token Lifetimes
ACCESS_TOKEN_LIFETIME_SECONDS = int(
os.getenv("ACCESS_TOKEN_LIFETIME_SECONDS", str(24 * 60 * 60)) # 1 day
os.getenv("ACCESS_TOKEN_LIFETIME_SECONDS", str(30 * 60)) # 30 minutes
)
MIN_ISSUED_AT = int(os.getenv("MIN_ISSUED_AT", "0"))
REFRESH_TOKEN_LIFETIME_SECONDS = int(
os.getenv("REFRESH_TOKEN_LIFETIME_SECONDS", str(14 * 24 * 60 * 60)) # 2 weeks
)
_PAT_MAX_EXPIRY_DAYS = os.getenv("PAT_MAX_EXPIRY_DAYS", "").strip()
PAT_MAX_EXPIRY_DAYS = (
int(_PAT_MAX_EXPIRY_DAYS) if _PAT_MAX_EXPIRY_DAYS else None
REFRESH_ROTATION_GRACE_SECONDS = int(
os.getenv("REFRESH_ROTATION_GRACE_SECONDS", "45")
)
REFRESH_ABSOLUTE_LIFETIME_SECONDS = int(
os.getenv("REFRESH_ABSOLUTE_LIFETIME_SECONDS", str(30 * 24 * 60 * 60))
)
if REFRESH_ABSOLUTE_LIFETIME_SECONDS <= REFRESH_TOKEN_LIFETIME_SECONDS:
raise ValueError(
"REFRESH_ABSOLUTE_LIFETIME_SECONDS must be greater than "
"REFRESH_TOKEN_LIFETIME_SECONDS so the sliding inactivity window works."
)
SESSION_COOKIE_NAME = os.getenv("SESSION_COOKIE_NAME", "surfsense_session")
REFRESH_COOKIE_NAME = os.getenv("REFRESH_COOKIE_NAME", "surfsense_refresh")
SESSION_COOKIE_SECURE_POLICY = os.getenv(
"SESSION_COOKIE_SECURE_POLICY", "auto"
).lower()
SESSION_COOKIE_SAMESITE = os.getenv("SESSION_COOKIE_SAMESITE", "lax").lower()
if SESSION_COOKIE_SAMESITE == "none":
raise ValueError("SESSION_COOKIE_SAMESITE=none is not supported")
COOKIE_DOMAIN = os.getenv("COOKIE_DOMAIN") or None
CSRF_ALLOWED_ORIGINS = [
origin.strip()
for origin in os.getenv("CSRF_ALLOWED_ORIGINS", "").split(",")
if origin.strip()
]
_PAT_MAX_EXPIRY_DAYS = os.getenv("PAT_MAX_EXPIRY_DAYS", "").strip()
PAT_MAX_EXPIRY_DAYS = int(_PAT_MAX_EXPIRY_DAYS) if _PAT_MAX_EXPIRY_DAYS else None
# ETL Service
ETL_SERVICE = os.getenv("ETL_SERVICE")

View file

@ -2714,9 +2714,10 @@ class RefreshToken(Base, TimestampMixin):
index=True,
)
user = relationship("User", back_populates="refresh_tokens")
token_hash = Column(String(256), unique=True, nullable=False, index=True)
token_hash = Column(String(64), unique=True, nullable=False, index=True)
expires_at = Column(TIMESTAMP(timezone=True), nullable=False, index=True)
is_revoked = Column(Boolean, default=False, nullable=False)
revoked_at = Column(TIMESTAMP(timezone=True), nullable=True)
absolute_expiry = Column(TIMESTAMP(timezone=True), nullable=True)
family_id = Column(UUID(as_uuid=True), nullable=False, index=True)
@property
@ -2725,7 +2726,7 @@ class RefreshToken(Base, TimestampMixin):
@property
def is_valid(self) -> bool:
return not self.is_expired and not self.is_revoked
return not self.is_expired and self.revoked_at is None
class PersonalAccessToken(BaseModel, TimestampMixin):

View file

@ -28,13 +28,12 @@ from pydantic import BaseModel
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.agents.chat.multi_agent_chat.shared.feature_flags import get_flags
from app.auth.context import AuthContext
from app.db import (
AgentActionLog,
NewChatThread,
Permission,
User,
get_async_session,
)
from app.users import get_auth_context
@ -114,7 +113,6 @@ async def list_thread_actions(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
) -> AgentActionListResponse:
user = auth.user
"""List agent actions for a thread, newest first.
Authorization:

View file

@ -30,14 +30,13 @@ from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.agents.chat.multi_agent_chat.shared.feature_flags import get_flags
from app.auth.context import AuthContext
from app.db import (
AgentPermissionRule,
NewChatThread,
Permission,
SearchSpace,
User,
get_async_session,
)
from app.users import get_auth_context
@ -136,7 +135,6 @@ def _to_read(row: AgentPermissionRule) -> AgentPermissionRuleRead:
async def _ensure_search_space_membership_admin(
session: AsyncSession, auth: AuthContext, search_space_id: int
) -> None:
user = auth.user
"""Curating agent rules == "settings" administration on the space."""
space = await session.get(SearchSpace, search_space_id)
if space is None:

View file

@ -1,21 +1,46 @@
"""Authentication routes for refresh token management."""
import logging
from datetime import UTC, datetime
from types import SimpleNamespace
from urllib.parse import urlparse
from fastapi import APIRouter, Depends, HTTPException, status
import httpx
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from fastapi_users import exceptions as fastapi_users_exceptions
from google.auth.transport import requests as google_requests
from google.oauth2 import id_token as google_id_token
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.db import User, async_session_maker
from app.auth.session_cookies import (
access_expires_at,
clear_session,
issue,
read_refresh,
)
from app.config import config
from app.db import User, async_session_maker, get_async_session
from app.rate_limiter import limiter
from app.schemas.auth import (
DesktopLoginRequest,
DesktopSessionRequest,
LogoutAllResponse,
LogoutRequest,
LogoutResponse,
RefreshTokenRequest,
RefreshTokenResponse,
SessionResponse,
)
from app.users import (
UserManager,
get_auth_context,
get_jwt_strategy,
get_user_manager,
)
from app.users import get_jwt_strategy, require_session_context
from app.utils.refresh_tokens import (
create_refresh_token,
revoke_all_user_tokens,
revoke_refresh_token,
rotate_refresh_token,
@ -25,57 +50,140 @@ from app.utils.refresh_tokens import (
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/auth/jwt", tags=["auth"])
session_router = APIRouter(prefix="/auth", tags=["auth"])
@router.post("/refresh", response_model=RefreshTokenResponse)
async def refresh_access_token(request: RefreshTokenRequest):
async def _load_user(user_id) -> User | None:
async with async_session_maker() as session:
result = await session.execute(select(User).where(User.id == user_id))
return result.scalars().first()
async def resolve_google_user(
*,
user_manager: UserManager,
request: Request,
google_access_token: str,
claims: dict,
expires_at: int | None = None,
google_refresh_token: str | None = None,
) -> User:
"""Resolve a Google identity with one policy for web and desktop OAuth.
Email-based account linking is only allowed when Google asserts that the
email is verified. Existing OAuth accounts continue to resolve by provider
account id regardless of the current email claim.
"""
if not claims.get("sub") or not claims.get("email"):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid Google identity token",
)
sub = claims["sub"]
email_verified = bool(claims.get("email_verified"))
canonical_user = await user_manager.user_db.get_by_oauth_account("google", sub)
if canonical_user is None:
legacy_account_id = f"people/{sub}"
legacy_user = await user_manager.user_db.get_by_oauth_account(
"google", legacy_account_id
)
if legacy_user is not None:
# Fallback for pre-sub Google OAuth rows created by the old web flow.
# TODO: Remove after oauth_account is fully backfilled to bare Google
# sub and production has zero google rows with account_id LIKE 'people/%'.
for oauth_account in legacy_user.oauth_accounts:
if (
oauth_account.oauth_name == "google"
and oauth_account.account_id == legacy_account_id
):
await user_manager.user_db.update_oauth_account(
legacy_user,
oauth_account,
{"account_id": sub},
)
break
try:
return await user_manager.oauth_callback(
"google",
google_access_token,
sub,
claims["email"],
expires_at=expires_at,
refresh_token=google_refresh_token,
request=request,
associate_by_email=email_verified,
is_verified_by_default=email_verified,
)
except fastapi_users_exceptions.UserAlreadyExists as exc:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="OAUTH_USER_ALREADY_EXISTS",
) from exc
@router.post("/refresh", response_model=None)
@limiter.limit("30/minute")
async def refresh_access_token(
request: Request,
response: Response,
body: RefreshTokenRequest | None = None,
):
"""
Exchange a valid refresh token for a new access token and refresh token.
Implements token rotation for security.
"""
token_record = await validate_refresh_token(request.refresh_token)
if not token_record:
refresh_token, mode = read_refresh(request, body)
if not refresh_token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired refresh token",
)
# Get user from token record
async with async_session_maker() as session:
result = await session.execute(
select(User).where(User.id == token_record.user_id)
rotation = await rotate_refresh_token(refresh_token)
if not rotation:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired refresh token",
)
user = result.scalars().first()
user = await _load_user(rotation.user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found",
)
# Generate new access token
strategy = get_jwt_strategy()
access_token = await strategy.write_token(user)
# Rotate refresh token
new_refresh_token = await rotate_refresh_token(token_record)
logger.info(f"Refreshed token for user {user.id}")
return RefreshTokenResponse(
access_token=access_token,
refresh_token=new_refresh_token,
return issue(
response,
mode,
access=access_token,
refresh=rotation.refresh_token,
access_expires_at=access_expires_at(access_token),
request=request,
)
@router.post("/revoke", response_model=LogoutResponse)
async def revoke_token(request: LogoutRequest):
async def revoke_token(
request: Request,
response: Response,
body: LogoutRequest | None = None,
):
"""
Logout current device by revoking the provided refresh token.
Does not require authentication - just the refresh token.
"""
revoked = await revoke_refresh_token(request.refresh_token)
refresh_token, _mode = read_refresh(request, body)
revoked = await revoke_refresh_token(refresh_token) if refresh_token else False
clear_session(response, request)
if revoked:
logger.info("User logged out from current device - token revoked")
else:
@ -85,13 +193,185 @@ async def revoke_token(request: LogoutRequest):
@router.post("/logout-all", response_model=LogoutAllResponse)
async def logout_all_devices(
auth: AuthContext = Depends(require_session_context),
request: Request,
response: Response,
body: LogoutRequest | None = None,
session: AsyncSession = Depends(get_async_session),
user_manager: UserManager = Depends(get_user_manager),
):
"""
Logout from all devices by revoking all refresh tokens for the user.
Requires valid access token.
"""
user = auth.user
user: User | None = None
try:
auth = await get_auth_context(
request, session=session, user_manager=user_manager
)
if auth.is_session:
user = auth.user
except HTTPException:
user = None
if user is None:
refresh_token, _mode = read_refresh(request, body)
token_record = (
await validate_refresh_token(refresh_token) if refresh_token else None
)
if token_record:
user = await _load_user(token_record.user_id)
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired refresh token",
)
await revoke_all_user_tokens(user.id)
clear_session(response, request)
logger.info(f"User {user.id} logged out from all devices")
return LogoutAllResponse()
@session_router.get("/session", response_model=SessionResponse)
async def get_session(
request: Request,
auth: AuthContext = Depends(get_auth_context),
):
if auth.method == "pat":
return SessionResponse(access_expires_at=None)
access_token = request.cookies.get(config.SESSION_COOKIE_NAME)
if access_token is None:
auth_header = request.headers.get("Authorization")
if auth_header:
scheme, _, token = auth_header.partition(" ")
if scheme.lower() == "bearer" and token:
access_token = token
if access_token is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized"
)
return SessionResponse(access_expires_at=access_expires_at(access_token))
@session_router.post("/desktop/login", response_model=RefreshTokenResponse)
@limiter.limit("5/minute")
async def desktop_password_login(
request: Request,
body: DesktopLoginRequest,
user_manager: UserManager = Depends(get_user_manager),
):
if config.AUTH_TYPE == "GOOGLE":
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not found")
if not config.REGISTRATION_ENABLED:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Registration is disabled",
)
credentials = SimpleNamespace(username=body.email, password=body.password)
user = await user_manager.authenticate(credentials)
if user is None or not user.is_active:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="LOGIN_BAD_CREDENTIALS",
)
app_access_token = await get_jwt_strategy().write_token(user)
app_refresh_token = await create_refresh_token(user.id)
await user_manager.on_after_login(user, request, None)
return RefreshTokenResponse(
access_token=app_access_token,
refresh_token=app_refresh_token,
access_expires_at=access_expires_at(app_access_token),
)
@session_router.post("/desktop/session", response_model=RefreshTokenResponse)
@limiter.limit("20/minute")
async def create_desktop_session(
request: Request,
body: DesktopSessionRequest,
user_manager: UserManager = Depends(get_user_manager),
):
parsed_redirect = urlparse(body.redirect_uri)
try:
redirect_port = parsed_redirect.port
except ValueError:
redirect_port = None
if not (
parsed_redirect.scheme == "http"
and parsed_redirect.hostname in {"127.0.0.1", "::1"}
and redirect_port is not None
and parsed_redirect.path == "/callback"
):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid redirect URI"
)
if not config.GOOGLE_DESKTOP_CLIENT_ID:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Desktop OAuth is not configured",
)
token_payload = {
"client_id": config.GOOGLE_DESKTOP_CLIENT_ID,
"code": body.code,
"code_verifier": body.code_verifier,
"grant_type": "authorization_code",
"redirect_uri": body.redirect_uri,
}
if config.GOOGLE_DESKTOP_CLIENT_SECRET:
token_payload["client_secret"] = config.GOOGLE_DESKTOP_CLIENT_SECRET
async with httpx.AsyncClient(timeout=10) as client:
token_response = await client.post(
"https://oauth2.googleapis.com/token", data=token_payload
)
if token_response.status_code >= 400:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="OAuth exchange failed"
)
token_data = token_response.json()
id_token = token_data.get("id_token")
access_token = token_data.get("access_token")
if not id_token or not access_token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="OAuth exchange failed"
)
try:
claims = google_id_token.verify_oauth2_token(
id_token,
google_requests.Request(),
config.GOOGLE_DESKTOP_CLIENT_ID,
)
except Exception as exc:
logger.warning("Desktop Google id_token verification failed: %s", exc)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid Google identity token",
) from exc
user = await resolve_google_user(
user_manager=user_manager,
request=request,
google_access_token=access_token,
claims=claims,
expires_at=(
int(datetime.now(UTC).timestamp()) + int(token_data["expires_in"])
if token_data.get("expires_in")
else None
),
google_refresh_token=token_data.get("refresh_token"),
)
app_access_token = await get_jwt_strategy().write_token(user)
app_refresh_token = await create_refresh_token(user.id)
return RefreshTokenResponse(
access_token=app_access_token,
refresh_token=app_refresh_token,
access_expires_at=access_expires_at(app_access_token),
)

View file

@ -7,8 +7,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from app.auth.context import AuthContext
from app.agents.chat.runtime.path_resolver import virtual_path_to_doc
from app.auth.context import AuthContext
from app.db import (
Chunk,
Document,
@ -18,7 +18,6 @@ from app.db import (
Permission,
SearchSpace,
SearchSpaceMembership,
User,
get_async_session,
)
from app.schemas import (
@ -684,7 +683,6 @@ async def search_document_titles(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Lightweight document title search optimized for mention picker (@mentions).
@ -789,7 +787,6 @@ async def get_document_by_virtual_path(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Resolve a knowledge-base document by its agent-facing virtual path.
The agent renders every document under ``/documents/...`` with a
@ -847,7 +844,6 @@ async def get_documents_status(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Batch status endpoint for documents in a search space.
@ -1071,7 +1067,6 @@ async def get_watched_folders(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Return root folders that are marked as watched (metadata->>'watched' = 'true')."""
await check_permission(
session,
@ -1113,7 +1108,6 @@ async def get_document_chunks_paginated(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Paginated chunk loading for a document.
Supports both page-based and offset-based access.
@ -1175,7 +1169,6 @@ async def read_document(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Get a specific document by ID.
Requires DOCUMENTS_READ permission for the search space.
@ -1230,7 +1223,6 @@ async def update_document(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Update a document.
Requires DOCUMENTS_UPDATE permission for the search space.
@ -1290,7 +1282,6 @@ async def delete_document(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Delete a document.
Requires DOCUMENTS_DELETE permission for the search space.
@ -1536,7 +1527,6 @@ async def folder_mtime_check(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Pre-upload optimization: check which files need uploading based on mtime.
Returns the subset of relative paths where the file is new or has a
@ -1754,7 +1744,6 @@ async def folder_unlink(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Handle file deletion events from the desktop watcher.
For each relative path, find the matching document and delete it.
@ -1809,7 +1798,6 @@ async def folder_sync_finalize(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Finalize a full folder scan by deleting orphaned documents.
The client sends the complete list of relative paths currently in the

View file

@ -19,7 +19,7 @@ from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.db import Chunk, Document, DocumentType, Permission, User, get_async_session
from app.db import Chunk, Document, DocumentType, Permission, get_async_session
from app.routes.reports_routes import (
_FILE_EXTENSIONS,
_MEDIA_TYPES,
@ -50,7 +50,6 @@ async def get_editor_content(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Get document content for editing.
@ -182,7 +181,6 @@ async def download_document_markdown(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Download the full document content as a .md file.
Reconstructs markdown from source_markdown or chunks.
@ -337,7 +335,6 @@ async def export_document(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Export a document in the requested format (reuses the report export pipeline)."""
await check_permission(
session,

View file

@ -8,7 +8,7 @@ from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.db import Permission, User, get_async_session
from app.db import Permission, get_async_session
from app.services.export_service import build_export_zip
from app.users import get_auth_context
from app.utils.rbac import check_permission
@ -27,7 +27,6 @@ async def export_knowledge_base(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Export documents as a ZIP of markdown files preserving folder structure."""
await check_permission(
session,

View file

@ -6,7 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.auth.context import AuthContext
from app.db import Document, Folder, Permission, User, get_async_session
from app.db import Document, Folder, Permission, get_async_session
from app.schemas import (
BulkDocumentMove,
DocumentMove,
@ -95,7 +95,6 @@ async def list_folders(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""List all folders in a search space (flat). Requires DOCUMENTS_READ permission."""
try:
await check_permission(
@ -127,7 +126,6 @@ async def get_folder(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Get a single folder. Requires DOCUMENTS_READ permission."""
try:
folder = await session.get(Folder, folder_id)
@ -158,7 +156,6 @@ async def get_folder_breadcrumb(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Get ancestor chain for breadcrumb display. Requires DOCUMENTS_READ permission."""
try:
folder = await session.get(Folder, folder_id)
@ -203,7 +200,6 @@ async def stop_watching_folder(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Clear the watched flag from a folder's metadata."""
folder = await session.get(Folder, folder_id)
if not folder:
@ -232,7 +228,6 @@ async def update_folder(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Rename a folder. Requires DOCUMENTS_UPDATE permission."""
try:
folder = await session.get(Folder, folder_id)
@ -273,7 +268,6 @@ async def move_folder(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Move a folder to a new parent. Requires DOCUMENTS_UPDATE permission."""
try:
folder = await session.get(Folder, folder_id)
@ -334,7 +328,6 @@ async def reorder_folder(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Reorder a folder among its siblings via fractional indexing. Requires DOCUMENTS_UPDATE."""
try:
folder = await session.get(Folder, folder_id)
@ -376,7 +369,6 @@ async def delete_folder(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Mark documents for deletion and dispatch Celery to delete docs first, then folders."""
try:
folder = await session.get(Folder, folder_id)
@ -451,7 +443,6 @@ async def move_document(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Move a document to a folder (or root). Requires DOCUMENTS_UPDATE permission."""
try:
result = await session.execute(
@ -498,7 +489,6 @@ async def bulk_move_documents(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Move multiple documents to a folder (or root). Requires DOCUMENTS_UPDATE permission."""
try:
if not request.document_ids:

View file

@ -30,7 +30,6 @@ from app.db import (
ExternalChatHealthStatus,
ExternalChatPeerKind,
ExternalChatPlatform,
User,
get_async_session,
)
from app.gateway.accounts import (
@ -979,7 +978,6 @@ async def list_platforms(
async def get_gateway_config(
auth: AuthContext = Depends(get_auth_context),
) -> dict[str, bool | str]:
user = auth.user
if not config.GATEWAY_ENABLED:
return {
"enabled": False,

View file

@ -101,7 +101,6 @@ async def request_pairing_code(
async def bridge_health(
auth: AuthContext = Depends(get_auth_context),
) -> dict[str, Any]:
user = auth.user
_ensure_baileys_enabled()
adapter = WhatsAppBaileysAdapter()
try:

View file

@ -24,7 +24,6 @@ from app.db import (
Permission,
SearchSpace,
SearchSpaceMembership,
User,
get_async_session,
)
from app.schemas import (
@ -224,6 +223,7 @@ async def _execute_image_generation(
# Fix relative URLs in response data (for the serving endpoint)
from urllib.parse import urlparse
images = response_dict.get("data", [])
provider_base_url = resolved_kwargs.get("api_base")
for image in images:
@ -422,7 +422,6 @@ async def get_image_generation(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Get a specific image generation by ID."""
try:
result = await session.execute(
@ -455,7 +454,6 @@ async def delete_image_generation(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Delete an image generation record."""
try:
result = await session.execute(

View file

@ -13,7 +13,6 @@ from app.db import (
Permission,
SearchSpace,
SearchSpaceMembership,
User,
get_async_session,
)
from app.schemas import LogCreate, LogRead, LogUpdate
@ -29,7 +28,6 @@ async def create_log(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Create a new log entry.
Note: This is typically called internally. Requires LOGS_READ permission (since logs are usually system-generated).
@ -141,7 +139,6 @@ async def read_log(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Get a specific log by ID.
Requires LOGS_READ permission for the search space.
@ -178,7 +175,6 @@ async def update_log(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Update a log entry.
Requires LOGS_READ permission (logs are typically updated by system).
@ -222,7 +218,6 @@ async def delete_log(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Delete a log entry.
Requires LOGS_DELETE permission for the search space.
@ -262,7 +257,6 @@ async def get_logs_summary(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Get a summary of logs for a search space in the last X hours.
Requires LOGS_READ permission for the search space.

View file

@ -325,7 +325,9 @@ async def _assert_connection_access(
@router.get("/global-llm-config-status")
async def global_llm_config_status(auth: AuthContext = Depends(require_session_context)):
async def global_llm_config_status(
auth: AuthContext = Depends(require_session_context),
):
del auth
return {"exists": config.GLOBAL_LLM_CONFIG_FILE_EXISTS}

View file

@ -10,7 +10,7 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.db import Document, DocumentType, Permission, User, get_async_session
from app.db import Document, DocumentType, Permission, get_async_session
from app.schemas import DocumentRead, PaginatedResponse
from app.users import get_auth_context
from app.utils.rbac import check_permission
@ -102,7 +102,6 @@ async def list_notes(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
List all notes in a search space.
@ -196,7 +195,6 @@ async def delete_note(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Delete a note.

View file

@ -125,7 +125,6 @@ PERMISSION_DESCRIPTIONS = {
async def list_all_permissions(
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
List all available permissions that can be assigned to roles.
"""
@ -162,7 +161,6 @@ async def create_role(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Create a new custom role in a search space.
Requires ROLES_CREATE permission.
@ -244,7 +242,6 @@ async def list_roles(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
List all roles in a search space.
Requires ROLES_READ permission.
@ -283,7 +280,6 @@ async def get_role(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Get a specific role by ID.
Requires ROLES_READ permission.
@ -329,7 +325,6 @@ async def update_role(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Update a role.
Requires ROLES_UPDATE permission.
@ -427,7 +422,6 @@ async def delete_role(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Delete a custom role.
Requires ROLES_DELETE permission.
@ -485,7 +479,6 @@ async def list_members(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
List all members of a search space.
Requires MEMBERS_VIEW permission.
@ -551,7 +544,6 @@ async def update_member_role(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Update a member's role.
Requires MEMBERS_MANAGE_ROLES permission.
@ -689,7 +681,6 @@ async def remove_member(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Remove a member from a search space.
Requires MEMBERS_REMOVE permission.
@ -814,7 +805,6 @@ async def list_invites(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
List all invites for a search space.
Requires MEMBERS_INVITE permission.
@ -854,7 +844,6 @@ async def update_invite(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Update an invite.
Requires MEMBERS_INVITE permission.
@ -921,7 +910,6 @@ async def revoke_invite(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Revoke (delete) an invite.
Requires MEMBERS_INVITE permission.

View file

@ -33,7 +33,6 @@ from app.db import (
Report,
SearchSpace,
SearchSpaceMembership,
User,
get_async_session,
)
from app.schemas import ReportContentRead, ReportContentUpdate, ReportRead
@ -161,7 +160,6 @@ async def _get_report_with_access(
session: AsyncSession,
auth: AuthContext,
) -> Report:
user = auth.user
"""Fetch a report and verify the user belongs to its search space.
Raises HTTPException(404) if not found, HTTPException(403) if no access.

View file

@ -11,7 +11,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.auth.context import AuthContext
from app.db import NewChatThread, Permission, User, get_async_session
from app.db import NewChatThread, Permission, get_async_session
from app.users import get_auth_context
from app.utils.rbac import check_permission
@ -50,7 +50,6 @@ async def download_sandbox_file(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Download a file from the Daytona sandbox associated with a chat thread."""
from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.sandbox import (

View file

@ -40,7 +40,6 @@ from app.db import (
Permission,
SearchSourceConnector,
SearchSourceConnectorType,
User,
async_session_maker,
get_async_session,
)
@ -286,7 +285,6 @@ async def read_search_source_connectors(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
List all search source connectors for a search space.
Requires CONNECTORS_READ permission.
@ -330,7 +328,6 @@ async def read_search_source_connector(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Get a specific search source connector by ID.
Requires CONNECTORS_READ permission.
@ -565,7 +562,6 @@ async def delete_search_source_connector(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Delete a search source connector and all its associated documents.
@ -2735,7 +2731,6 @@ async def list_mcp_connectors(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
List all MCP connectors for a search space.
@ -2787,7 +2782,6 @@ async def get_mcp_connector(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Get a specific MCP connector by ID.
@ -2841,7 +2835,6 @@ async def update_mcp_connector(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Update an MCP connector.
@ -2918,7 +2911,6 @@ async def delete_mcp_connector(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Delete an MCP connector.
@ -2977,7 +2969,6 @@ async def test_mcp_server_connection(
server_config: dict = Body(...),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Test connection to an MCP server and fetch available tools.
@ -3058,7 +3049,6 @@ async def get_drive_picker_token(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""Return an OAuth access token + client ID for the Google Picker API."""
result = await session.execute(
select(SearchSourceConnector).filter(SearchSourceConnector.id == connector_id)

View file

@ -279,7 +279,9 @@ async def update_search_space(
) from e
@router.put("/searchspaces/{search_space_id}/api-access", response_model=SearchSpaceRead)
@router.put(
"/searchspaces/{search_space_id}/api-access", response_model=SearchSpaceRead
)
async def update_search_space_api_access(
search_space_id: int,
body: SearchSpaceApiAccessUpdate,

View file

@ -7,7 +7,7 @@ from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.db import User, get_async_session
from app.db import get_async_session
from app.services.memory import (
MemoryRead,
MemoryScope,
@ -32,7 +32,6 @@ async def get_team_memory(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
await check_search_space_access(session, auth, search_space_id)
memory_md = await read_memory(
scope=MemoryScope.TEAM,
@ -49,7 +48,6 @@ async def update_team_memory(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
await check_search_space_access(session, auth, search_space_id)
result = await save_memory(
scope=MemoryScope.TEAM,
@ -68,7 +66,6 @@ async def reset_team_memory(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
await check_search_space_access(session, auth, search_space_id)
result = await reset_memory(
scope=MemoryScope.TEAM,

View file

@ -0,0 +1,34 @@
"""Cookie-aware user profile routes."""
from fastapi import APIRouter, Depends, Request
from app.auth.context import AuthContext
from app.schemas import UserRead, UserUpdate
from app.users import (
UserManager,
get_auth_context,
get_user_manager,
require_session_context,
)
router = APIRouter(prefix="/users", tags=["users"])
@router.get("/me", response_model=UserRead)
async def get_current_user_profile(
auth: AuthContext = Depends(get_auth_context),
):
return auth.user
@router.patch("/me", response_model=UserRead)
async def update_current_user_profile(
update: UserUpdate,
request: Request,
auth: AuthContext = Depends(require_session_context),
user_manager: UserManager = Depends(get_user_manager),
):
updated_user = await user_manager.update(
update, auth.user, safe=True, request=request
)
return updated_user

View file

@ -21,7 +21,6 @@ from app.db import (
Permission,
SearchSpace,
SearchSpaceMembership,
User,
VideoPresentation,
get_async_session,
)
@ -93,7 +92,6 @@ async def read_video_presentation(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Get a specific video presentation by ID.
Requires authentication with VIDEO_PRESENTATIONS_READ permission.
@ -137,7 +135,6 @@ async def delete_video_presentation(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Delete a video presentation.
Requires VIDEO_PRESENTATIONS_DELETE permission for the search space.
@ -181,7 +178,6 @@ async def stream_slide_audio(
session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context),
):
user = auth.user
"""
Stream the audio file for a specific slide in a video presentation.
The slide_number is 1-based. Audio path is read from the slides JSONB.

View file

@ -0,0 +1,31 @@
"""Zero sync authentication context routes."""
from fastapi import APIRouter, Depends
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.db import get_async_session
from app.users import get_auth_context
from app.utils.rbac import get_allowed_read_space_ids
router = APIRouter(prefix="/zero", tags=["zero"])
class ZeroContextResponse(BaseModel):
model_config = ConfigDict(populate_by_name=True)
user_id: str = Field(alias="userId")
allowed_space_ids: list[int] = Field(alias="allowedSpaceIds")
@router.get("/context", response_model=ZeroContextResponse)
async def get_zero_context(
auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session),
) -> ZeroContextResponse:
allowed_space_ids = await get_allowed_read_space_ids(session, auth)
return ZeroContextResponse(
user_id=str(auth.user.id),
allowed_space_ids=allowed_space_ids,
)

View file

@ -242,9 +242,9 @@ __all__ = [
"SearchSourceConnectorCreate",
"SearchSourceConnectorRead",
"SearchSourceConnectorUpdate",
"SearchSpaceApiAccessUpdate",
# Search space schemas
"SearchSpaceBase",
"SearchSpaceApiAccessUpdate",
"SearchSpaceCreate",
"SearchSpaceRead",
"SearchSpaceUpdate",

View file

@ -6,21 +6,22 @@ from pydantic import BaseModel
class RefreshTokenRequest(BaseModel):
"""Request body for token refresh endpoint."""
refresh_token: str
refresh_token: str | None = None
class RefreshTokenResponse(BaseModel):
"""Response from token refresh endpoint."""
access_token: str
refresh_token: str
refresh_token: str | None = None
token_type: str = "bearer"
access_expires_at: int
class LogoutRequest(BaseModel):
"""Request body for logout endpoint (current device)."""
refresh_token: str
refresh_token: str | None = None
class LogoutResponse(BaseModel):
@ -33,3 +34,19 @@ class LogoutAllResponse(BaseModel):
"""Response from logout all devices endpoint."""
detail: str = "Successfully logged out from all devices"
class SessionResponse(BaseModel):
authenticated: bool = True
access_expires_at: int | None = None
class DesktopSessionRequest(BaseModel):
code: str
code_verifier: str
redirect_uri: str
class DesktopLoginRequest(BaseModel):
email: str
password: str

View file

@ -435,7 +435,6 @@ async def list_snapshots_for_thread(
thread_id: int,
auth: AuthContext,
) -> list[dict]:
user = auth.user
"""List all public snapshots for a thread."""
from app.config import config
@ -482,7 +481,6 @@ async def list_snapshots_for_search_space(
search_space_id: int,
auth: AuthContext,
) -> list[dict]:
user = auth.user
"""List all public snapshots for a search space."""
from app.config import config
@ -540,7 +538,6 @@ async def delete_snapshot(
snapshot_id: int,
auth: AuthContext,
) -> bool:
user = auth.user
"""Delete a specific snapshot. Only thread owner can delete."""
# Get snapshot with thread
result = await session.execute(

View file

@ -0,0 +1,34 @@
"""Celery task for pruning expired refresh-token rows."""
from __future__ import annotations
import asyncio
from datetime import UTC, datetime, timedelta
from sqlalchemy import delete, or_
from app.celery_app import celery_app
from app.config import config
from app.db import RefreshToken, async_session_maker
@celery_app.task(name="purge_refresh_tokens")
def purge_refresh_tokens() -> int:
return asyncio.run(_purge_refresh_tokens())
async def _purge_refresh_tokens() -> int:
now = datetime.now(UTC)
revoked_cutoff = now - timedelta(seconds=config.REFRESH_ROTATION_GRACE_SECONDS)
async with async_session_maker() as session:
result = await session.execute(
delete(RefreshToken).where(
or_(
RefreshToken.expires_at < now,
RefreshToken.revoked_at < revoked_cutoff,
)
)
)
await session.commit()
return result.rowcount or 0

View file

@ -3,6 +3,7 @@ import uuid
from datetime import UTC, datetime
import httpx
import jwt
from fastapi import Depends, HTTPException, Request, Response, status
from fastapi.responses import JSONResponse, RedirectResponse
from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, models
@ -12,11 +13,12 @@ from fastapi_users.authentication import (
JWTStrategy,
)
from fastapi_users.db import SQLAlchemyUserDatabase
from pydantic import BaseModel
from fastapi_users.jwt import generate_jwt
from sqlalchemy import update
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.auth.session_cookies import access_expires_at, write_session
from app.config import config
from app.db import (
Prompt,
@ -36,12 +38,6 @@ from app.utils.refresh_tokens import create_refresh_token
logger = logging.getLogger(__name__)
class BearerResponse(BaseModel):
access_token: str
refresh_token: str
token_type: str
SECRET = config.SECRET_KEY
@ -230,8 +226,23 @@ async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db
yield UserManager(user_db)
class IatJWTStrategy(JWTStrategy[models.UP, models.ID]):
async def write_token(self, user: models.UP) -> str:
data = {
"sub": str(user.id),
"aud": self.token_audience,
"iat": int(datetime.now(UTC).timestamp()),
}
return generate_jwt(
data,
self.encode_key,
self.lifetime_seconds,
algorithm=self.algorithm,
)
def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
return JWTStrategy(
return IatJWTStrategy(
secret=SECRET,
lifetime_seconds=config.ACCESS_TOKEN_LIFETIME_SECONDS,
)
@ -260,9 +271,6 @@ def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
# BEARER AUTH CODE.
class CustomBearerTransport(BearerTransport):
async def get_login_response(self, token: str) -> Response:
import jwt
# Decode JWT to get user_id for refresh token creation
try:
payload = jwt.decode(
token, SECRET, algorithms=["HS256"], options={"verify_aud": False}
@ -271,24 +279,26 @@ class CustomBearerTransport(BearerTransport):
refresh_token = await create_refresh_token(user_id)
except Exception as e:
logger.error(f"Failed to create refresh token: {e}")
# Fall back to response without refresh token
refresh_token = ""
bearer_response = BearerResponse(
access_token=token,
refresh_token=refresh_token,
token_type="bearer",
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to create session",
) from e
if config.AUTH_TYPE == "GOOGLE":
redirect_url = (
f"{config.NEXT_FRONTEND_URL}/auth/callback"
f"?token={bearer_response.access_token}"
f"&refresh_token={bearer_response.refresh_token}"
response = RedirectResponse(
f"{config.NEXT_FRONTEND_URL}/dashboard",
status_code=302,
)
return RedirectResponse(redirect_url, status_code=302)
else:
return JSONResponse(bearer_response.model_dump())
response = JSONResponse(
{
"authenticated": True,
"access_expires_at": access_expires_at(token),
}
)
write_session(response, token, refresh_token)
return response
bearer_transport = CustomBearerTransport(tokenUrl="auth/jwt/login")
@ -303,6 +313,22 @@ auth_backend = AuthenticationBackend(
fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])
def _token_meets_epoch(token: str) -> bool:
min_issued_at = config.MIN_ISSUED_AT
if min_issued_at <= 0:
return True
try:
payload = jwt.decode(
token, SECRET, algorithms=["HS256"], options={"verify_aud": False}
)
except jwt.PyJWTError:
return False
issued_at = payload.get("iat")
return isinstance(issued_at, (int, float)) and int(issued_at) >= min_issued_at
async def get_auth_context(
request: Request,
session: AsyncSession = Depends(get_async_session),
@ -315,38 +341,42 @@ async def get_auth_context(
receives the full SurfSense principal instead of a bare User.
"""
auth_header = request.headers.get("Authorization")
if not auth_header:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Unauthorized",
)
if auth_header:
scheme, _, credential = auth_header.partition(" ")
is_bearer = scheme.lower() == "bearer" and bool(credential)
token = credential if is_bearer else auth_header.strip()
scheme, _, token = auth_header.partition(" ")
if scheme.lower() != "bearer" or not token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Unauthorized",
)
if token.startswith(PAT_PREFIX):
pat = await resolve_pat(session, token)
if pat and pat.user and pat.user.is_active:
maybe_touch_last_used(pat)
return AuthContext.pat_auth(pat.user, pat)
if token.startswith(PAT_PREFIX):
pat = await resolve_pat(session, token)
if pat and pat.user and pat.user.is_active:
maybe_touch_last_used(pat)
return AuthContext.pat_auth(pat.user, pat)
if is_bearer and _token_meets_epoch(token):
try:
user = await get_jwt_strategy().read_token(token, user_manager)
except Exception:
logger.exception("Failed to read bearer access token")
user = None
try:
user = await get_jwt_strategy().read_token(token, user_manager)
except Exception:
logger.exception("Failed to read access token")
user = None
if user and user.is_active:
return AuthContext.session(user)
if not user or not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Unauthorized",
)
cookie_token = request.cookies.get(config.SESSION_COOKIE_NAME)
if cookie_token and _token_meets_epoch(cookie_token):
try:
user = await get_jwt_strategy().read_token(cookie_token, user_manager)
except Exception:
logger.exception("Failed to read session cookie access token")
user = None
return AuthContext.session(user)
if user and user.is_active:
return AuthContext.session(user)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Unauthorized",
)
async def allow_any_principal(
@ -371,6 +401,3 @@ async def require_session_context(
detail="This action requires an interactive session",
)
return auth
current_optional_user = fastapi_users.current_user(active=True, optional=True)

View file

@ -23,11 +23,15 @@ logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
def _render_inline_content(content: list[dict[str, Any]] | None) -> str:
def _render_inline_content(
content: list[dict[str, Any]] | None,
inherited_styles: dict[str, Any] | None = None,
) -> str:
"""Convert BlockNote inline content array to a markdown string."""
if not content:
return ""
inherited_styles = inherited_styles or {}
parts: list[str] = []
for item in content:
if not isinstance(item, dict):
@ -37,7 +41,10 @@ def _render_inline_content(content: list[dict[str, Any]] | None) -> str:
if item_type == "text":
text = item.get("text", "")
styles: dict[str, Any] = item.get("styles", {})
styles: dict[str, Any] = {
**inherited_styles,
**item.get("styles", {}),
}
# Apply inline styles (order: code first so nested marks don't break it)
if styles.get("code"):
@ -56,7 +63,11 @@ def _render_inline_content(content: list[dict[str, Any]] | None) -> str:
elif item_type == "link":
href = item.get("href", "")
link_content = item.get("content", [])
link_text = _render_inline_content(link_content) if link_content else href
link_text = (
_render_inline_content(link_content, inherited_styles)
if link_content
else href
)
parts.append(f"[{link_text}]({href})")
else:
@ -89,6 +100,7 @@ def _render_block(
"""
block_type = block.get("type", "paragraph")
props: dict[str, Any] = block.get("props", {})
styles: dict[str, Any] = block.get("styles", {})
content = block.get("content")
children: list[dict[str, Any]] = block.get("children", [])
prefix = " " * indent # 2-space indent per nesting level
@ -98,17 +110,17 @@ def _render_block(
# --- Block type handlers ---
if block_type == "paragraph":
text = _render_inline_content(content) if content else ""
text = _render_inline_content(content, styles) if content else ""
lines.append(f"{prefix}{text}")
elif block_type == "heading":
level = props.get("level", 1)
hashes = "#" * min(max(level, 1), 6)
text = _render_inline_content(content) if content else ""
text = _render_inline_content(content, styles) if content else ""
lines.append(f"{prefix}{hashes} {text}")
elif block_type == "bulletListItem":
text = _render_inline_content(content) if content else ""
text = _render_inline_content(content, styles) if content else ""
lines.append(f"{prefix}- {text}")
elif block_type == "numberedListItem":
@ -118,13 +130,13 @@ def _render_block(
numbered_list_counter = int(start)
else:
numbered_list_counter += 1
text = _render_inline_content(content) if content else ""
text = _render_inline_content(content, styles) if content else ""
lines.append(f"{prefix}{numbered_list_counter}. {text}")
elif block_type == "checkListItem":
checked = props.get("checked", False)
marker = "[x]" if checked else "[ ]"
text = _render_inline_content(content) if content else ""
text = _render_inline_content(content, styles) if content else ""
lines.append(f"{prefix}- {marker} {text}")
elif block_type == "codeBlock":

View file

@ -18,6 +18,7 @@ logger = logging.getLogger(__name__)
PAT_PREFIX = "ss_pat_"
PAT_TOKEN_BYTES = 32
LAST_USED_THROTTLE = timedelta(minutes=10)
_last_used_tasks: set[asyncio.Task[None]] = set()
def generate_pat() -> str:
@ -70,4 +71,6 @@ def maybe_touch_last_used(pat: PersonalAccessToken) -> None:
if last_used_at is not None and now - last_used_at < LAST_USED_THROTTLE:
return
asyncio.create_task(_touch_last_used(pat.id))
task = asyncio.create_task(_touch_last_used(pat.id))
_last_used_tasks.add(task)
task.add_done_callback(_last_used_tasks.discard)

View file

@ -80,6 +80,28 @@ async def get_user_permissions(
return []
async def get_allowed_read_space_ids(
session: AsyncSession,
auth: AuthContext,
) -> list[int]:
"""Return search spaces the principal may read through sync transports.
This mirrors the basic REST search-space access rule: membership is required,
and PAT principals are additionally constrained by the per-space API gate.
"""
stmt = (
select(SearchSpaceMembership.search_space_id)
.join(SearchSpace, SearchSpace.id == SearchSpaceMembership.search_space_id)
.filter(SearchSpaceMembership.user_id == auth.user.id)
.order_by(SearchSpaceMembership.search_space_id)
)
if auth.is_gated:
stmt = stmt.filter(SearchSpace.api_access_enabled == True) # noqa: E712
result = await session.execute(stmt)
return list(result.scalars().all())
async def _enforce_api_access_gate(
session: AsyncSession,
auth: AuthContext,

View file

@ -4,6 +4,7 @@ import hashlib
import logging
import secrets
import uuid
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from sqlalchemy import select, update
@ -14,6 +15,13 @@ from app.db import RefreshToken, async_session_maker
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class RefreshRotationResult:
user_id: uuid.UUID
refresh_token: str | None
access_only: bool = False
def generate_refresh_token() -> str:
"""Generate a cryptographically secure refresh token."""
return secrets.token_urlsafe(32)
@ -27,6 +35,7 @@ def hash_token(token: str) -> str:
async def create_refresh_token(
user_id: uuid.UUID,
family_id: uuid.UUID | None = None,
absolute_expiry: datetime | None = None,
) -> str:
"""
Create and store a new refresh token for a user.
@ -40,8 +49,14 @@ async def create_refresh_token(
"""
token = generate_refresh_token()
token_hash = hash_token(token)
expires_at = datetime.now(UTC) + timedelta(
seconds=config.REFRESH_TOKEN_LIFETIME_SECONDS
now = datetime.now(UTC)
if absolute_expiry is None:
absolute_expiry = now + timedelta(
seconds=config.REFRESH_ABSOLUTE_LIFETIME_SECONDS
)
expires_at = min(
now + timedelta(seconds=config.REFRESH_TOKEN_LIFETIME_SECONDS),
absolute_expiry,
)
if family_id is None:
@ -53,6 +68,7 @@ async def create_refresh_token(
token_hash=token_hash,
expires_at=expires_at,
family_id=family_id,
absolute_expiry=absolute_expiry,
)
session.add(refresh_token)
await session.commit()
@ -61,15 +77,7 @@ async def create_refresh_token(
async def validate_refresh_token(token: str) -> RefreshToken | None:
"""
Validate a refresh token. Handles reuse detection.
Args:
token: The plaintext refresh token
Returns:
RefreshToken if valid, None otherwise
"""
"""Validate an active refresh token without rotating it."""
token_hash = hash_token(token)
async with async_session_maker() as session:
@ -81,43 +89,87 @@ async def validate_refresh_token(token: str) -> RefreshToken | None:
if not refresh_token:
return None
# Reuse detection: revoked token used while family has active tokens
if refresh_token.is_revoked:
active = await session.execute(
select(RefreshToken).where(
RefreshToken.family_id == refresh_token.family_id,
RefreshToken.is_revoked == False, # noqa: E712
RefreshToken.expires_at > datetime.now(UTC),
)
now = datetime.now(UTC)
if (
refresh_token.revoked_at is not None
or now >= refresh_token.expires_at
or (
refresh_token.absolute_expiry is not None
and now >= refresh_token.absolute_expiry
)
if active.scalars().first():
# Revoke entire family
await session.execute(
update(RefreshToken)
.where(RefreshToken.family_id == refresh_token.family_id)
.values(is_revoked=True)
)
await session.commit()
logger.warning(f"Token reuse detected for user {refresh_token.user_id}")
return None
if refresh_token.is_expired:
):
return None
return refresh_token
async def rotate_refresh_token(old_token: RefreshToken) -> str:
"""Revoke old token and create new one in same family."""
async with async_session_maker() as session:
await session.execute(
update(RefreshToken)
.where(RefreshToken.id == old_token.id)
.values(is_revoked=True)
)
await session.commit()
async def rotate_refresh_token(token: str) -> RefreshRotationResult | None:
"""Atomically rotate a refresh token with access-only grace."""
token_hash = hash_token(token)
now = datetime.now(UTC)
grace_window = timedelta(seconds=config.REFRESH_ROTATION_GRACE_SECONDS)
return await create_refresh_token(old_token.user_id, old_token.family_id)
async with async_session_maker() as session:
async with session.begin():
result = await session.execute(
select(RefreshToken)
.where(RefreshToken.token_hash == token_hash)
.with_for_update()
)
refresh_token = result.scalars().first()
if not refresh_token:
return None
user_id = refresh_token.user_id
if refresh_token.revoked_at is not None:
if (
now - refresh_token.revoked_at <= grace_window
and now < refresh_token.expires_at
):
return RefreshRotationResult(
user_id=user_id,
refresh_token=None,
access_only=True,
)
await session.execute(
update(RefreshToken)
.where(RefreshToken.family_id == refresh_token.family_id)
.values(revoked_at=now, expires_at=now)
)
logger.warning(f"Token reuse detected for user {user_id}")
return None
if now >= refresh_token.expires_at:
return None
family_cap = refresh_token.absolute_expiry or (
now + timedelta(seconds=config.REFRESH_ABSOLUTE_LIFETIME_SECONDS)
)
if now >= family_cap:
return None
new_plaintext = generate_refresh_token()
child = RefreshToken(
user_id=user_id,
token_hash=hash_token(new_plaintext),
expires_at=min(
now + timedelta(seconds=config.REFRESH_TOKEN_LIFETIME_SECONDS),
family_cap,
),
family_id=refresh_token.family_id,
absolute_expiry=family_cap,
)
session.add(child)
refresh_token.revoked_at = now
refresh_token.absolute_expiry = family_cap
return RefreshRotationResult(
user_id=user_id,
refresh_token=new_plaintext,
access_only=False,
)
async def revoke_refresh_token(token: str) -> bool:
@ -131,12 +183,13 @@ async def revoke_refresh_token(token: str) -> bool:
True if token was found and revoked, False otherwise
"""
token_hash = hash_token(token)
now = datetime.now(UTC)
async with async_session_maker() as session:
result = await session.execute(
update(RefreshToken)
.where(RefreshToken.token_hash == token_hash)
.values(is_revoked=True)
.values(revoked_at=now, expires_at=now)
)
await session.commit()
return result.rowcount > 0
@ -144,10 +197,11 @@ async def revoke_refresh_token(token: str) -> bool:
async def revoke_all_user_tokens(user_id: uuid.UUID) -> None:
"""Revoke all refresh tokens for a user (logout all devices)."""
now = datetime.now(UTC)
async with async_session_maker() as session:
await session.execute(
update(RefreshToken)
.where(RefreshToken.user_id == user_id)
.values(is_revoked=True)
.values(revoked_at=now, expires_at=now)
)
await session.commit()

View file

@ -52,6 +52,16 @@ AUTOMATION_RUN_COLS = [
"created_at",
]
AUTOMATION_COLS = [
"id",
"search_space_id",
]
NEW_CHAT_THREAD_COLS = [
"id",
"search_space_id",
]
# Enough to drive the lifecycle UI by push: status, the reviewable brief, and
# its version. The bulky source_content and transcript are deliberately excluded
# and fetched over REST when a gate opens.
@ -73,10 +83,12 @@ ZERO_PUBLICATION: Mapping[str, Sequence[str] | None] = {
"documents": DOCUMENT_COLS,
"folders": None,
"search_source_connectors": None,
"new_chat_threads": NEW_CHAT_THREAD_COLS,
"new_chat_messages": None,
"chat_comments": None,
"chat_session_state": None,
"user": USER_COLS,
"automations": AUTOMATION_COLS,
"automation_runs": AUTOMATION_RUN_COLS,
"podcasts": PODCAST_COLS,
}

View file

@ -0,0 +1,69 @@
"""One-shot cutover helper to revoke every refresh token.
Run with --yes during the auth-hardening cutover, alongside setting
MIN_ISSUED_AT to the deploy epoch.
"""
from __future__ import annotations
import argparse
import asyncio
from sqlalchemy import text
from app.db import async_session_maker
async def _count_active_tokens() -> int:
async with async_session_maker() as session:
result = await session.execute(
text(
"""
SELECT count(*)
FROM refresh_tokens
WHERE revoked_at IS NULL
AND expires_at > NOW()
"""
)
)
return int(result.scalar_one())
async def _revoke_all_tokens() -> int:
async with async_session_maker() as session:
result = await session.execute(
text(
"""
UPDATE refresh_tokens
SET revoked_at = NOW(),
expires_at = NOW()
WHERE revoked_at IS NULL
OR expires_at > NOW()
"""
)
)
await session.commit()
return int(result.rowcount or 0)
async def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"--yes",
action="store_true",
help="Actually revoke tokens. Without this flag the command is a dry run.",
)
args = parser.parse_args()
active_count = await _count_active_tokens()
if not args.yes:
print(f"Dry run: {active_count} active refresh token(s) would be revoked.")
print("Re-run with --yes during the auth-hardening cutover to revoke them.")
return
updated_count = await _revoke_all_tokens()
print(f"Revoked {updated_count} refresh token row(s).")
if __name__ == "__main__":
asyncio.run(main())

View file

@ -8,7 +8,6 @@ webhook fulfillment (idempotent), and the reconciliation fallback.
from __future__ import annotations
from types import SimpleNamespace
from urllib.parse import parse_qs, urlparse
import asyncpg
import httpx
@ -63,18 +62,13 @@ def _extract_access_token(response: httpx.Response) -> str | None:
if response.status_code == 200:
return response.json()["access_token"]
if response.status_code == 302:
location = response.headers.get("location", "")
return parse_qs(urlparse(location).query).get("token", [None])[0]
return None
async def _authenticate_test_user(client: httpx.AsyncClient) -> str:
response = await client.post(
"/auth/jwt/login",
data={"username": TEST_EMAIL, "password": TEST_PASSWORD},
headers={"Content-Type": "application/x-www-form-urlencoded"},
"/auth/desktop/login",
json={"email": TEST_EMAIL, "password": TEST_PASSWORD},
)
token = _extract_access_token(response)
if token:
@ -89,9 +83,8 @@ async def _authenticate_test_user(client: httpx.AsyncClient) -> str:
)
response = await client.post(
"/auth/jwt/login",
data={"username": TEST_EMAIL, "password": TEST_PASSWORD},
headers={"Content-Type": "application/x-www-form-urlencoded"},
"/auth/desktop/login",
json={"email": TEST_EMAIL, "password": TEST_PASSWORD},
)
token = _extract_access_token(response)
assert token, f"Login failed ({response.status_code}): {response.text}"

View file

@ -0,0 +1,90 @@
from __future__ import annotations
from types import SimpleNamespace
from fastapi import Request, Response
from app.auth.session_cookies import TransportMode, issue, read_refresh
from app.config import config
def _request_with_refresh_cookie(token: str) -> Request:
scope = {
"type": "http",
"method": "POST",
"path": "/auth/jwt/refresh",
"headers": [(b"cookie", f"{config.REFRESH_COOKIE_NAME}={token}".encode())],
"scheme": "https",
"server": ("testserver", 443),
}
return Request(scope)
def test_cookie_transport_sets_cookies_without_body_tokens():
response = Response()
body = issue(
response,
TransportMode.COOKIE,
access="access-token",
refresh="refresh-token",
access_expires_at=123,
)
assert "access_token" not in body
assert "refresh_token" not in body
assert body == {"authenticated": True, "access_expires_at": 123}
set_cookie_headers = response.headers.getlist("set-cookie")
assert any(config.SESSION_COOKIE_NAME in header for header in set_cookie_headers)
assert any(config.REFRESH_COOKIE_NAME in header for header in set_cookie_headers)
def test_cookie_transport_re_stamps_access_without_refresh_body_or_cookie():
response = Response()
body = issue(
response,
TransportMode.COOKIE,
access="access-token",
refresh=None,
access_expires_at=123,
)
assert "access_token" not in body
assert "refresh_token" not in body
set_cookie_headers = response.headers.getlist("set-cookie")
assert any(config.SESSION_COOKIE_NAME in header for header in set_cookie_headers)
assert not any(
config.REFRESH_COOKIE_NAME in header for header in set_cookie_headers
)
def test_header_transport_returns_body_tokens_without_cookies():
response = Response()
body = issue(
response,
TransportMode.HEADER,
access="access-token",
refresh="refresh-token",
access_expires_at=123,
)
assert body == {
"access_token": "access-token",
"refresh_token": "refresh-token",
"token_type": "bearer",
"access_expires_at": 123,
}
assert "set-cookie" not in response.headers
def test_read_refresh_cookie_source_wins_over_body_source():
request = _request_with_refresh_cookie("cookie-token")
refresh, mode = read_refresh(request, SimpleNamespace(refresh_token="body-token"))
assert refresh == "cookie-token"
assert mode is TransportMode.COOKIE

View file

@ -0,0 +1,85 @@
"""Regression tests for Zero's backend-computed authorization context."""
from __future__ import annotations
import pytest
from fastapi import HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.db import PersonalAccessToken, SearchSpace, User
from app.routes.search_spaces_routes import create_default_roles_and_membership
from app.utils.rbac import check_search_space_access, get_allowed_read_space_ids
pytestmark = pytest.mark.integration
def _pat_auth(user: User) -> AuthContext:
pat = PersonalAccessToken(
user_id=user.id,
user=user,
token_hash="1" * 64,
token_prefix="ss_pat_zero",
label="Zero PAT",
)
return AuthContext.pat_auth(user, pat)
async def _space_with_membership(
db_session: AsyncSession,
user: User,
*,
api_access_enabled: bool,
) -> SearchSpace:
space = SearchSpace(
name="Zero Authz Space",
user_id=user.id,
api_access_enabled=api_access_enabled,
)
db_session.add(space)
await db_session.flush()
await create_default_roles_and_membership(db_session, space.id, user.id)
await db_session.flush()
return space
async def test_zero_read_set_matches_session_search_space_access(
db_session: AsyncSession,
db_user: User,
db_search_space: SearchSpace,
):
disabled_space = await _space_with_membership(
db_session,
db_user,
api_access_enabled=False,
)
session_auth = AuthContext.session(db_user)
allowed_ids = set(await get_allowed_read_space_ids(db_session, session_auth))
for space in (db_search_space, disabled_space):
membership = await check_search_space_access(db_session, session_auth, space.id)
assert membership.search_space_id in allowed_ids
async def test_zero_read_set_applies_pat_api_access_gate(
db_session: AsyncSession,
db_user: User,
db_search_space: SearchSpace,
):
db_search_space.api_access_enabled = True
disabled_space = await _space_with_membership(
db_session,
db_user,
api_access_enabled=False,
)
await db_session.flush()
pat_auth = _pat_auth(db_user)
allowed_ids = set(await get_allowed_read_space_ids(db_session, pat_auth))
assert db_search_space.id in allowed_ids
assert disabled_space.id not in allowed_ids
with pytest.raises(HTTPException) as exc_info:
await check_search_space_access(db_session, pat_auth, disabled_space.id)
assert exc_info.value.status_code == 403

View file

@ -40,7 +40,9 @@ async def cleanup_supervisors():
async def test_start_byo_long_poll_noops_when_mode_is_webhook(monkeypatch):
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_ENABLED", True)
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_TELEGRAM_INTAKE_MODE", "webhook")
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_WHATSAPP_INTAKE_MODE", "disabled")
monkeypatch.setattr(
byo_long_poll.config, "GATEWAY_WHATSAPP_INTAKE_MODE", "disabled"
)
await byo_long_poll.start_byo_long_poll_supervisors()
@ -53,7 +55,9 @@ async def test_start_byo_long_poll_noops_when_no_byo_accounts(mocker, monkeypatc
monkeypatch.setattr(
byo_long_poll.config, "GATEWAY_TELEGRAM_INTAKE_MODE", "longpoll"
)
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_WHATSAPP_INTAKE_MODE", "disabled")
monkeypatch.setattr(
byo_long_poll.config, "GATEWAY_WHATSAPP_INTAKE_MODE", "disabled"
)
session = mocker.AsyncMock()
session.execute.return_value = ScalarResult([])
monkeypatch.setattr(
@ -75,7 +79,9 @@ async def test_start_byo_long_poll_spawns_one_supervisor_per_account(
monkeypatch.setattr(
byo_long_poll.config, "GATEWAY_TELEGRAM_INTAKE_MODE", "longpoll"
)
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_WHATSAPP_INTAKE_MODE", "disabled")
monkeypatch.setattr(
byo_long_poll.config, "GATEWAY_WHATSAPP_INTAKE_MODE", "disabled"
)
accounts = [mocker.Mock(id=1), mocker.Mock(id=2)]
session = mocker.AsyncMock()
session.execute.return_value = ScalarResult(accounts)
@ -125,7 +131,9 @@ async def test_shutdown_cancels_running_supervisors(mocker, monkeypatch):
monkeypatch.setattr(
byo_long_poll.config, "GATEWAY_TELEGRAM_INTAKE_MODE", "longpoll"
)
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_WHATSAPP_INTAKE_MODE", "disabled")
monkeypatch.setattr(
byo_long_poll.config, "GATEWAY_WHATSAPP_INTAKE_MODE", "disabled"
)
session = mocker.AsyncMock()
session.execute.return_value = ScalarResult([mocker.Mock(id=1)])
monkeypatch.setattr(

View file

@ -450,7 +450,9 @@ class TestRevertTurnDispatch:
thread_id=1,
chat_turn_id="ct-mixed-all",
session=session,
auth=AuthContext.session(_FakeUser()), # only id=7 has a different user_id
auth=AuthContext.session(
_FakeUser()
), # only id=7 has a different user_id
)
assert response.total == len(rows) == 6

View file

@ -32,7 +32,9 @@ def _patch_common_bundle_dependencies(monkeypatch: pytest.MonkeyPatch):
_CapturedChatLiteLLM.calls = []
async def _fake_search_space(_session: Any, _search_space_id: int) -> SimpleNamespace:
async def _fake_search_space(
_session: Any, _search_space_id: int
) -> SimpleNamespace:
return SimpleNamespace(id=42, user_id="user-1")
monkeypatch.setattr(llm_bundle, "_load_search_space", _fake_search_space)

View file

@ -32,11 +32,7 @@ CONNECTOR_LISTERS = [
def _python_files() -> list[Path]:
return [
path
for path in APP_ROOT.rglob("*.py")
if "__pycache__" not in path.parts
]
return [path for path in APP_ROOT.rglob("*.py") if "__pycache__" not in path.parts]
def test_current_active_user_is_removed_from_app_tree() -> None:

View file

@ -0,0 +1,22 @@
"""Static guards for Zero authorization wiring."""
from __future__ import annotations
from pathlib import Path
import pytest
pytestmark = pytest.mark.unit
REPO_ROOT = Path(__file__).resolve().parents[3]
WEB_ROOT = REPO_ROOT / "surfsense_web"
def test_zero_query_route_uses_authoritative_backend_context() -> None:
route = WEB_ROOT / "app/api/zero/query/route.ts"
text = route.read_text()
assert "/zero/context" in text
assert "/users/me" not in text
assert "userID: auth.ctx.userId" in text
assert "handleQueryRequest({" in text

View file

@ -290,4 +290,4 @@ class TestExtractTextContent:
def test_boolean_returns_empty_string(self):
from app.utils.content_utils import extract_text_content
assert extract_text_content(True) == ""
assert extract_text_content(True) == ""

View file

@ -16,9 +16,8 @@ TEST_PASSWORD = "testpassword123"
async def get_auth_token(client: httpx.AsyncClient) -> str:
"""Log in and return a Bearer JWT token, registering the user first if needed."""
response = await client.post(
"/auth/jwt/login",
data={"username": TEST_EMAIL, "password": TEST_PASSWORD},
headers={"Content-Type": "application/x-www-form-urlencoded"},
"/auth/desktop/login",
json={"email": TEST_EMAIL, "password": TEST_PASSWORD},
)
if response.status_code == 200:
return response.json()["access_token"]
@ -32,9 +31,8 @@ async def get_auth_token(client: httpx.AsyncClient) -> str:
)
response = await client.post(
"/auth/jwt/login",
data={"username": TEST_EMAIL, "password": TEST_PASSWORD},
headers={"Content-Type": "application/x-www-form-urlencoded"},
"/auth/desktop/login",
json={"email": TEST_EMAIL, "password": TEST_PASSWORD},
)
assert response.status_code == 200, (
f"Login after registration failed ({response.status_code}): {response.text}"

View file

@ -1,10 +0,0 @@
# Electron-specific build-time configuration.
# Set before running pnpm dist:mac / dist:win / dist:linux.
# The hosted web frontend URL. Used to intercept OAuth redirects and keep them
# inside the desktop app. Set to your production frontend domain.
HOSTED_FRONTEND_URL=https://surfsense.com
# PostHog analytics (leave empty to disable)
POSTHOG_KEY=
POSTHOG_HOST=https://assets.surfsense.com

View file

@ -3,7 +3,15 @@
# The hosted web frontend URL. Used to intercept OAuth redirects and keep them
# inside the desktop app. Set to your production frontend domain.
HOSTED_FRONTEND_URL=https://surfsense.com
HOSTED_FRONTEND_URL=http://localhost:3000
# The backend API URL used by desktop auth and refresh flows.
HOSTED_BACKEND_URL=http://localhost:8000
# Public Google OAuth Desktop app client ID. Required for packaged desktop
# Google login using loopback + PKCE. This is safe to ship in the desktop app;
# the PKCE code verifier, not a client secret, protects the token exchange.
GOOGLE_DESKTOP_CLIENT_ID=your_google_desktop_client_id.apps.googleusercontent.com
# Runtime override for the above (read at app start, no rebuild required).
# Useful for self-hosters whose backend NEXT_FRONTEND_URL differs from the

View file

@ -28,7 +28,7 @@
"@types/node": "^25.5.0",
"concurrently": "^9.2.1",
"dotenv": "^17.3.1",
"electron": "^41.0.2",
"electron": "^42.4.0",
"electron-builder": "^26.8.1",
"esbuild": "^0.27.4",
"typescript": "^5.9.3",

View file

@ -46,8 +46,8 @@ importers:
specifier: ^17.3.1
version: 17.3.1
electron:
specifier: ^41.0.2
version: 41.0.2
specifier: ^42.4.0
version: 42.4.0
electron-builder:
specifier: ^26.8.1
version: 26.8.1(electron-builder-squirrel-windows@26.8.1)
@ -70,6 +70,10 @@ packages:
resolution: {integrity: sha512-0cp4PsWQ/9avqTVMCtZ+GirikIA36ikvjtHweU4/j8yLtgObI0+JUPhYFScgwlteveGB1rt3Cm8UhN04XayDig==}
engines: {node: '>= 8.9.0'}
'@electron-internal/extract-zip@1.0.3':
resolution: {integrity: sha512-OjKpjB7gohtEjZiq6nDx1egqjZJhGPN1iFOIED+NFhB/MMkXw/XRcHjh1DGXKT5z2W9eW7Jy2UKU3gpjvusFTQ==}
engines: {node: '>=22.12.0'}
'@electron/asar@3.4.1':
resolution: {integrity: sha512-i4/rNPRS84t0vSRa2HorerGRXWyF4vThfHesw0dmcWHp+cspK743UanA0suA5Q5y8kzY2y6YKrvbIUn69BCAiA==}
engines: {node: '>=10.12.0'}
@ -79,14 +83,14 @@ packages:
resolution: {integrity: sha512-zx0EIq78WlY/lBb1uXlziZmDZI4ubcCXIMJ4uGjXzZW0nS19TjSPeXPAjzzTmKQlJUZm0SbmZhPKP7tuQ1SsEw==}
hasBin: true
'@electron/get@2.0.3':
resolution: {integrity: sha512-Qkzpg2s9GnVV2I2BjRksUi43U5e6+zaQMcjoJy0C+C5oxaKl+fmckGDQFtRpZpZV0NQekuZZ+tGz7EA9TVnQtQ==}
engines: {node: '>=12'}
'@electron/get@3.1.0':
resolution: {integrity: sha512-F+nKc0xW+kVbBRhFzaMgPy3KwmuNTYX1fx6+FxxoSnNgwYX6LD7AKBTWkU0MQ6IBoe7dz069CNkR673sPAgkCQ==}
engines: {node: '>=14'}
'@electron/get@5.0.0':
resolution: {integrity: sha512-pjoBpru1KdEtcExBnuHAP1cAc/5faoedw0hzJkL3o4/IJp7HNF1+fbrdxT3gMYRX2oJfvnA/WXeCTVQpYYxyJA==}
engines: {node: '>=22.12.0'}
'@electron/notarize@2.5.0':
resolution: {integrity: sha512-jNT8nwH1f9X5GEITXaQ8IF/KdskvIkOFfB2CvwumsveVidzpSc+mvhhTMdAGSYF3O+Nq49lJ7y+ssODRXu06+A==}
engines: {node: '>= 10.0.0'}
@ -346,8 +350,8 @@ packages:
'@types/ms@2.1.0':
resolution: {integrity: sha512-GsCCIZDE/p3i96vtEqx+7dBUGXrc7zeSK3wwPHIaRThS+9OhWIXRqzs4d6k1SVU8g91DrNRWxWUGhp5KXQb2VA==}
'@types/node@24.12.0':
resolution: {integrity: sha512-GYDxsZi3ChgmckRT9HPU0WEhKLP08ev/Yfcq2AstjrDASOYCSXeyjDsHg4v5t4jOj7cyDX3vmprafKlWIG9MXQ==}
'@types/node@24.13.2':
resolution: {integrity: sha512-fRa09kZTgu8o71KFcDjUFuc7F+dEbZYZmkI0mg5YBTRs0yMKjYHsq/c0urDKeDb+D5qVgXOdFcuu+DZPKOITwA==}
'@types/node@25.5.0':
resolution: {integrity: sha512-jp2P3tQMSxWugkCUKLRPVUpGaL5MVFwF8RDuSRztfwgN1wmqJeMSbKlnEtQqU8UrhTmzEmZdu2I6v2dpp7XIxw==}
@ -361,9 +365,6 @@ packages:
'@types/verror@1.10.11':
resolution: {integrity: sha512-RlDm9K7+o5stv0Co8i8ZRGxDbrTxhJtgjqjFyVh/tXQyl/rYtTKlnTvZ88oSTeYREWurwx20Js4kTuKCsFkUtg==}
'@types/yauzl@2.10.3':
resolution: {integrity: sha512-oJoftv0LSuaDZE3Le4DbKX+KS9G36NzOeSap90UIK0yMA/NhKJhqlSGtNDORNRaIbQfzjXDrQa0ytJ6mNRGz/Q==}
'@xmldom/xmldom@0.8.11':
resolution: {integrity: sha512-cQzWCtO6C8TQiYl1ruKNn2U6Ao4o4WBBcbL61yJl84x+j5sOWWFU9X7DpND8XZG3daDppSsigMdfAIl2upQBRw==}
engines: {node: '>=10.0.0'}
@ -483,9 +484,6 @@ packages:
resolution: {integrity: sha512-h+DEnpVvxmfVefa4jFbCf5HdH5YMDXRsmKflpf1pILZWRFlTbJpxeU55nJl4Smt5HQaGzg1o6RHFPJaOqnmBDg==}
engines: {node: 18 || 20 || >=22}
buffer-crc32@0.2.13:
resolution: {integrity: sha512-VO9Ht/+p3SN7SKWqcrgEzjGbRSJYTx+Q1pTQC0wrWqHx0vpJraQ6GtHx8tvcg1rlK1byhU5gccxgOgj7B0TDkQ==}
buffer-from@1.1.2:
resolution: {integrity: sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==}
@ -714,9 +712,9 @@ packages:
resolution: {integrity: sha512-bO3y10YikuUwUuDUQRM4KfwNkKhnpVO7IPdbsrejwN9/AABJzzTQ4GeHwyzNSrVO+tEH3/Np255a3sVZpZDjvg==}
engines: {node: '>=8.0.0'}
electron@41.0.2:
resolution: {integrity: sha512-raotm/aO8kOs1jD8SI8ssJ7EKciQOY295AOOprl1TxW7B0At8m5Ae7qNU1xdMxofiHMR8cNEGi9PKD3U+yT/mA==}
engines: {node: '>= 12.20.55'}
electron@42.4.0:
resolution: {integrity: sha512-OXXqh9LD9KxXPv2Fe25EfU9N9AvWTuV6V81sfhQaNvTAXCd9ONA+Q4OWvMe+CmYD6xIwjFxGGtG/ZphDYYC5OQ==}
engines: {node: '>= 22.12.0'}
hasBin: true
emoji-regex@8.0.0:
@ -777,11 +775,6 @@ packages:
exponential-backoff@3.1.3:
resolution: {integrity: sha512-ZgEeZXj30q+I0EN+CbSSpIyPaJ5HVQD18Z1m+u1FXbAeT94mr1zw50q4q6jiiC447Nl/YTcIYSAftiGqetwXCA==}
extract-zip@2.0.1:
resolution: {integrity: sha512-GDhU9ntwuKyGXdZBUgTIe+vXnWj0fppUEtMDL0+idd5Sta8TGpHssn/eusA9mrPr9qNDym6SxAYZjNvCn/9RBg==}
engines: {node: '>= 10.17.0'}
hasBin: true
extsprintf@1.4.1:
resolution: {integrity: sha512-Wrk35e8ydCKDj/ArClo1VrPVmN8zph5V4AtHwIuHhvMXsKf73UT3BOD+azBIW+3wOJ4FhEH7zyaJCFvChjYvMA==}
engines: {'0': node >=0.6.0}
@ -795,9 +788,6 @@ packages:
fast-uri@3.1.0:
resolution: {integrity: sha512-iPeeDKJSWf4IEOasVVrknXpaBV0IApz/gp7S2bb7Z4Lljbl2MGJRqInZiUrQwV16cpzw/D3S5j5Julj/gT52AA==}
fd-slicer@1.1.0:
resolution: {integrity: sha512-cE1qsB/VwyQozZ+q1dGxR8LBYNZeofhEdUNGSMbQD3Gw2lAzX9Zb3uIU6Ebc/Fmyjo9AWWfnn0AUCHqtevs/8g==}
fdir@6.5.0:
resolution: {integrity: sha512-tIbYtZbucOs0BRGqPJkshJUYdL+SDH7dVM8gjy+ERp3WAUjLEFJE+02kanyHtwjWOnwrKYBiwAmM0p4kLJAnXg==}
engines: {node: '>=12.0.0'}
@ -838,6 +828,10 @@ packages:
resolution: {integrity: sha512-CTXd6rk/M3/ULNQj8FBqBWHYBVYybQ3VPBw0xGKFe3tuH7ytT6ACnvzpIQ3UZtB8yvUKC2cXn1a+x+5EVQLovA==}
engines: {node: '>=14.14'}
fs-extra@11.3.5:
resolution: {integrity: sha512-eKpRKAovdpZtR1WopLHxlBWvAgPny3c4gX1G5Jhwmmw4XJj0ifSD5qB5TOo8hmA0wlRKDAOAhEE1yVPgs6Fgcg==}
engines: {node: '>=14.14'}
fs-extra@7.0.1:
resolution: {integrity: sha512-YJDaCJZEnBmcbw13fvdAM9AwNOJwOzrE4pqMqBq5nFiEqXUqHwlK4B+3pUw6JNvfSPtX05xFHtYy/1ni01eGCw==}
engines: {node: '>=6 <7 || >=8'}
@ -1045,6 +1039,9 @@ packages:
jsonfile@6.2.0:
resolution: {integrity: sha512-FGuPw30AdOIUTRMC2OMRtQV+jkVj2cfPqSeWXv1NEAJ1qZ5zb1X6z1mFhbfOB/iy3ssJCD+3KuZ8r8C3uVFlAg==}
jsonfile@6.2.1:
resolution: {integrity: sha512-zwOTdL3rFQ/lRdBnntKVOX6k5cKJwEc1HdilT71BWEu7J41gXIB2MRp+vxduPSwZJPWBxEzv4yH1wYLJGUHX4Q==}
keyv@4.5.4:
resolution: {integrity: sha512-oxVHkHR/EJf2CNXnWxRLW6mg7JyCCUcG0DtEGmL2ctUo1PNTin1PUil+r/+4r5MpVgC/fn1kjsx7mjSujKqIpw==}
@ -1261,9 +1258,6 @@ packages:
resolution: {integrity: sha512-eRWB5LBz7PpDu4PUlwT0PhnQfTQJlDDdPa35urV4Osrm0t0AqQFGn+UIkU3klZvwJ8KPO3VbBFsXquA6p6kqZw==}
engines: {node: '>=12', npm: '>=6'}
pend@1.2.0:
resolution: {integrity: sha512-F3asv42UuXchdzt+xXqfW1OGlVBe+mxa2mqI0pg5yAHZPvFmY3Y6drSf/GQ1A86WgWEN9Kzh/WrgKa6iGcHXLg==}
picocolors@1.1.1:
resolution: {integrity: sha512-xceH2snhtb5M9liqDsmEw56le376mTZkEX/jEb/RxNFyegNul7eNslCXP9FDj/Lcu0X8KEyMceP2ntpaHrDEVA==}
@ -1397,6 +1391,11 @@ packages:
engines: {node: '>=10'}
hasBin: true
semver@7.8.5:
resolution: {integrity: sha512-Y7/KDsb8LjooZpwaqGyulO6DQlksgCncchHGk+sZIY4SBvUocMBEFH5Ur1fI4dV+Jvl0w6cjvucaIi40puRioA==}
engines: {node: '>=10'}
hasBin: true
serialize-error@7.0.1:
resolution: {integrity: sha512-8I8TjW5KMOKsZQTvoxjuSIa7foAwPWGOts+6o7sgjz41/qMD9VQHEDxi6PBvK2l0MXUmqZyNpUK+T2tQaaElvw==}
engines: {node: '>=10'}
@ -1554,12 +1553,13 @@ packages:
resolution: {integrity: sha512-rvKSBiC5zqCCiDZ9kAOszZcDvdAHwwIKJG33Ykj43OKcWsnmcBRL09YTU4nOeHZ8Y2a7l1MgTd08SBe9A8Qj6A==}
engines: {node: '>=18'}
undici-types@7.16.0:
resolution: {integrity: sha512-Zz+aZWSj8LE6zoxD+xrjh4VfkIG8Ya6LvYkZqtUQGJPZjYl53ypCaUwWqo7eI0x66KBGeRo+mlBEkMSeSZ38Nw==}
undici-types@7.18.2:
resolution: {integrity: sha512-AsuCzffGHJybSaRrmr5eHr81mwJU3kjw6M+uprWvCXiNeN9SOGwQ3Jn8jb8m3Z6izVgknn1R0FTCEAP2QrLY/w==}
undici@7.28.0:
resolution: {integrity: sha512-cRZYrTDwWznlnRiPjggAGxZXanty6M8RV1ff8Wm4LWXBp7/IG8v5DnOm74DtUBp9OONpK75YlPnIjQqX0dBDtA==}
engines: {node: '>=20.18.1'}
unique-filename@4.0.0:
resolution: {integrity: sha512-XSnEewXmQ+veP7xX2dS5Q4yZAvO40cBN2MWkJ7D/6sW4Dg6wYBNwM1Vrnz1FhH5AdeLIlUXRI9e28z1YZi71NQ==}
engines: {node: ^18.17.0 || >=20.5.0}
@ -1644,9 +1644,6 @@ packages:
resolution: {integrity: sha512-7dSzzRQ++CKnNI/krKnYRV7JKKPUXMEh61soaHKg9mrWEhzFWhFnxPxGl+69cD1Ou63C13NUPCnmIcrvqCuM6w==}
engines: {node: '>=12'}
yauzl@2.10.0:
resolution: {integrity: sha512-p4a9I6X6nu6IhoGmBqAcbJy1mlC4j27vEPZX9F4L4/vZT3Lyq1VkFHw/V/PUcB9Buo+DG3iHkT0x3Qya58zc3g==}
yocto-queue@0.1.0:
resolution: {integrity: sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q==}
engines: {node: '>=10'}
@ -1660,6 +1657,8 @@ snapshots:
ajv: 6.14.0
ajv-keywords: 3.5.2(ajv@6.14.0)
'@electron-internal/extract-zip@1.0.3': {}
'@electron/asar@3.4.1':
dependencies:
commander: 5.1.0
@ -1672,7 +1671,7 @@ snapshots:
fs-extra: 9.1.0
minimist: 1.2.8
'@electron/get@2.0.3':
'@electron/get@3.1.0':
dependencies:
debug: 4.4.3
env-paths: 2.2.1
@ -1686,17 +1685,16 @@ snapshots:
transitivePeerDependencies:
- supports-color
'@electron/get@3.1.0':
'@electron/get@5.0.0':
dependencies:
debug: 4.4.3
env-paths: 2.2.1
fs-extra: 8.1.0
got: 11.8.6
env-paths: 3.0.0
graceful-fs: 4.2.11
progress: 2.0.3
semver: 6.3.1
semver: 7.8.5
sumchecker: 3.0.1
optionalDependencies:
global-agent: 3.0.0
undici: 7.28.0
transitivePeerDependencies:
- supports-color
@ -1753,7 +1751,7 @@ snapshots:
dependencies:
cross-dirname: 0.1.0
debug: 4.4.3
fs-extra: 11.3.4
fs-extra: 11.3.5
minimist: 1.2.8
postject: 1.0.0-alpha.6
transitivePeerDependencies:
@ -1930,9 +1928,9 @@ snapshots:
'@types/ms@2.1.0': {}
'@types/node@24.12.0':
'@types/node@24.13.2':
dependencies:
undici-types: 7.16.0
undici-types: 7.18.2
'@types/node@25.5.0':
dependencies:
@ -1951,11 +1949,6 @@ snapshots:
'@types/verror@1.10.11':
optional: true
'@types/yauzl@2.10.3':
dependencies:
'@types/node': 25.5.0
optional: true
'@xmldom/xmldom@0.8.11': {}
abbrev@3.0.1: {}
@ -2100,8 +2093,6 @@ snapshots:
dependencies:
balanced-match: 4.0.4
buffer-crc32@0.2.13: {}
buffer-from@1.1.2: {}
buffer@5.7.1:
@ -2428,11 +2419,11 @@ snapshots:
transitivePeerDependencies:
- supports-color
electron@41.0.2:
electron@42.4.0:
dependencies:
'@electron/get': 2.0.3
'@types/node': 24.12.0
extract-zip: 2.0.1
'@electron-internal/extract-zip': 1.0.3
'@electron/get': 5.0.0
'@types/node': 24.13.2
transitivePeerDependencies:
- supports-color
@ -2509,16 +2500,6 @@ snapshots:
exponential-backoff@3.1.3: {}
extract-zip@2.0.1:
dependencies:
debug: 4.4.3
get-stream: 5.2.0
yauzl: 2.10.0
optionalDependencies:
'@types/yauzl': 2.10.3
transitivePeerDependencies:
- supports-color
extsprintf@1.4.1:
optional: true
@ -2528,10 +2509,6 @@ snapshots:
fast-uri@3.1.0: {}
fd-slicer@1.1.0:
dependencies:
pend: 1.2.0
fdir@6.5.0(picomatch@4.0.3):
optionalDependencies:
picomatch: 4.0.3
@ -2569,6 +2546,13 @@ snapshots:
jsonfile: 6.2.0
universalify: 2.0.1
fs-extra@11.3.5:
dependencies:
graceful-fs: 4.2.11
jsonfile: 6.2.1
universalify: 2.0.1
optional: true
fs-extra@7.0.1:
dependencies:
graceful-fs: 4.2.11
@ -2804,6 +2788,13 @@ snapshots:
optionalDependencies:
graceful-fs: 4.2.11
jsonfile@6.2.1:
dependencies:
universalify: 2.0.1
optionalDependencies:
graceful-fs: 4.2.11
optional: true
keyv@4.5.4:
dependencies:
json-buffer: 3.0.1
@ -3015,8 +3006,6 @@ snapshots:
pe-library@0.4.1: {}
pend@1.2.0: {}
picocolors@1.1.1: {}
picomatch@4.0.3: {}
@ -3136,6 +3125,8 @@ snapshots:
semver@7.7.4: {}
semver@7.8.5: {}
serialize-error@7.0.1:
dependencies:
type-fest: 0.13.1
@ -3295,10 +3286,11 @@ snapshots:
uint8array-extras@1.5.0: {}
undici-types@7.16.0: {}
undici-types@7.18.2: {}
undici@7.28.0:
optional: true
unique-filename@4.0.0:
dependencies:
unique-slug: 5.0.0
@ -3384,9 +3376,4 @@ snapshots:
y18n: 5.0.8
yargs-parser: 21.1.1
yauzl@2.10.0:
dependencies:
buffer-crc32: 0.2.13
fd-slicer: 1.1.0
yocto-queue@0.1.0: {}

View file

@ -40,8 +40,12 @@ export const IPC_CHANNELS = {
READ_AGENT_LOCAL_FILE_TEXT: 'agent-filesystem:read-local-file-text',
WRITE_AGENT_LOCAL_FILE_TEXT: 'agent-filesystem:write-local-file-text',
// Auth token sync across windows
GET_AUTH_TOKENS: 'auth:get-tokens',
SET_AUTH_TOKENS: 'auth:set-tokens',
GET_ACCESS_TOKEN: 'auth:get-access-token',
REFRESH_ACCESS_TOKEN: 'auth:refresh-access-token',
LOGOUT: 'auth:logout',
AUTH_CHANGED: 'auth:changed',
AUTH_START_GOOGLE: 'auth:start-google',
AUTH_LOGIN_PASSWORD: 'auth:login-password',
// Keyboard shortcut configuration
GET_SHORTCUTS: 'shortcuts:get',
SET_SHORTCUTS: 'shortcuts:set',

View file

@ -1,4 +1,4 @@
import { app, ipcMain, shell } from 'electron';
import { app, BrowserWindow, ipcMain, shell } from 'electron';
import { IPC_CHANNELS } from './channels';
import {
getPermissionsStatus,
@ -52,8 +52,64 @@ import {
type AgentFilesystemTreeWatchOptions,
} from '../modules/agent-filesystem-tree-watcher';
import { installDownloadedUpdate } from '../modules/auto-updater';
import { secretStore } from '../modules/secret-store';
import { startGoogleOAuth } from '../modules/oauth';
let authTokens: { bearer: string; refresh: string } | null = null;
const REFRESH_TOKEN_KEY = 'surfsense_refresh_token';
let accessToken: string | null = null;
let refreshInFlight: Promise<string | null> | null = null;
type DesktopAuthResponse = {
access_token?: string;
refresh_token?: string | null;
};
function getBackendUrl(): string {
return (process.env.HOSTED_BACKEND_URL || process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || '').replace(
/\/+$/,
''
);
}
function broadcastAuthChanged(): void {
for (const win of BrowserWindow.getAllWindows()) {
win.webContents.send(IPC_CHANNELS.AUTH_CHANGED, { authed: !!accessToken, accessToken });
}
}
async function storeTokens(tokens: { bearer: string; refresh?: string | null }): Promise<void> {
accessToken = tokens.bearer || null;
if (tokens.refresh) {
await secretStore.set(REFRESH_TOKEN_KEY, tokens.refresh);
}
broadcastAuthChanged();
}
async function refreshAccessToken(): Promise<string | null> {
if (refreshInFlight) return refreshInFlight;
refreshInFlight = (async () => {
const refresh = await secretStore.get(REFRESH_TOKEN_KEY);
const backendUrl = getBackendUrl();
if (!refresh || !backendUrl) return null;
const response = await fetch(`${backendUrl}/auth/jwt/refresh`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ refresh_token: refresh }),
});
if (!response.ok) return null;
const data = (await response.json()) as { access_token?: string; refresh_token?: string | null };
if (!data.access_token) return null;
await storeTokens({ bearer: data.access_token, refresh: data.refresh_token });
return data.access_token;
})().finally(() => {
refreshInFlight = null;
});
return refreshInFlight;
}
export function registerIpcHandlers(): void {
ipcMain.on(IPC_CHANNELS.OPEN_EXTERNAL, (_event, url: string) => {
@ -173,14 +229,81 @@ export function registerIpcHandlers(): void {
}
);
ipcMain.handle(IPC_CHANNELS.SET_AUTH_TOKENS, (_event, tokens: { bearer: string; refresh: string }) => {
authTokens = tokens;
ipcMain.handle(IPC_CHANNELS.GET_ACCESS_TOKEN, async () => {
if (!accessToken) {
await refreshAccessToken();
}
return accessToken;
});
ipcMain.handle(IPC_CHANNELS.GET_AUTH_TOKENS, () => {
return authTokens;
ipcMain.handle(IPC_CHANNELS.REFRESH_ACCESS_TOKEN, () => {
return refreshAccessToken();
});
ipcMain.handle(IPC_CHANNELS.LOGOUT, async () => {
const backendUrl = getBackendUrl();
const refresh = await secretStore.get(REFRESH_TOKEN_KEY);
if (backendUrl && refresh) {
try {
await fetch(`${backendUrl}/auth/jwt/revoke`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ refresh_token: refresh }),
});
} catch {
// Local logout is fail-closed even if the server revoke call fails.
}
}
accessToken = null;
await secretStore.clear(REFRESH_TOKEN_KEY);
broadcastAuthChanged();
});
ipcMain.handle(IPC_CHANNELS.AUTH_START_GOOGLE, async () => {
const backendUrl = getBackendUrl();
if (!backendUrl) {
throw new Error('Backend URL is not configured');
}
const tokens = await startGoogleOAuth(backendUrl);
await storeTokens({ bearer: tokens.access_token, refresh: tokens.refresh_token });
return { ok: true };
});
ipcMain.handle(
IPC_CHANNELS.AUTH_LOGIN_PASSWORD,
async (_event, payload: { email: string; password: string }) => {
const backendUrl = getBackendUrl();
if (!backendUrl) {
throw new Error('Backend URL is not configured');
}
const response = await fetch(`${backendUrl}/auth/desktop/login`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(payload),
});
if (!response.ok) {
let detail = 'Password login failed';
try {
const error = (await response.json()) as { detail?: string };
detail = error.detail || detail;
} catch {
// Keep the generic error if the backend did not return JSON.
}
throw new Error(detail);
}
const tokens = (await response.json()) as DesktopAuthResponse;
if (!tokens.access_token || !tokens.refresh_token) {
throw new Error('Password login did not return desktop tokens');
}
await storeTokens({ bearer: tokens.access_token, refresh: tokens.refresh_token });
return { ok: true };
}
);
ipcMain.handle(IPC_CHANNELS.GET_SHORTCUTS, () => getShortcuts());
ipcMain.handle(IPC_CHANNELS.GET_AUTO_LAUNCH, () => getAutoLaunchState());

View file

@ -17,6 +17,7 @@ import {
syncAutoLaunchOnStartup,
wasLaunchedAtLogin,
} from './modules/auto-launch';
import { purgeLegacyAuthCutover } from './modules/auth-cutover';
registerGlobalErrorHandlers();
app.setName('SurfSense');
@ -29,6 +30,7 @@ registerIpcHandlers();
app.whenReady().then(async () => {
initAnalytics();
await purgeLegacyAuthCutover();
const launchedAtLogin = wasLaunchedAtLogin();
const startedHidden = shouldStartHidden();
trackEvent('desktop_app_launched', {

View file

@ -0,0 +1,30 @@
import { app } from 'electron';
import { mkdir, readFile, writeFile } from 'node:fs/promises';
import path from 'node:path';
import { secretStore } from './secret-store';
const CUTOVER_FLAG_FILE = 'auth-cutover-v1.json';
const REFRESH_TOKEN_KEY = 'surfsense_refresh_token';
async function hasCompletedCutover(flagPath: string): Promise<boolean> {
try {
const raw = await readFile(flagPath, 'utf8');
return JSON.parse(raw)?.complete === true;
} catch {
return false;
}
}
export async function purgeLegacyAuthCutover(): Promise<void> {
const userDataPath = app.getPath('userData');
const flagPath = path.join(userDataPath, CUTOVER_FLAG_FILE);
if (await hasCompletedCutover(flagPath)) return;
await secretStore.clear(REFRESH_TOKEN_KEY);
await mkdir(userDataPath, { recursive: true });
await writeFile(
flagPath,
JSON.stringify({ complete: true, completedAt: new Date().toISOString() }),
{ mode: 0o600 }
);
}

View file

@ -22,8 +22,7 @@ function handleDeepLink(url: string) {
path: parsed.pathname,
});
if (parsed.hostname === 'auth' && parsed.pathname === '/callback') {
const params = parsed.searchParams.toString();
win.loadURL(`${getServerOrigin()}/auth/callback?${params}`);
win.loadURL(`${getServerOrigin()}/dashboard`);
}
win.show();

View file

@ -0,0 +1,72 @@
import http from 'node:http';
function escapeHtml(value: string): string {
return value
.replace(/&/g, '&amp;')
.replace(/</g, '&lt;')
.replace(/>/g, '&gt;')
.replace(/"/g, '&quot;')
.replace(/'/g, '&#39;');
}
function renderOAuthPage(title: string, message: string): string {
return `<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<title>${escapeHtml(title)}</title>
<style>
:root {
color-scheme: dark;
}
* {
box-sizing: border-box;
}
body {
margin: 0;
min-height: 100vh;
display: grid;
place-items: center;
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif;
background: #303030;
background: oklch(0.24 0 0);
color: #fafafa;
}
main {
width: min(420px, calc(100vw - 32px));
text-align: center;
}
h1 {
margin: 0 0 12px;
font-size: 24px;
line-height: 1.2;
letter-spacing: -0.02em;
}
p {
margin: 0;
color: #d4d4d4;
line-height: 1.5;
}
</style>
</head>
<body>
<main>
<h1>${escapeHtml(title)}</h1>
<p>${escapeHtml(message)}</p>
</main>
</body>
</html>`;
}
export function writeOAuthPage(
res: http.ServerResponse,
statusCode: number,
title: string,
message: string,
_tone?: 'success' | 'error' | 'neutral',
): void {
res
.writeHead(statusCode, { 'content-type': 'text/html; charset=utf-8' })
.end(renderOAuthPage(title, message));
}

View file

@ -0,0 +1,155 @@
import { shell } from 'electron';
import crypto from 'node:crypto';
import http from 'node:http';
import { writeOAuthPage } from './oauth-page';
export interface DesktopAuthTokens {
access_token: string;
refresh_token: string;
}
const OAUTH_TIMEOUT_MS = 5 * 60 * 1000;
const OAUTH_CALLBACK_PATH = '/callback';
function base64Url(buffer: Buffer): string {
return buffer.toString('base64').replace(/\+/g, '-').replace(/\//g, '_').replace(/=+$/, '');
}
function randomUrlSafe(bytes = 32): string {
return base64Url(crypto.randomBytes(bytes));
}
function sha256(value: string): string {
return base64Url(crypto.createHash('sha256').update(value).digest());
}
function getGoogleDesktopClientId(): string {
const clientId = (process.env.GOOGLE_DESKTOP_CLIENT_ID || '').trim();
if (!clientId) {
throw new Error('Google desktop OAuth client ID is not configured');
}
return clientId;
}
export async function startGoogleOAuth(backendUrl: string): Promise<DesktopAuthTokens> {
const clientId = getGoogleDesktopClientId();
const state = randomUrlSafe();
const codeVerifier = randomUrlSafe(64);
const codeChallenge = sha256(codeVerifier);
return new Promise((resolve, reject) => {
let settled = false;
let port: number | null = null;
let timeout: NodeJS.Timeout | null = null;
const cleanup = () => {
if (timeout) {
clearTimeout(timeout);
timeout = null;
}
if (server.listening) {
server.close();
}
};
const fail = (error: Error) => {
if (settled) return;
settled = true;
cleanup();
reject(error);
};
const succeed = (tokens: DesktopAuthTokens) => {
if (settled) return;
settled = true;
cleanup();
resolve(tokens);
};
const server = http.createServer(async (req, res) => {
try {
const url = new URL(req.url || '/', 'http://127.0.0.1');
if (url.pathname !== OAUTH_CALLBACK_PATH) {
writeOAuthPage(res, 404, 'Not found', 'This OAuth callback endpoint is only used by SurfSense.');
return;
}
const oauthError = url.searchParams.get('error');
if (oauthError) {
const description = url.searchParams.get('error_description');
writeOAuthPage(res, 400, 'Authentication failed', 'You can close this window and return to SurfSense.', 'error');
fail(new Error(description || `Google OAuth failed: ${oauthError}`));
return;
}
const code = url.searchParams.get('code');
const returnedState = url.searchParams.get('state');
if (!code || returnedState !== state) {
writeOAuthPage(res, 400, 'Authentication failed', 'You can close this window and return to SurfSense.', 'error');
fail(new Error('Invalid OAuth callback'));
return;
}
if (!port) {
writeOAuthPage(res, 500, 'Authentication failed', 'You can close this window and return to SurfSense.', 'error');
fail(new Error('OAuth loopback server was not ready'));
return;
}
const redirectUri = `http://127.0.0.1:${port}${OAUTH_CALLBACK_PATH}`;
const response = await fetch(`${backendUrl}/auth/desktop/session`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ code, code_verifier: codeVerifier, redirect_uri: redirectUri }),
});
if (!response.ok) {
let detail = 'Desktop session exchange failed';
try {
const error = (await response.json()) as { detail?: string };
detail = error.detail || detail;
} catch {
// Keep the generic exchange error if the backend did not return JSON.
}
writeOAuthPage(res, 401, 'Authentication failed', 'You can close this window and return to SurfSense.', 'error');
fail(new Error(detail));
return;
}
const tokens = (await response.json()) as DesktopAuthTokens;
writeOAuthPage(res, 200, 'Authentication complete', 'You can close this window and return to SurfSense.', 'success');
succeed(tokens);
} catch (error) {
fail(error instanceof Error ? error : new Error('Google OAuth failed'));
}
});
server.listen(0, '127.0.0.1', () => {
const addressInfo = server.address();
if (!addressInfo || typeof addressInfo === 'string') {
fail(new Error('Unable to bind loopback OAuth server'));
return;
}
port = addressInfo.port;
timeout = setTimeout(() => {
fail(new Error('Google OAuth timed out'));
}, OAUTH_TIMEOUT_MS);
const redirectUri = `http://127.0.0.1:${port}${OAUTH_CALLBACK_PATH}`;
const authUrl = new URL('https://accounts.google.com/o/oauth2/v2/auth');
authUrl.searchParams.set('client_id', clientId);
authUrl.searchParams.set('redirect_uri', redirectUri);
authUrl.searchParams.set('response_type', 'code');
authUrl.searchParams.set('scope', 'openid email profile');
authUrl.searchParams.set('state', state);
authUrl.searchParams.set('code_challenge', codeChallenge);
authUrl.searchParams.set('code_challenge_method', 'S256');
shell.openExternal(authUrl.toString()).catch((error) => {
fail(error instanceof Error ? error : new Error('Unable to open browser for Google OAuth'));
});
});
server.on('error', (error) => {
fail(error);
});
});
}

View file

@ -0,0 +1,86 @@
import { app, safeStorage } from 'electron';
import fs from 'node:fs/promises';
import path from 'node:path';
export interface SecretStore {
set(key: string, value: string): Promise<void>;
get(key: string): Promise<string | null>;
clear(key: string): Promise<void>;
isHardwareBacked(): Promise<boolean>;
}
const memoryStore = new Map<string, string>();
const storePath = path.join(app.getPath('userData'), 'secrets.enc.json');
async function readDiskStore(): Promise<Record<string, string>> {
try {
const raw = await fs.readFile(storePath, 'utf8');
return JSON.parse(raw) as Record<string, string>;
} catch {
return {};
}
}
async function writeDiskStore(data: Record<string, string>): Promise<void> {
await fs.mkdir(path.dirname(storePath), { recursive: true });
await fs.writeFile(storePath, JSON.stringify(data), { encoding: 'utf8', mode: 0o600 });
}
async function canPersistEncryptedSecrets(): Promise<boolean> {
try {
if (safeStorage.getSelectedStorageBackend?.() === 'basic_text') {
return false;
}
return await safeStorage.isAsyncEncryptionAvailable();
} catch {
return false;
}
}
export const secretStore: SecretStore = {
async set(key, value) {
if (!(await canPersistEncryptedSecrets())) {
memoryStore.set(key, value);
return;
}
const encrypted = await safeStorage.encryptStringAsync(value);
const data = await readDiskStore();
data[key] = encrypted.toString('base64');
await writeDiskStore(data);
},
async get(key) {
if (!(await canPersistEncryptedSecrets())) {
return memoryStore.get(key) ?? null;
}
const data = await readDiskStore();
const encoded = data[key];
if (!encoded) return null;
try {
const decrypted = await safeStorage.decryptStringAsync(Buffer.from(encoded, 'base64'));
if (decrypted.shouldReEncrypt) {
await this.set(key, decrypted.result);
}
return decrypted.result;
} catch {
await this.clear(key);
return null;
}
},
async clear(key) {
memoryStore.delete(key);
const data = await readDiskStore();
if (key in data) {
delete data[key];
await writeDiskStore(data);
}
},
async isHardwareBacked() {
return canPersistEncryptedSecrets();
},
};

View file

@ -94,6 +94,10 @@ export function createMainWindow(initialPath = '/dashboard'): BrowserWindow {
session.defaultSession.webRequest.onBeforeRequest(rewriteFilter, (details, callback) => {
try {
const u = new URL(details.url);
if (!u.pathname.includes('/connectors/callback')) {
callback({});
return;
}
const originalHost = u.host;
const local = new URL(getServerOrigin());
u.protocol = local.protocol;

View file

@ -80,9 +80,18 @@ contextBridge.exposeInMainWorld('electronAPI', {
ipcRenderer.invoke(IPC_CHANNELS.WRITE_AGENT_LOCAL_FILE_TEXT, virtualPath, content, searchSpaceId),
// Auth token sync across windows
getAuthTokens: () => ipcRenderer.invoke(IPC_CHANNELS.GET_AUTH_TOKENS),
setAuthTokens: (bearer: string, refresh: string) =>
ipcRenderer.invoke(IPC_CHANNELS.SET_AUTH_TOKENS, { bearer, refresh }),
getAccessToken: () => ipcRenderer.invoke(IPC_CHANNELS.GET_ACCESS_TOKEN),
refreshAccessToken: () => ipcRenderer.invoke(IPC_CHANNELS.REFRESH_ACCESS_TOKEN),
logout: () => ipcRenderer.invoke(IPC_CHANNELS.LOGOUT),
startGoogleOAuth: () => ipcRenderer.invoke(IPC_CHANNELS.AUTH_START_GOOGLE),
loginPassword: (email: string, password: string) =>
ipcRenderer.invoke(IPC_CHANNELS.AUTH_LOGIN_PASSWORD, { email, password }),
onAuthChanged: (callback: (payload: { authed: boolean; accessToken: string | null }) => void) => {
const listener = (_event: Electron.IpcRendererEvent, payload: { authed: boolean; accessToken: string | null }) =>
callback(payload);
ipcRenderer.on(IPC_CHANNELS.AUTH_CHANGED, listener);
return () => ipcRenderer.removeListener(IPC_CHANNELS.AUTH_CHANGED, listener);
},
// Keyboard shortcut configuration
getShortcuts: () => ipcRenderer.invoke(IPC_CHANNELS.GET_SHORTCUTS),

View file

@ -5,8 +5,8 @@ SurfSense supports ``AUTH_TYPE=LOCAL`` (email + password) and
There is no headless equivalent of the Google flow, so the harness handles
both modes by treating the JWT as the universal credential:
* **LOCAL**: harness POSTs form-encoded ``username`` + ``password`` to
``/auth/jwt/login``, reads ``{access_token, refresh_token}``.
* **LOCAL**: harness POSTs JSON ``email`` + ``password`` to
``/auth/desktop/login``, reads ``{access_token, refresh_token}``.
* **GOOGLE / pre-issued JWT**: operator pastes their existing JWT (and
optionally refresh token) into ``SURFSENSE_JWT`` /
``SURFSENSE_REFRESH_TOKEN``; harness skips login.
@ -22,7 +22,7 @@ MIRAGE runs.
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import Any
import httpx
@ -40,9 +40,8 @@ _NO_CREDENTIALS_MESSAGE = (
"No SurfSense credentials configured. Set ONE of:\n"
" (LOCAL) SURFSENSE_USER_EMAIL + SURFSENSE_USER_PASSWORD\n"
" (GOOGLE) SURFSENSE_JWT (and optionally SURFSENSE_REFRESH_TOKEN)\n"
"For GOOGLE: log in to SurfSense in your browser, open DevTools → "
"Application → Local Storage → copy `surfsense_bearer_token` and "
"`surfsense_refresh_token` into those env vars."
"For GOOGLE: use a PAT or operator-issued bearer token and set "
"SURFSENSE_JWT (plus SURFSENSE_REFRESH_TOKEN if available)."
)
@ -69,7 +68,7 @@ async def acquire_token(config: Config, *, http: httpx.AsyncClient | None = None
1. ``SURFSENSE_JWT`` set use it directly. Refresh token captured if
supplied.
2. ``SURFSENSE_USER_EMAIL`` + ``SURFSENSE_USER_PASSWORD`` set
form-encoded POST to ``/auth/jwt/login``.
JSON POST to ``/auth/desktop/login``.
3. Neither raise ``CredentialError``.
The optional ``http`` argument lets tests inject a mocked client; if
@ -86,9 +85,9 @@ async def acquire_token(config: Config, *, http: httpx.AsyncClient | None = None
if config.has_local_mode():
async def _login(client: httpx.AsyncClient) -> TokenBundle:
response = await client.post(
f"{config.surfsense_api_base}/auth/jwt/login",
data={
"username": config.surfsense_user_email,
f"{config.surfsense_api_base}/auth/desktop/login",
json={
"email": config.surfsense_user_email,
"password": config.surfsense_user_password,
},
headers={"Accept": "application/json"},

View file

@ -46,8 +46,8 @@ async def test_acquire_token_jwt_mode_short_circuits():
@pytest.mark.asyncio
@respx.mock
async def test_acquire_token_local_mode_posts_form():
respx.post("http://test/auth/jwt/login").mock(
async def test_acquire_token_local_mode_posts_desktop_login_json():
respx.post("http://test/auth/desktop/login").mock(
return_value=httpx.Response(
200, json={"access_token": "T", "refresh_token": "R", "token_type": "bearer"}
)

View file

@ -41,6 +41,10 @@ NEXT_PUBLIC_POSTHOG_HOST=https://us.i.posthog.com
# "/zero" endpoint behind Caddy. Set it for local dev or packaged clients.
# ─────────────────────────────────────────────────────────────────────────────
# NEXT_PUBLIC_ZERO_CACHE_URL=http://localhost:4848
# Server-only shared secret that authorizes zero-cache when it calls
# /api/zero/query. Leave unset during the compatibility rollout, then set it
# once every zero-cache instance sends X-Api-Key.
# ZERO_QUERY_API_KEY=
# ─────────────────────────────────────────────────────────────────────────────
# Cloudflare Turnstile CAPTCHA for anonymous chat abuse prevention

View file

@ -11,6 +11,7 @@ import { useRuntimeConfig } from "@/components/providers/runtime-config";
import { Button } from "@/components/ui/button";
import { Spinner } from "@/components/ui/spinner";
import { getAuthErrorDetails, isNetworkError } from "@/lib/auth-errors";
import { getPostLoginRedirectPath } from "@/lib/auth-utils";
import { ValidationError } from "@/lib/error";
import { trackLoginAttempt, trackLoginFailure, trackLoginSuccess } from "@/lib/posthog/events";
@ -38,7 +39,7 @@ export function LocalLoginForm() {
trackLoginAttempt("local");
try {
const data = await login({
await login({
username,
password,
grant_type: "password",
@ -47,14 +48,9 @@ export function LocalLoginForm() {
// Track successful login
trackLoginSuccess("local");
// Set flag so TokenHandler knows local login was already tracked
if (typeof window !== "undefined") {
sessionStorage.setItem("login_success_tracked", "true");
}
// Small delay to show success message
setTimeout(() => {
router.push(`/auth/callback?token=${data.access_token}`);
router.push(getPostLoginRedirectPath());
}, 500);
} catch (err) {
if (err instanceof ValidationError) {

View file

@ -30,8 +30,7 @@ function LoginContent() {
const logout = searchParams.get("logout");
const returnUrl = searchParams.get("returnUrl");
// Save returnUrl to localStorage so it persists through OAuth flows (e.g., Google)
// This is read by TokenHandler after successful authentication
// Save returnUrl for client-side login flows that can redirect directly after success.
if (returnUrl) {
setRedirectPath(decodeURIComponent(returnUrl));
}

View file

@ -12,8 +12,8 @@ import { Logo } from "@/components/Logo";
import { useRuntimeConfig } from "@/components/providers/runtime-config";
import { Button } from "@/components/ui/button";
import { Spinner } from "@/components/ui/spinner";
import { useSession } from "@/hooks/use-session";
import { getAuthErrorDetails, isNetworkError, shouldRetry } from "@/lib/auth-errors";
import { getBearerToken } from "@/lib/auth-utils";
import { AppError, ValidationError } from "@/lib/error";
import {
trackRegistrationAttempt,
@ -37,18 +37,19 @@ export default function RegisterPage() {
message: null,
});
const router = useRouter();
const session = useSession();
const [{ mutateAsync: register, isPending: isRegistering }] = useAtom(registerMutationAtom);
// Check authentication type and redirect if not LOCAL
useEffect(() => {
if (getBearerToken()) {
if (session.status === "authenticated") {
router.replace("/dashboard");
return;
}
if (authType !== "LOCAL") {
router.push("/login");
}
}, [authType, router]);
}, [authType, router, session.status]);
const handleSubmit = (e: React.FormEvent) => {
e.preventDefault();

View file

@ -12,45 +12,66 @@ import { schema } from "@/zero/schema";
// (e.g. http://localhost:8929) does NOT resolve from inside the frontend
// container and would make every authenticated Zero query fail with a 503.
const backendURL = SERVER_BACKEND_URL.replace(/\/$/, "");
const zeroQueryApiKey = process.env.ZERO_QUERY_API_KEY;
function validateZeroCacheRequest(request: Request): NextResponse | null {
if (!zeroQueryApiKey) return null;
if (request.headers.get("X-Api-Key") === zeroQueryApiKey) return null;
return NextResponse.json({ error: "Forbidden" }, { status: 403 });
}
async function authenticateRequest(
request: Request
): Promise<{ ctx: Context; error?: never } | { ctx?: never; error: NextResponse }> {
): Promise<
{ ctx: Exclude<Context, undefined>; error?: never } | { ctx?: never; error: NextResponse }
> {
const authHeader = request.headers.get("Authorization");
if (!authHeader?.startsWith("Bearer ")) {
return { ctx: undefined };
const cookieHeader = request.headers.get("Cookie");
const headers: HeadersInit = {};
if (authHeader?.startsWith("Bearer ")) {
headers.Authorization = authHeader;
} else if (cookieHeader) {
headers.Cookie = cookieHeader;
} else {
return { error: NextResponse.json({ error: "Unauthorized" }, { status: 401 }) };
}
try {
const res = await fetch(`${backendURL}/users/me`, {
headers: { Authorization: authHeader },
const res = await fetch(`${backendURL}/zero/context`, {
headers,
});
if (!res.ok) {
return { error: NextResponse.json({ error: "Unauthorized" }, { status: 401 }) };
}
const user = await res.json();
return { ctx: { userId: String(user.id) } };
const ctx = (await res.json()) as Exclude<Context, undefined>;
return { ctx };
} catch {
return { error: NextResponse.json({ error: "Auth service unavailable" }, { status: 503 }) };
}
}
export async function POST(request: Request) {
const forbidden = validateZeroCacheRequest(request);
if (forbidden) {
return forbidden;
}
const auth = await authenticateRequest(request);
if (auth.error) {
return auth.error;
}
const result = await handleQueryRequest(
(name, args) => {
const result = await handleQueryRequest({
handler: (name, args) => {
const query = mustGetQuery(queries, name);
return query.fn({ args, ctx: auth.ctx });
},
schema,
request
);
request,
userID: auth.ctx.userId,
});
return NextResponse.json(result);
}

View file

@ -1,14 +0,0 @@
"use client";
import { Suspense } from "react";
import TokenHandler from "@/components/TokenHandler";
export default function AuthCallbackPage() {
// Suspense fallback returns null - the GlobalLoadingProvider handles the loading UI
// TokenHandler uses useGlobalLoadingEffect to show the loading screen
return (
<Suspense fallback={null}>
<TokenHandler redirectPath="/dashboard" tokenParamName="token" />
</Suspense>
);
}

View file

@ -69,7 +69,7 @@ import { useMessagesSync } from "@/hooks/use-messages-sync";
import { useThreadDetail, useThreadMessages } from "@/hooks/use-thread-queries";
import { getAgentFilesystemSelection } from "@/lib/agent-filesystem";
import { documentsApiService } from "@/lib/apis/documents-api.service";
import { getBearerToken } from "@/lib/auth-utils";
import { getDesktopAccessToken } from "@/lib/auth-fetch";
import { type ChatFlow, classifyChatError } from "@/lib/chat/chat-error-classifier";
import { tagPreAcceptSendFailure, toHttpResponseError } from "@/lib/chat/chat-request-errors";
import { getMentionDocKey } from "@/lib/chat/mention-doc-key";
@ -917,29 +917,26 @@ export default function NewChatPage() {
// Cancel ongoing request
const cancelRun = useCallback(async () => {
if (threadId) {
const token = getBearerToken();
if (token) {
try {
const response = await fetch(
buildBackendUrl(`/api/v1/threads/${threadId}/cancel-active-turn`),
{
method: "POST",
headers: {
Authorization: `Bearer ${token}`,
},
}
);
if (response.ok) {
const payload = (await response.json()) as {
error_code?: string;
};
if (payload.error_code === "TURN_CANCELLING") {
recentCancelRequestedAtRef.current = Date.now();
}
const token = await getDesktopAccessToken();
try {
const response = await fetch(
buildBackendUrl(`/api/v1/threads/${threadId}/cancel-active-turn`),
{
method: "POST",
headers: token ? { Authorization: `Bearer ${token}` } : undefined,
credentials: "include",
}
);
if (response.ok) {
const payload = (await response.json()) as {
error_code?: string;
};
if (payload.error_code === "TURN_CANCELLING") {
recentCancelRequestedAtRef.current = Date.now();
}
} catch (error) {
console.warn("[NewChatPage] Failed to signal cancel-active-turn:", error);
}
} catch (error) {
console.warn("[NewChatPage] Failed to signal cancel-active-turn:", error);
}
}
if (abortControllerRef.current) {
@ -964,11 +961,7 @@ export default function NewChatPage() {
if (!userQuery.trim() && userImages.length === 0) return;
const token = getBearerToken();
if (!token) {
toast.error("Not authenticated. Please log in again.");
return;
}
const token = await getDesktopAccessToken();
// Lazy thread creation: create thread on first message if it doesn't exist
let currentThreadId = threadId;
@ -1149,8 +1142,9 @@ export default function NewChatPage() {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${token}`,
...(token ? { Authorization: `Bearer ${token}` } : {}),
},
credentials: "include",
body: JSON.stringify({
chat_id: currentThreadId,
user_query: userQuery.trim(),
@ -1537,12 +1531,7 @@ export default function NewChatPage() {
stagedDecisionsByInterruptIdRef.current.clear();
setIsRunning(true);
const token = getBearerToken();
if (!token) {
toast.error("Not authenticated. Please log in again.");
setIsRunning(false);
return;
}
const token = await getDesktopAccessToken();
const controller = new AbortController();
abortControllerRef.current = controller;
@ -1648,8 +1637,9 @@ export default function NewChatPage() {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${token}`,
...(token ? { Authorization: `Bearer ${token}` } : {}),
},
credentials: "include",
body: JSON.stringify({
search_space_id: searchSpaceId,
decisions,
@ -1981,11 +1971,7 @@ export default function NewChatPage() {
abortControllerRef.current = null;
}
const token = getBearerToken();
if (!token) {
toast.error("Not authenticated. Please log in again.");
return;
}
const token = await getDesktopAccessToken();
// Extract the original user query BEFORE removing messages (for reload mode)
let userQueryToDisplay: string | undefined;
@ -2104,8 +2090,9 @@ export default function NewChatPage() {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${token}`,
...(token ? { Authorization: `Bearer ${token}` } : {}),
},
credentials: "include",
body: JSON.stringify(requestBody),
signal: controller.signal,
})

View file

@ -13,13 +13,15 @@ import { Logo } from "@/components/Logo";
import { ModelProviderConnectionsPanel } from "@/components/settings/model-connections/model-provider-connections-panel";
import { Button } from "@/components/ui/button";
import { useGlobalLoadingEffect } from "@/hooks/use-global-loading";
import { getBearerToken, redirectToLogin } from "@/lib/auth-utils";
import { useSession } from "@/hooks/use-session";
import { redirectToLogin } from "@/lib/auth-utils";
import { hasEnabledChatModel, isLlmOnboardingComplete } from "@/lib/onboarding";
export default function OnboardPage() {
const router = useRouter();
const params = useParams();
const searchSpaceId = Number(params.search_space_id);
const session = useSession();
const { data: globalConnections = [], isLoading: globalLoading } = useAtomValue(
globalModelConnectionsAtom
);
@ -29,8 +31,8 @@ export default function OnboardPage() {
useAtomValue(globalLlmConfigStatusAtom);
useEffect(() => {
if (!getBearerToken()) redirectToLogin();
}, []);
if (session.status === "unauthenticated") redirectToLogin();
}, [session.status]);
const hasUsableChatModel = useMemo(
() => hasEnabledChatModel([...globalConnections, ...connections]),
@ -43,7 +45,8 @@ export default function OnboardPage() {
connections
);
const isLoading = globalLoading || rolesLoading || globalConfigStatusLoading;
const isLoading =
session.status === "loading" || globalLoading || rolesLoading || globalConfigStatusLoading;
// Onboarding only applies when no global_llm_config.yaml exists. If a global
// config is present (or onboarding is already complete), leave this page.

View file

@ -1,10 +1,20 @@
"use client";
import { Check, Copy, Info, Plus, Trash2 } from "lucide-react";
import { Check, Copy, Info, Trash2 } from "lucide-react";
import { useCallback, useMemo, useState } from "react";
import { Alert, AlertDescription } from "@/components/ui/alert";
import { Badge } from "@/components/ui/badge";
import {
AlertDialog,
AlertDialogAction,
AlertDialogCancel,
AlertDialogContent,
AlertDialogDescription,
AlertDialogFooter,
AlertDialogHeader,
AlertDialogTitle,
} from "@/components/ui/alert-dialog";
import { Button } from "@/components/ui/button";
import { Card, CardContent } from "@/components/ui/card";
import {
Dialog,
DialogContent,
@ -16,6 +26,7 @@ import {
import { Input } from "@/components/ui/input";
import { Label } from "@/components/ui/label";
import { Skeleton } from "@/components/ui/skeleton";
import { Spinner } from "@/components/ui/spinner";
import { usePats } from "@/hooks/use-pats";
import { copyToClipboard as copyToClipboardUtil } from "@/lib/utils";
@ -26,6 +37,7 @@ export function ApiKeyContent() {
const [label, setLabel] = useState("");
const [expiresInDays, setExpiresInDays] = useState("");
const [copiedToken, setCopiedToken] = useState(false);
const [deleteTarget, setDeleteTarget] = useState<{ id: number; label: string } | null>(null);
const sortedTokens = useMemo(() => tokens, [tokens]);
@ -51,95 +63,112 @@ export function ApiKeyContent() {
}
}, [createdToken]);
const handleDelete = useCallback(
async (id: number, tokenLabel: string) => {
if (!window.confirm(`Delete personal access token "${tokenLabel}"? This cannot be undone.`)) {
return;
}
await deleteToken(id);
},
[deleteToken]
);
const handleConfirmDelete = useCallback(async () => {
if (!deleteTarget) return;
await deleteToken(deleteTarget.id);
setDeleteTarget(null);
}, [deleteTarget, deleteToken]);
return (
<div className="space-y-6 min-w-0 overflow-hidden">
<div className="space-y-6 min-w-0">
<Alert>
<Info />
<AlertDescription>
Personal access tokens are long-lived credentials for extensions, Obsidian, and
programmatic API clients. Copy a token when you create it; it is shown only once.
API keys let extensions, Obsidian, and other apps connect to SurfSense.
</AlertDescription>
</Alert>
<div className="flex items-center justify-between gap-3">
<div>
<h3 className="text-sm font-semibold tracking-tight">Personal access tokens</h3>
<h3 className="text-sm font-semibold tracking-tight">API keys</h3>
<p className="text-xs text-muted-foreground">
Expired tokens stay listed until you delete them.
Expired API keys stay listed until you delete them.
</p>
</div>
<Button size="sm" onClick={() => setCreateOpen(true)}>
<Plus className="mr-2 h-4 w-4" />
Create token
Create API key
</Button>
</div>
<div className="min-w-0 overflow-hidden rounded-lg border border-border/60">
{isLoading ? (
<div className="space-y-3 p-4">
<Skeleton className="h-12 w-full" />
<Skeleton className="h-12 w-full" />
</div>
) : sortedTokens.length > 0 ? (
<div className="divide-y divide-border/60">
{sortedTokens.map((token) => {
const expiresAt = token.expires_at ? new Date(token.expires_at) : null;
const isExpired = expiresAt ? expiresAt.getTime() <= Date.now() : false;
return (
<div key={token.id} className="flex items-center gap-3 p-4">
{isLoading ? (
<div className="-m-1 grid grid-cols-1 gap-3 p-1">
{["skeleton-a", "skeleton-b"].map((key) => (
<Card
key={key}
className="group relative overflow-hidden transition-all duration-200 border-accent bg-accent/20 hover:shadow-md h-full"
>
<CardContent className="p-4 flex flex-col gap-3 h-full min-h-24">
<Skeleton className="h-4 w-32 md:w-40 bg-accent" />
<Skeleton className="h-3 w-full bg-accent" />
<Skeleton className="h-3 w-24 md:w-28 bg-accent" />
</CardContent>
</Card>
))}
</div>
) : sortedTokens.length > 0 ? (
<div className="-m-1 grid grid-cols-1 gap-3 p-1">
{sortedTokens.map((token) => {
const expiresAt = token.expires_at ? new Date(token.expires_at) : null;
const isExpired = expiresAt ? expiresAt.getTime() <= Date.now() : false;
return (
<Card
key={token.id}
className="group relative overflow-hidden transition-all duration-200 border-accent bg-accent/20 hover:shadow-md h-full"
>
<CardContent className="flex min-h-24 items-center gap-3 p-4">
<div className="min-w-0 flex-1">
<div className="flex items-center gap-2">
<p className="truncate text-sm font-medium">{token.label}</p>
{isExpired ? <Badge variant="secondary">Expired</Badge> : null}
<div className="flex flex-col gap-1">
<div className="flex items-center gap-2">
<h4 className="truncate text-sm font-semibold tracking-tight">
{token.label}
</h4>
{isExpired ? (
<span className="rounded-md border-0 bg-muted px-1.5 py-0.5 text-[10px] font-medium text-muted-foreground">
Expired
</span>
) : null}
</div>
<p className="truncate font-mono text-xs text-muted-foreground">
{token.prefix}...
</p>
<p className="text-xs text-muted-foreground">
Expires: {expiresAt ? expiresAt.toLocaleDateString() : "Never"} · Last used:{" "}
{token.last_used_at ? new Date(token.last_used_at).toLocaleString() : "Never"}
</p>
</div>
<p className="font-mono text-xs text-muted-foreground">{token.prefix}...</p>
<p className="text-xs text-muted-foreground">
Expires: {expiresAt ? expiresAt.toLocaleDateString() : "Never"} · Last used:{" "}
{token.last_used_at
? new Date(token.last_used_at).toLocaleString()
: "Never"}
</p>
</div>
<Button
variant="ghost"
size="icon"
disabled={isMutating}
onClick={() => handleDelete(token.id, token.label)}
onClick={() => setDeleteTarget({ id: token.id, label: token.label })}
className="h-7 w-7 shrink-0 rounded-lg text-muted-foreground transition-opacity duration-150 hover:text-accent-foreground sm:opacity-0 sm:pointer-events-none sm:group-hover:opacity-100 sm:group-hover:pointer-events-auto"
>
<Trash2 className="h-4 w-4 text-muted-foreground" />
<Trash2 className="h-4 w-4" />
</Button>
</div>
);
})}
</div>
) : (
<p className="p-6 text-center text-sm text-muted-foreground">
No personal access tokens yet.
</p>
)}
</div>
</CardContent>
</Card>
);
})}
</div>
) : (
<p className="py-6 text-center text-sm text-muted-foreground">
No API keys yet.
</p>
)}
<Dialog open={createOpen} onOpenChange={setCreateOpen}>
<DialogContent>
<DialogHeader>
<DialogTitle>Create personal access token</DialogTitle>
<DialogTitle>Create API key</DialogTitle>
<DialogDescription>
Name this token so you can recognize where it is used later.
Name this API key so you can recognize where it is used later.
</DialogDescription>
</DialogHeader>
<div className="space-y-4">
<div className="space-y-2">
<Label htmlFor="pat-label">Label</Label>
<Label htmlFor="pat-label">Name</Label>
<Input
id="pat-label"
value={label}
@ -160,11 +189,24 @@ export function ApiKeyContent() {
</div>
</div>
<DialogFooter>
<Button variant="outline" onClick={() => setCreateOpen(false)}>
<Button
type="button"
variant="secondary"
size="sm"
onClick={() => setCreateOpen(false)}
disabled={isMutating}
className="text-sm h-9"
>
Cancel
</Button>
<Button disabled={isMutating || !label.trim()} onClick={handleCreate}>
Create token
<Button
size="sm"
disabled={isMutating || !label.trim()}
onClick={handleCreate}
className="relative text-sm h-9 min-w-[128px]"
>
<span className={isMutating ? "opacity-0" : ""}>Create API key</span>
{isMutating && <Spinner size="sm" className="absolute" />}
</Button>
</DialogFooter>
</DialogContent>
@ -173,17 +215,21 @@ export function ApiKeyContent() {
<Dialog open={!!createdToken} onOpenChange={(open) => !open && setCreatedToken(null)}>
<DialogContent>
<DialogHeader>
<DialogTitle>Copy your token now</DialogTitle>
<DialogTitle>Copy your API key now</DialogTitle>
<DialogDescription>
This token is shown only once. Store it somewhere secure before closing this
dialog.
This API key is shown only once. Store it somewhere secure before closing this dialog.
</DialogDescription>
</DialogHeader>
<div className="flex items-center gap-2 rounded-md border border-border/60 bg-muted/30 p-2">
<code className="min-w-0 flex-1 overflow-x-auto whitespace-nowrap text-xs">
{createdToken?.token}
</code>
<Button variant="outline" size="sm" onClick={copyCreatedToken}>
<Button
variant="outline"
size="sm"
onClick={copyCreatedToken}
className="border-0 bg-muted/30 hover:bg-muted/50"
>
{copiedToken ? <Check className="h-4 w-4" /> : <Copy className="h-4 w-4" />}
</Button>
</div>
@ -192,6 +238,41 @@ export function ApiKeyContent() {
</DialogFooter>
</DialogContent>
</Dialog>
<AlertDialog
open={deleteTarget !== null}
onOpenChange={(open) => !open && setDeleteTarget(null)}
>
<AlertDialogContent>
<AlertDialogHeader>
<AlertDialogTitle>Delete API key?</AlertDialogTitle>
<AlertDialogDescription>
<span className="font-medium text-foreground">{deleteTarget?.label}</span> will be
permanently removed. This cannot be undone.
</AlertDialogDescription>
</AlertDialogHeader>
<AlertDialogFooter>
<AlertDialogCancel disabled={isMutating}>Cancel</AlertDialogCancel>
<AlertDialogAction
disabled={isMutating}
className="bg-destructive text-white hover:bg-destructive/90"
onClick={(event) => {
event.preventDefault();
void handleConfirmDelete();
}}
>
{isMutating ? (
<span className="inline-flex items-center gap-2">
<Spinner size="xs" />
Deleting...
</span>
) : (
"Delete"
)}
</AlertDialogAction>
</AlertDialogFooter>
</AlertDialogContent>
</AlertDialog>
</div>
);
}

View file

@ -38,13 +38,13 @@ export function CommunityPromptsContent() {
const list = prompts ?? [];
return (
<div className="space-y-6 min-w-0 overflow-hidden">
<div className="space-y-6 min-w-0">
<p className="text-sm text-muted-foreground">
Prompts shared by other users. Add any to your collection with one click.
</p>
{isLoading && (
<div className="space-y-2">
<div className="-m-1 space-y-2 p-1">
{["skeleton-a", "skeleton-b", "skeleton-c"].map((key) => (
<Card key={key} className="border-accent bg-accent/20">
<CardContent className="p-4 flex flex-col gap-3 min-h-24">
@ -76,7 +76,7 @@ export function CommunityPromptsContent() {
)}
{!isLoading && !isError && list.length > 0 && (
<div className="space-y-2">
<div className="-m-1 space-y-2 p-1">
{list.map((prompt) => (
<Card
key={prompt.id}

View file

@ -19,7 +19,7 @@ import { Separator } from "@/components/ui/separator";
import { Skeleton } from "@/components/ui/skeleton";
import type { SearchSpace } from "@/contracts/types/search-space.types";
import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service";
import { authenticatedFetch } from "@/lib/auth-utils";
import { authenticatedFetch } from "@/lib/auth-fetch";
import { buildBackendUrl } from "@/lib/env-config";
import { cn } from "@/lib/utils";

View file

@ -148,7 +148,7 @@ export function PromptsContent() {
const list = prompts ?? [];
return (
<div className="space-y-6 min-w-0 overflow-hidden">
<div className="space-y-6 min-w-0">
<div className="flex items-center justify-between">
<p className="text-sm text-muted-foreground">
Create prompt templates triggered with <ShortcutKbd keys={["/"]} className="ml-0" /> in
@ -276,7 +276,7 @@ export function PromptsContent() {
</Dialog>
{isLoading && (
<div className="space-y-2">
<div className="-m-1 space-y-2 p-1">
{["skeleton-a", "skeleton-b", "skeleton-c"].map((key) => (
<Card key={key} className="border-accent bg-accent/20">
<CardContent className="p-4 flex flex-col gap-3 min-h-24">
@ -308,7 +308,7 @@ export function PromptsContent() {
)}
{!isLoading && !isError && list.length > 0 && (
<div className="space-y-2">
<div className="-m-1 space-y-2 p-1">
{list.map((prompt) => (
<div
key={prompt.id}

View file

@ -3,31 +3,29 @@
import { useEffect, useState } from "react";
import { USER_QUERY_KEY } from "@/atoms/user/user-query.atoms";
import { useGlobalLoadingEffect } from "@/hooks/use-global-loading";
import { ensureTokensFromElectron, getBearerToken, redirectToLogin } from "@/lib/auth-utils";
import { useSession } from "@/hooks/use-session";
import { redirectToLogin } from "@/lib/auth-utils";
import { queryClient } from "@/lib/query-client/client";
export function DashboardShell({ children }: { children: React.ReactNode }) {
const [isCheckingAuth, setIsCheckingAuth] = useState(true);
const session = useSession();
// Use the global loading screen - spinner animation won't reset
useGlobalLoadingEffect(isCheckingAuth);
useEffect(() => {
async function checkAuth() {
let token = getBearerToken();
if (!token) {
const synced = await ensureTokensFromElectron();
if (synced) token = getBearerToken();
}
if (!token) {
if (session.status === "loading") return;
if (session.status === "unauthenticated") {
redirectToLogin();
return;
}
queryClient.invalidateQueries({ queryKey: [...USER_QUERY_KEY] });
setIsCheckingAuth(false);
}
checkAuth();
}, []);
void checkAuth();
}, [session.status]);
// Return null while loading - the global provider handles the loading UI
if (isCheckingAuth) {

View file

@ -9,13 +9,7 @@ import { useEffect, useState } from "react";
import { searchSpacesAtom } from "@/atoms/search-spaces/search-space-query.atoms";
import { CreateSearchSpaceDialog } from "@/components/layout";
import { Button } from "@/components/ui/button";
import {
Card,
CardDescription,
CardFooter,
CardHeader,
CardTitle,
} from "@/components/ui/card";
import { Card, CardDescription, CardFooter, CardHeader, CardTitle } from "@/components/ui/card";
import { useGlobalLoadingEffect } from "@/hooks/use-global-loading";
function ErrorScreen({ message }: { message: string }) {

View file

@ -1,12 +1,10 @@
"use client";
import { useAtom } from "jotai";
import { Crop, Eye, EyeOff, Rocket, RotateCcw, Zap } from "lucide-react";
import Image from "next/image";
import { useRouter } from "next/navigation";
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { toast } from "sonner";
import { loginMutationAtom } from "@/atoms/auth/auth-mutation.atoms";
import { DEFAULT_SHORTCUTS, keyEventToAccelerator } from "@/components/desktop/shortcut-recorder";
import { useIsGoogleAuth } from "@/components/providers/runtime-config";
import { Button } from "@/components/ui/button";
@ -17,8 +15,7 @@ import { ShortcutKbd } from "@/components/ui/shortcut-kbd";
import { Spinner } from "@/components/ui/spinner";
import { useElectronAPI } from "@/hooks/use-platform";
import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service";
import { setBearerToken } from "@/lib/auth-utils";
import { buildBackendUrl } from "@/lib/env-config";
import { getPostLoginRedirectPath } from "@/lib/auth-utils";
type ShortcutKey = "generalAssist" | "quickAsk" | "screenshotAssist";
type ShortcutMap = typeof DEFAULT_SHORTCUTS;
@ -190,12 +187,12 @@ export default function DesktopLoginPage() {
const router = useRouter();
const api = useElectronAPI();
const isGoogleAuth = useIsGoogleAuth();
const [{ mutateAsync: login, isPending: isLoggingIn }] = useAtom(loginMutationAtom);
const [email, setEmail] = useState("");
const [password, setPassword] = useState("");
const [showPassword, setShowPassword] = useState(false);
const [loginError, setLoginError] = useState<string | null>(null);
const [isLoggingIn, setIsLoggingIn] = useState(false);
const [isGoogleRedirecting, setIsGoogleRedirecting] = useState(false);
const [shortcuts, setShortcuts] = useState(DEFAULT_SHORTCUTS);
@ -237,10 +234,17 @@ export default function DesktopLoginPage() {
[updateShortcut]
);
const handleGoogleLogin = () => {
const handleGoogleLogin = async () => {
if (isGoogleRedirecting) return;
setIsGoogleRedirecting(true);
window.location.href = buildBackendUrl("/auth/google/authorize-redirect");
try {
await api?.startGoogleOAuth?.();
await autoSetSearchSpace();
router.push(getPostLoginRedirectPath());
} catch (error) {
setIsGoogleRedirecting(false);
toast.error(error instanceof Error ? error.message : "Google sign-in failed");
}
};
const autoSetSearchSpace = async () => {
@ -259,23 +263,19 @@ export default function DesktopLoginPage() {
const handleLocalLogin = async (e: React.FormEvent) => {
e.preventDefault();
setLoginError(null);
if (isLoggingIn) return;
setIsLoggingIn(true);
try {
const data = await login({
username: email,
password,
grant_type: "password",
});
if (typeof window !== "undefined") {
sessionStorage.setItem("login_success_tracked", "true");
if (!api?.loginPassword) {
throw new Error("Desktop password login is not available");
}
await api.loginPassword(email, password);
setBearerToken(data.access_token);
await autoSetSearchSpace();
setTimeout(() => {
router.push(`/auth/callback?token=${data.access_token}`);
router.push(getPostLoginRedirectPath());
}, 300);
} catch (err) {
if (err instanceof Error) {
@ -283,6 +283,8 @@ export default function DesktopLoginPage() {
} else {
setLoginError("Login failed. Please check your credentials.");
}
} finally {
setIsLoggingIn(false);
}
};

View file

@ -30,8 +30,9 @@ import {
} from "@/components/ui/card";
import { Spinner } from "@/components/ui/spinner";
import type { AcceptInviteResponse } from "@/contracts/types/invites.types";
import { useSession } from "@/hooks/use-session";
import { invitesApiService } from "@/lib/apis/invites-api.service";
import { getBearerToken, setRedirectPath } from "@/lib/auth-utils";
import { setRedirectPath } from "@/lib/auth-utils";
import {
trackSearchSpaceInviteAccepted,
trackSearchSpaceInviteDeclined,
@ -43,6 +44,7 @@ export default function InviteAcceptPage() {
const params = useParams();
const router = useRouter();
const inviteCode = params.invite_code as string;
const session = useSession();
const { data: inviteInfo = null, isLoading: loading } = useQuery({
queryKey: cacheKeys.invites.info(inviteCode),
@ -81,11 +83,9 @@ export default function InviteAcceptPage() {
// Check if user is logged in
useEffect(() => {
if (typeof window !== "undefined") {
const token = getBearerToken();
setIsLoggedIn(!!token);
}
}, []);
if (session.status === "loading") return;
setIsLoggedIn(session.status === "authenticated");
}, [session.status]);
const handleAccept = async () => {
setAccepting(true);

View file

@ -5,6 +5,7 @@ import { Roboto } from "next/font/google";
import Script from "next/script";
import { AnnouncementToastProvider } from "@/components/announcements/AnnouncementToastProvider";
import { DesktopUpdateToast } from "@/components/desktop/desktop-update-toast";
import { AuthCutoverPurge } from "@/components/providers/AuthCutoverPurge";
import { GlobalLoadingProvider } from "@/components/providers/GlobalLoadingProvider";
import { I18nProvider } from "@/components/providers/I18nProvider";
import { PostHogProvider } from "@/components/providers/PostHogProvider";
@ -17,13 +18,10 @@ import {
import { ThemeProvider } from "@/components/theme/theme-provider";
import { Toaster } from "@/components/ui/sonner";
import { LocaleProvider } from "@/contexts/LocaleContext";
import { BUILD_TIME_AUTH_TYPE } from "@/lib/env-config";
import { PlatformProvider } from "@/contexts/platform-context";
import { BUILD_TIME_AUTH_TYPE } from "@/lib/env-config";
import { ReactQueryClientProvider } from "@/lib/query-client/query-client.provider";
import {
getRuntimeAuthInitScript,
resolveRuntimeAuthUiMode,
} from "@/lib/runtime-auth-config";
import { getRuntimeAuthInitScript, resolveRuntimeAuthUiMode } from "@/lib/runtime-auth-config";
import { cn } from "@/lib/utils";
const roboto = Roboto({
@ -164,6 +162,7 @@ export default function RootLayout({
<PlatformProvider>
<RootProvider>
<ReactQueryClientProvider>
<AuthCutoverPurge />
<ZeroProvider>
<GlobalLoadingProvider>{children}</GlobalLoadingProvider>
</ZeroProvider>

View file

@ -15,6 +15,7 @@ export async function GET(request: NextRequest) {
headers: {
Authorization: request.headers.get("authorization") || "",
"X-API-Key": request.headers.get("x-api-key") || "",
Cookie: request.headers.get("cookie") || "",
},
cache: "no-store",
});

View file

@ -1,6 +1,6 @@
import { atomWithQuery } from "jotai-tanstack-query";
import { agentFlagsApiService } from "@/lib/apis/agent-flags-api.service";
import { getBearerToken } from "@/lib/auth-utils";
import { isAuthenticated } from "@/lib/auth-utils";
export const AGENT_FLAGS_QUERY_KEY = ["agent", "flags"] as const;
@ -12,6 +12,6 @@ export const AGENT_FLAGS_QUERY_KEY = ["agent", "flags"] as const;
export const agentFlagsAtom = atomWithQuery(() => ({
queryKey: AGENT_FLAGS_QUERY_KEY,
staleTime: 10 * 60 * 1000,
enabled: !!getBearerToken(),
enabled: isAuthenticated(),
queryFn: () => agentFlagsApiService.get(),
}));

Some files were not shown because too many files have changed in this diff Show more