mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-26 21:39:43 +02:00
Merge pull request #1535 from AnishSarkar22/feat/auth-revamp
feat(auth): complete session auth cutover with desktop oauth support
This commit is contained in:
commit
6950646bf1
158 changed files with 4032 additions and 1270 deletions
2
.github/workflows/desktop-release.yml
vendored
2
.github/workflows/desktop-release.yml
vendored
|
|
@ -113,6 +113,7 @@ jobs:
|
|||
env:
|
||||
HOSTED_BACKEND_URL: ${{ vars.HOSTED_BACKEND_URL }}
|
||||
HOSTED_FRONTEND_URL: ${{ vars.HOSTED_FRONTEND_URL }}
|
||||
GOOGLE_DESKTOP_CLIENT_ID: ${{ vars.GOOGLE_DESKTOP_CLIENT_ID }}
|
||||
POSTHOG_KEY: ${{ secrets.POSTHOG_KEY }}
|
||||
POSTHOG_HOST: ${{ vars.POSTHOG_HOST }}
|
||||
|
||||
|
|
@ -143,6 +144,7 @@ jobs:
|
|||
working-directory: surfsense_desktop
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
GOOGLE_DESKTOP_CLIENT_ID: ${{ vars.GOOGLE_DESKTOP_CLIENT_ID }}
|
||||
WINDOWS_PUBLISHER_NAME: ${{ vars.WINDOWS_PUBLISHER_NAME }}
|
||||
AZURE_CODESIGN_ENDPOINT: ${{ vars.AZURE_CODESIGN_ENDPOINT }}
|
||||
AZURE_CODESIGN_ACCOUNT: ${{ vars.AZURE_CODESIGN_ACCOUNT }}
|
||||
|
|
|
|||
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -20,3 +20,4 @@ surfsense_web/blob-report/
|
|||
content_research/
|
||||
automation-design-plan.md
|
||||
automation-frontend-builder-plan.md
|
||||
surfsense_desktop/.env
|
||||
|
|
|
|||
|
|
@ -30,6 +30,11 @@ SECRET_KEY=replace_me_with_a_random_string
|
|||
# Auth type: LOCAL (email/password) or GOOGLE (OAuth)
|
||||
AUTH_TYPE=LOCAL
|
||||
|
||||
# Cloud only: set COOKIE_DOMAIN=.surfsense.com so api., zero., and app
|
||||
# subdomains all receive the same first-party session cookie. Leave empty for
|
||||
# self-hosted Docker where Caddy serves a single origin.
|
||||
# COOKIE_DOMAIN=
|
||||
|
||||
# Deployment mode: self-hosted enables local filesystem connectors; cloud hides them.
|
||||
DEPLOYMENT_MODE=self-hosted
|
||||
|
||||
|
|
@ -135,6 +140,19 @@ CERT_EMAIL=
|
|||
# ZERO_MUTATE_URL=https://surf.example.com/api/zero/mutate
|
||||
# ZERO_QUERY_URL=http://frontend:3000/api/zero/query
|
||||
# ZERO_MUTATE_URL=http://frontend:3000/api/zero/mutate
|
||||
#
|
||||
# Forward browser session cookies from zero-cache to the query route. Keep this
|
||||
# enabled before switching the web app to cookie-only auth.
|
||||
# ZERO_QUERY_FORWARD_COOKIES=true
|
||||
#
|
||||
# Optional shared secret for the zero-cache -> /api/zero/query hop. Set the same
|
||||
# value on zero-cache and the frontend. When unset, the query route accepts the
|
||||
# request for backward-compatible rollout.
|
||||
# ZERO_QUERY_API_KEY=
|
||||
#
|
||||
# Bounds for auth revocation and RBAC membership changes on already-open sockets.
|
||||
# ZERO_AUTH_REVALIDATE_INTERVAL_SECONDS=60
|
||||
# ZERO_AUTH_RETRANSFORM_INTERVAL_SECONDS=60
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Database (defaults work out of the box, change for security)
|
||||
|
|
|
|||
|
|
@ -99,7 +99,7 @@ services:
|
|||
# container to run migrations, so you must run `uv run alembic upgrade head`
|
||||
# from `surfsense_backend/` on the host BEFORE `docker compose up -d`.
|
||||
zero-cache:
|
||||
image: rocicorp/zero:1.4.0
|
||||
image: rocicorp/zero:1.6.0
|
||||
ports:
|
||||
- "${ZERO_CACHE_PORT:-4848}:4848"
|
||||
extra_hosts:
|
||||
|
|
@ -120,6 +120,10 @@ services:
|
|||
- ZERO_CVR_MAX_CONNS=${ZERO_CVR_MAX_CONNS:-30}
|
||||
- ZERO_QUERY_URL=${ZERO_QUERY_URL:-http://host.docker.internal:3000/api/zero/query}
|
||||
- ZERO_MUTATE_URL=${ZERO_MUTATE_URL:-http://host.docker.internal:3000/api/zero/mutate}
|
||||
- ZERO_QUERY_FORWARD_COOKIES=${ZERO_QUERY_FORWARD_COOKIES:-true}
|
||||
- ZERO_QUERY_API_KEY=${ZERO_QUERY_API_KEY:-}
|
||||
- ZERO_AUTH_REVALIDATE_INTERVAL_SECONDS=${ZERO_AUTH_REVALIDATE_INTERVAL_SECONDS:-60}
|
||||
- ZERO_AUTH_RETRANSFORM_INTERVAL_SECONDS=${ZERO_AUTH_RETRANSFORM_INTERVAL_SECONDS:-60}
|
||||
volumes:
|
||||
- zero_cache_data:/data
|
||||
restart: unless-stopped
|
||||
|
|
|
|||
|
|
@ -220,7 +220,7 @@ services:
|
|||
condition: service_started
|
||||
|
||||
zero-cache:
|
||||
image: rocicorp/zero:1.4.0
|
||||
image: rocicorp/zero:1.6.0
|
||||
ports:
|
||||
- "${ZERO_CACHE_PORT:-4848}:4848"
|
||||
extra_hosts:
|
||||
|
|
@ -243,6 +243,10 @@ services:
|
|||
- ZERO_CVR_MAX_CONNS=${ZERO_CVR_MAX_CONNS:-30}
|
||||
- ZERO_QUERY_URL=${ZERO_QUERY_URL:-http://frontend:3000/api/zero/query}
|
||||
- ZERO_MUTATE_URL=${ZERO_MUTATE_URL:-http://frontend:3000/api/zero/mutate}
|
||||
- ZERO_QUERY_FORWARD_COOKIES=${ZERO_QUERY_FORWARD_COOKIES:-true}
|
||||
- ZERO_QUERY_API_KEY=${ZERO_QUERY_API_KEY:-}
|
||||
- ZERO_AUTH_REVALIDATE_INTERVAL_SECONDS=${ZERO_AUTH_REVALIDATE_INTERVAL_SECONDS:-60}
|
||||
- ZERO_AUTH_RETRANSFORM_INTERVAL_SECONDS=${ZERO_AUTH_RETRANSFORM_INTERVAL_SECONDS:-60}
|
||||
volumes:
|
||||
- zero_cache_data:/data
|
||||
restart: unless-stopped
|
||||
|
|
|
|||
|
|
@ -250,7 +250,7 @@ services:
|
|||
restart: unless-stopped
|
||||
|
||||
zero-cache:
|
||||
image: rocicorp/zero:1.4.0
|
||||
image: rocicorp/zero:1.6.0
|
||||
expose:
|
||||
- "4848"
|
||||
extra_hosts:
|
||||
|
|
@ -268,6 +268,10 @@ services:
|
|||
ZERO_CVR_MAX_CONNS: ${ZERO_CVR_MAX_CONNS:-30}
|
||||
ZERO_QUERY_URL: ${ZERO_QUERY_URL:-http://frontend:3000/api/zero/query}
|
||||
ZERO_MUTATE_URL: ${ZERO_MUTATE_URL:-http://frontend:3000/api/zero/mutate}
|
||||
ZERO_QUERY_FORWARD_COOKIES: ${ZERO_QUERY_FORWARD_COOKIES:-true}
|
||||
ZERO_QUERY_API_KEY: ${ZERO_QUERY_API_KEY:-}
|
||||
ZERO_AUTH_REVALIDATE_INTERVAL_SECONDS: ${ZERO_AUTH_REVALIDATE_INTERVAL_SECONDS:-60}
|
||||
ZERO_AUTH_RETRANSFORM_INTERVAL_SECONDS: ${ZERO_AUTH_RETRANSFORM_INTERVAL_SECONDS:-60}
|
||||
volumes:
|
||||
- zero_cache_data:/data
|
||||
restart: unless-stopped
|
||||
|
|
|
|||
|
|
@ -81,9 +81,24 @@ STRIPE_RECONCILIATION_INTERVAL=10m
|
|||
|
||||
SECRET_KEY=SECRET
|
||||
|
||||
# JWT Token Lifetimes (optional, defaults shown)
|
||||
# ACCESS_TOKEN_LIFETIME_SECONDS=86400 # 1 day
|
||||
# REFRESH_TOKEN_LIFETIME_SECONDS=1209600 # 2 weeks
|
||||
# JWT/session lifetimes (optional, defaults shown)
|
||||
# ACCESS_TOKEN_LIFETIME_SECONDS=1800 # 30 minutes
|
||||
# REFRESH_TOKEN_LIFETIME_SECONDS=1209600 # 14-day inactivity window
|
||||
# REFRESH_ROTATION_GRACE_SECONDS=45
|
||||
# REFRESH_ABSOLUTE_LIFETIME_SECONDS=2592000 # 30-day absolute cap
|
||||
#
|
||||
# Web session cookies. Leave COOKIE_DOMAIN empty for self-hosted same-origin
|
||||
# Docker. In cloud, use .surfsense.com so api., zero., and the app share the
|
||||
# first-party session cookie.
|
||||
# SESSION_COOKIE_NAME=surfsense_session
|
||||
# REFRESH_COOKIE_NAME=surfsense_refresh
|
||||
# SESSION_COOKIE_SECURE_POLICY=auto
|
||||
# SESSION_COOKIE_SAMESITE=lax
|
||||
# COOKIE_DOMAIN=
|
||||
#
|
||||
# Comma-separated allow-list for cookie-session unsafe requests. Defaults also
|
||||
# include NEXT_FRONTEND_URL and SURFSENSE_PUBLIC_URL when set.
|
||||
# CSRF_ALLOWED_ORIGINS=http://localhost:3000
|
||||
# Personal Access Tokens (PATs). Empty/unset = no maximum; users may create
|
||||
# never-expiring PATs. When set, PAT creation requires an expiry <= this many days.
|
||||
# PAT_MAX_EXPIRY_DAYS=
|
||||
|
|
@ -115,6 +130,8 @@ REGISTRATION_ENABLED=TRUE or FALSE
|
|||
# For Google Auth Only
|
||||
GOOGLE_OAUTH_CLIENT_ID=924507538m
|
||||
GOOGLE_OAUTH_CLIENT_SECRET=GOCSV
|
||||
GOOGLE_DESKTOP_CLIENT_ID=your_google_desktop_client_id
|
||||
GOOGLE_DESKTOP_CLIENT_SECRET=your_google_desktop_client_secret
|
||||
GOOGLE_PICKER_API_KEY=your-google-picker-api-key
|
||||
|
||||
# Google Connector Specific Configurations
|
||||
|
|
|
|||
|
|
@ -77,7 +77,5 @@ def upgrade() -> None:
|
|||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
"ALTER TABLE searchspaces DROP COLUMN IF EXISTS api_access_enabled"
|
||||
)
|
||||
op.execute("ALTER TABLE searchspaces DROP COLUMN IF EXISTS api_access_enabled")
|
||||
op.execute("DROP TABLE IF EXISTS personal_access_tokens")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,23 @@
|
|||
"""publish Zero authz parent tables
|
||||
|
||||
Revision ID: 167
|
||||
Revises: 166
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
from app.zero_publication import apply_publication
|
||||
|
||||
revision: str = "167"
|
||||
down_revision: str | None = "166"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
apply_publication(op.get_bind())
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""No-op. Historical publication shapes are immutable."""
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
"""harden refresh token schema
|
||||
|
||||
Revision ID: 168
|
||||
Revises: 167
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "168"
|
||||
down_revision: str | None = "167"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"refresh_tokens",
|
||||
sa.Column("revoked_at", sa.TIMESTAMP(timezone=True), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"refresh_tokens",
|
||||
sa.Column("absolute_expiry", sa.TIMESTAMP(timezone=True), nullable=True),
|
||||
)
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE refresh_tokens
|
||||
SET revoked_at = NOW()
|
||||
WHERE is_revoked = TRUE
|
||||
"""
|
||||
)
|
||||
op.alter_column(
|
||||
"refresh_tokens",
|
||||
"token_hash",
|
||||
existing_type=sa.String(length=256),
|
||||
type_=sa.String(length=64),
|
||||
existing_nullable=False,
|
||||
)
|
||||
op.drop_column("refresh_tokens", "is_revoked")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"refresh_tokens",
|
||||
sa.Column("is_revoked", sa.Boolean(), nullable=False, server_default="false"),
|
||||
)
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE refresh_tokens
|
||||
SET is_revoked = TRUE
|
||||
WHERE revoked_at IS NOT NULL
|
||||
"""
|
||||
)
|
||||
op.alter_column("refresh_tokens", "is_revoked", server_default=None)
|
||||
op.alter_column(
|
||||
"refresh_tokens",
|
||||
"token_hash",
|
||||
existing_type=sa.String(length=64),
|
||||
type_=sa.String(length=256),
|
||||
existing_nullable=False,
|
||||
)
|
||||
op.drop_column("refresh_tokens", "absolute_expiry")
|
||||
op.drop_column("refresh_tokens", "revoked_at")
|
||||
|
|
@ -0,0 +1,74 @@
|
|||
"""migrate Google OAuth account IDs to sub
|
||||
|
||||
Revision ID: 169
|
||||
Revises: 168
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "169"
|
||||
down_revision: str | None = "168"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def _oauth_account_table_exists() -> bool:
|
||||
bind = op.get_bind()
|
||||
return bool(
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT EXISTS (
|
||||
SELECT 1
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = current_schema()
|
||||
AND table_name = 'oauth_account'
|
||||
)
|
||||
"""
|
||||
)
|
||||
).scalar()
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
if not _oauth_account_table_exists():
|
||||
return
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE oauth_account AS legacy
|
||||
SET account_id = regexp_replace(legacy.account_id, '^people/', '')
|
||||
WHERE legacy.oauth_name = 'google'
|
||||
AND legacy.account_id LIKE 'people/%'
|
||||
AND NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM oauth_account AS canonical
|
||||
WHERE canonical.oauth_name = 'google'
|
||||
AND canonical.account_id = regexp_replace(legacy.account_id, '^people/', '')
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
if not _oauth_account_table_exists():
|
||||
return
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE oauth_account AS canonical
|
||||
SET account_id = 'people/' || canonical.account_id
|
||||
WHERE canonical.oauth_name = 'google'
|
||||
AND canonical.account_id NOT LIKE 'people/%'
|
||||
AND NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM oauth_account AS legacy
|
||||
WHERE legacy.oauth_name = 'google'
|
||||
AND legacy.account_id = 'people/' || canonical.account_id
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
|
@ -58,6 +58,7 @@ def create_create_automation_tool(
|
|||
``AsyncSession`` is opened per call to avoid stale sessions on
|
||||
compiled-agent cache hits (same pattern as the Notion / memory tools).
|
||||
"""
|
||||
|
||||
@tool
|
||||
async def create_automation(intent: str, runtime: ToolRuntime) -> dict[str, Any]:
|
||||
"""Draft + save an automation from a natural-language intent.
|
||||
|
|
|
|||
|
|
@ -242,6 +242,7 @@ def create_generate_image_tool(
|
|||
|
||||
# Update all image URLs in response_dict to be absolute (for the serving endpoint)
|
||||
from urllib.parse import urlparse
|
||||
|
||||
for image in images:
|
||||
if image.get("url"):
|
||||
raw_url: str = image["url"]
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from datetime import UTC, datetime
|
|||
from threading import Lock
|
||||
|
||||
import redis
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request, status
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request, Response, status
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
|
|
@ -28,6 +28,7 @@ from app.agents.chat.runtime.checkpointer import (
|
|||
setup_checkpointer_tables,
|
||||
)
|
||||
from app.auth.context import AuthContext
|
||||
from app.auth.csrf import CsrfOriginMiddleware
|
||||
from app.config import (
|
||||
config,
|
||||
initialize_image_gen_router,
|
||||
|
|
@ -53,8 +54,14 @@ from app.observability import metrics as ot_metrics
|
|||
from app.observability.bootstrap import init_otel, shutdown_otel
|
||||
from app.rate_limiter import get_real_client_ip, limiter
|
||||
from app.routes import router as crud_router
|
||||
from app.routes.auth_routes import router as auth_router
|
||||
from app.schemas import UserCreate, UserRead, UserUpdate
|
||||
from app.routes.auth_routes import (
|
||||
resolve_google_user,
|
||||
router as auth_router,
|
||||
session_router,
|
||||
)
|
||||
from app.routes.users_routes import router as users_router
|
||||
from app.routes.zero_context_routes import router as zero_context_router
|
||||
from app.schemas import UserCreate, UserRead
|
||||
from app.session_events import register_session_hooks
|
||||
from app.users import SECRET, allow_any_principal, auth_backend, fastapi_users
|
||||
from app.utils.perf import log_system_snapshot
|
||||
|
|
@ -803,6 +810,7 @@ allowed_origins.extend(
|
|||
]
|
||||
)
|
||||
|
||||
app.add_middleware(CsrfOriginMiddleware)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=allowed_origins,
|
||||
|
|
@ -855,16 +863,14 @@ if config.AUTH_TYPE != "GOOGLE":
|
|||
tags=["auth"],
|
||||
)
|
||||
|
||||
# /users/me (read/update profile) is needed in every auth mode, so it stays
|
||||
# mounted unconditionally.
|
||||
app.include_router(
|
||||
fastapi_users.get_users_router(UserRead, UserUpdate),
|
||||
prefix="/users",
|
||||
tags=["users"],
|
||||
)
|
||||
# /users/me uses the unified auth resolver so web cookie sessions, desktop bearer
|
||||
# sessions, and PAT principals all resolve through the same authority.
|
||||
app.include_router(users_router)
|
||||
|
||||
# Include custom auth routes (refresh token, logout)
|
||||
app.include_router(auth_router)
|
||||
app.include_router(session_router)
|
||||
app.include_router(zero_context_router)
|
||||
|
||||
if config.AUTH_TYPE == "GOOGLE":
|
||||
from fastapi.responses import RedirectResponse
|
||||
|
|
@ -890,36 +896,183 @@ if config.AUTH_TYPE == "GOOGLE":
|
|||
parsed_url = urlparse(config.BACKEND_URL)
|
||||
csrf_cookie_domain = parsed_url.hostname
|
||||
|
||||
app.include_router(
|
||||
fastapi_users.get_oauth_router(
|
||||
google_oauth_client,
|
||||
auth_backend,
|
||||
SECRET,
|
||||
is_verified_by_default=True,
|
||||
csrf_token_cookie_secure=is_secure_context,
|
||||
csrf_token_cookie_samesite=csrf_cookie_samesite,
|
||||
csrf_token_cookie_httponly=False, # Required for cross-site OAuth in Firefox/Safari
|
||||
from fastapi_users.jwt import decode_jwt
|
||||
from fastapi_users.router.oauth import (
|
||||
CSRF_TOKEN_COOKIE_NAME,
|
||||
CSRF_TOKEN_KEY,
|
||||
STATE_TOKEN_AUDIENCE,
|
||||
generate_state_token,
|
||||
)
|
||||
from google.auth.transport import requests as google_requests
|
||||
from google.oauth2 import id_token as google_id_token
|
||||
|
||||
from app.users import get_user_manager
|
||||
|
||||
def _google_callback_url(request: Request) -> str:
|
||||
if config.BACKEND_URL:
|
||||
return f"{config.BACKEND_URL}/auth/google/callback"
|
||||
return str(request.url_for("google_oauth_callback"))
|
||||
|
||||
def _set_google_oauth_csrf_cookie(response: Response, csrf_token: str) -> None:
|
||||
response.set_cookie(
|
||||
key=CSRF_TOKEN_COOKIE_NAME,
|
||||
value=csrf_token,
|
||||
max_age=3600,
|
||||
path="/",
|
||||
domain=csrf_cookie_domain,
|
||||
secure=is_secure_context,
|
||||
httponly=False, # Required for cross-site OAuth in Firefox/Safari
|
||||
samesite=csrf_cookie_samesite,
|
||||
)
|
||||
if not config.BACKEND_URL
|
||||
else fastapi_users.get_oauth_router(
|
||||
google_oauth_client,
|
||||
auth_backend,
|
||||
|
||||
async def _google_authorization_url(request: Request, response: Response) -> str:
|
||||
import secrets
|
||||
|
||||
csrf_token = secrets.token_urlsafe(32)
|
||||
state = generate_state_token(
|
||||
{CSRF_TOKEN_KEY: csrf_token},
|
||||
SECRET,
|
||||
is_verified_by_default=True,
|
||||
redirect_url=f"{config.BACKEND_URL}/auth/google/callback",
|
||||
csrf_token_cookie_secure=is_secure_context,
|
||||
csrf_token_cookie_samesite=csrf_cookie_samesite,
|
||||
csrf_token_cookie_httponly=False, # Required for cross-site OAuth in Firefox/Safari
|
||||
csrf_token_cookie_domain=csrf_cookie_domain, # Explicitly set cookie domain
|
||||
),
|
||||
prefix="/auth/google",
|
||||
lifetime_seconds=3600,
|
||||
)
|
||||
authorization_url = await google_oauth_client.get_authorization_url(
|
||||
_google_callback_url(request),
|
||||
state,
|
||||
scope=["openid", "email", "profile"],
|
||||
)
|
||||
_set_google_oauth_csrf_cookie(response, csrf_token)
|
||||
return authorization_url
|
||||
|
||||
@app.get(
|
||||
"/auth/google/authorize",
|
||||
tags=["auth"],
|
||||
# REGISTRATION_ENABLED is a master auth kill switch: when set to FALSE
|
||||
# it blocks BOTH new OAuth signups AND login of existing OAuth users
|
||||
# (the fastapi-users OAuth router shares one callback for create+login,
|
||||
# so this dependency closes both paths together).
|
||||
dependencies=[Depends(registration_allowed)],
|
||||
)
|
||||
async def google_authorize(request: Request, response: Response):
|
||||
"""Return Google's authorization URL, matching fastapi-users' shape."""
|
||||
return {"authorization_url": await _google_authorization_url(request, response)}
|
||||
|
||||
@app.get(
|
||||
"/auth/google/callback",
|
||||
name="google_oauth_callback",
|
||||
tags=["auth"],
|
||||
dependencies=[Depends(registration_allowed)],
|
||||
)
|
||||
async def google_oauth_callback(
|
||||
request: Request,
|
||||
user_manager=Depends(get_user_manager),
|
||||
):
|
||||
"""Handle web Google OAuth with the same verified-email policy as desktop."""
|
||||
import secrets
|
||||
|
||||
import httpx
|
||||
import jwt as pyjwt
|
||||
|
||||
state = request.query_params.get("state")
|
||||
code = request.query_params.get("code")
|
||||
if not state or not code:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="OAuth callback missing code or state",
|
||||
)
|
||||
|
||||
try:
|
||||
state_data = decode_jwt(state, SECRET, [STATE_TOKEN_AUDIENCE])
|
||||
except pyjwt.DecodeError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="ACCESS_TOKEN_DECODE_ERROR",
|
||||
) from exc
|
||||
except pyjwt.ExpiredSignatureError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="ACCESS_TOKEN_ALREADY_EXPIRED",
|
||||
) from exc
|
||||
|
||||
cookie_csrf_token = request.cookies.get(CSRF_TOKEN_COOKIE_NAME)
|
||||
state_csrf_token = state_data.get(CSRF_TOKEN_KEY)
|
||||
if (
|
||||
not cookie_csrf_token
|
||||
or not state_csrf_token
|
||||
or not secrets.compare_digest(cookie_csrf_token, state_csrf_token)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="OAUTH_INVALID_STATE",
|
||||
)
|
||||
|
||||
token_payload = {
|
||||
"client_id": config.GOOGLE_OAUTH_CLIENT_ID,
|
||||
"client_secret": config.GOOGLE_OAUTH_CLIENT_SECRET,
|
||||
"code": code,
|
||||
"grant_type": "authorization_code",
|
||||
"redirect_uri": _google_callback_url(request),
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
token_response = await client.post(
|
||||
"https://oauth2.googleapis.com/token",
|
||||
data=token_payload,
|
||||
)
|
||||
if token_response.status_code >= 400:
|
||||
_error_logger.warning(
|
||||
"Web Google OAuth exchange failed: %s", token_response.text
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="OAuth exchange failed",
|
||||
)
|
||||
|
||||
token_data = token_response.json()
|
||||
google_access_token = token_data.get("access_token")
|
||||
google_id_token_value = token_data.get("id_token")
|
||||
if not google_access_token or not google_id_token_value:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="OAuth exchange failed",
|
||||
)
|
||||
|
||||
try:
|
||||
claims = google_id_token.verify_oauth2_token(
|
||||
google_id_token_value,
|
||||
google_requests.Request(),
|
||||
config.GOOGLE_OAUTH_CLIENT_ID,
|
||||
)
|
||||
except Exception as exc:
|
||||
_error_logger.warning("Web Google id_token verification failed: %s", exc)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid Google identity token",
|
||||
) from exc
|
||||
|
||||
expires_at = (
|
||||
int(datetime.now(UTC).timestamp()) + int(token_data["expires_in"])
|
||||
if token_data.get("expires_in")
|
||||
else None
|
||||
)
|
||||
user = await resolve_google_user(
|
||||
user_manager=user_manager,
|
||||
request=request,
|
||||
google_access_token=google_access_token,
|
||||
claims=claims,
|
||||
expires_at=expires_at,
|
||||
google_refresh_token=token_data.get("refresh_token"),
|
||||
)
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="LOGIN_BAD_CREDENTIALS",
|
||||
)
|
||||
|
||||
response = await auth_backend.login(auth_backend.get_strategy(), user)
|
||||
await user_manager.on_after_login(user, request, response)
|
||||
response.delete_cookie(
|
||||
key=CSRF_TOKEN_COOKIE_NAME,
|
||||
path="/",
|
||||
domain=csrf_cookie_domain,
|
||||
secure=is_secure_context,
|
||||
samesite=csrf_cookie_samesite,
|
||||
httponly=False,
|
||||
)
|
||||
return response
|
||||
|
||||
# Add a redirect-based authorize endpoint for Firefox/Safari compatibility
|
||||
# This endpoint performs a server-side redirect instead of returning JSON
|
||||
|
|
@ -944,43 +1097,9 @@ if config.AUTH_TYPE == "GOOGLE":
|
|||
This fixes CSRF cookie issues in Firefox and Safari where cookies set
|
||||
via cross-origin fetch requests are not sent on subsequent redirects.
|
||||
"""
|
||||
import secrets
|
||||
|
||||
from fastapi_users.router.oauth import generate_state_token
|
||||
|
||||
# Generate CSRF token
|
||||
csrf_token = secrets.token_urlsafe(32)
|
||||
|
||||
# Build state token
|
||||
state_data = {"csrftoken": csrf_token}
|
||||
state = generate_state_token(state_data, SECRET, lifetime_seconds=3600)
|
||||
|
||||
# Get the callback URL
|
||||
if config.BACKEND_URL:
|
||||
redirect_url = f"{config.BACKEND_URL}/auth/google/callback"
|
||||
else:
|
||||
redirect_url = str(request.url_for("oauth:google.jwt.callback"))
|
||||
|
||||
# Get authorization URL from Google
|
||||
authorization_url = await google_oauth_client.get_authorization_url(
|
||||
redirect_url,
|
||||
state,
|
||||
scope=["openid", "email", "profile"],
|
||||
)
|
||||
|
||||
# Create redirect response and set CSRF cookie
|
||||
response = RedirectResponse(url=authorization_url, status_code=302)
|
||||
response.set_cookie(
|
||||
key="fastapiusersoauthcsrf",
|
||||
value=csrf_token,
|
||||
max_age=3600,
|
||||
path="/",
|
||||
domain=csrf_cookie_domain,
|
||||
secure=is_secure_context,
|
||||
httponly=False, # Required for cross-site OAuth in Firefox/Safari
|
||||
samesite=csrf_cookie_samesite,
|
||||
)
|
||||
|
||||
response = RedirectResponse(url="", status_code=302)
|
||||
authorization_url = await _google_authorization_url(request, response)
|
||||
response.headers["location"] = authorization_url
|
||||
return response
|
||||
|
||||
|
||||
|
|
|
|||
61
surfsense_backend/app/auth/csrf.py
Normal file
61
surfsense_backend/app/auth/csrf.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
"""CSRF protection for ambient cookie-authenticated requests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fastapi import status
|
||||
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
|
||||
from app.config import config
|
||||
|
||||
UNSAFE_METHODS = {"POST", "PUT", "PATCH", "DELETE"}
|
||||
|
||||
|
||||
def _origin_from_url(url: str | None) -> str | None:
|
||||
if not url:
|
||||
return None
|
||||
parsed = urlparse(url)
|
||||
if not parsed.scheme or not parsed.netloc:
|
||||
return None
|
||||
return f"{parsed.scheme}://{parsed.netloc}"
|
||||
|
||||
|
||||
def _allowed_origins() -> set[str]:
|
||||
origins = set(config.CSRF_ALLOWED_ORIGINS)
|
||||
for url in (config.NEXT_FRONTEND_URL, config.SURFSENSE_PUBLIC_URL):
|
||||
origin = _origin_from_url(url)
|
||||
if origin:
|
||||
origins.add(origin)
|
||||
return origins
|
||||
|
||||
|
||||
class CsrfOriginMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(
|
||||
self,
|
||||
request: Request,
|
||||
call_next: RequestResponseEndpoint,
|
||||
) -> Response:
|
||||
if request.method not in UNSAFE_METHODS:
|
||||
return await call_next(request)
|
||||
|
||||
# PAT/Bearer credentials are not ambient browser credentials and are not
|
||||
# CSRF-able. Enforce only when the web session cookie is the credential.
|
||||
if (
|
||||
request.headers.get("Authorization")
|
||||
or config.SESSION_COOKIE_NAME not in request.cookies
|
||||
):
|
||||
return await call_next(request)
|
||||
|
||||
origin = request.headers.get("Origin") or _origin_from_url(
|
||||
request.headers.get("Referer")
|
||||
)
|
||||
if origin not in _allowed_origins():
|
||||
return JSONResponse(
|
||||
{"detail": "CSRF origin check failed"},
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
130
surfsense_backend/app/auth/session_cookies.py
Normal file
130
surfsense_backend/app/auth/session_cookies.py
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
"""Centralized session-cookie I/O for web authentication."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import jwt
|
||||
from fastapi import Request, Response
|
||||
|
||||
from app.config import config
|
||||
|
||||
|
||||
class TransportMode(Enum):
|
||||
COOKIE = "cookie"
|
||||
HEADER = "header"
|
||||
|
||||
|
||||
def _cookie_secure(request: Request | None = None) -> bool:
|
||||
policy = config.SESSION_COOKIE_SECURE_POLICY
|
||||
if policy == "always":
|
||||
return True
|
||||
if policy == "never":
|
||||
return False
|
||||
if request is not None:
|
||||
proto = request.headers.get("x-forwarded-proto")
|
||||
if proto:
|
||||
return proto.split(",", 1)[0].strip().lower() == "https"
|
||||
return request.url.scheme == "https"
|
||||
return bool(config.BACKEND_URL and config.BACKEND_URL.startswith("https://"))
|
||||
|
||||
|
||||
def _set_persistent_cookie(
|
||||
response: Response,
|
||||
*,
|
||||
key: str,
|
||||
value: str,
|
||||
max_age: int,
|
||||
request: Request | None,
|
||||
) -> None:
|
||||
expires = datetime.now(UTC) + timedelta(seconds=max_age)
|
||||
response.set_cookie(
|
||||
key=key,
|
||||
value=value,
|
||||
max_age=max_age,
|
||||
expires=expires,
|
||||
httponly=True,
|
||||
secure=_cookie_secure(request),
|
||||
samesite=config.SESSION_COOKIE_SAMESITE,
|
||||
domain=config.COOKIE_DOMAIN,
|
||||
path="/",
|
||||
)
|
||||
|
||||
|
||||
def write_session(
|
||||
response: Response,
|
||||
access: str,
|
||||
refresh: str | None = None,
|
||||
request: Request | None = None,
|
||||
) -> None:
|
||||
_set_persistent_cookie(
|
||||
response,
|
||||
key=config.SESSION_COOKIE_NAME,
|
||||
value=access,
|
||||
max_age=config.ACCESS_TOKEN_LIFETIME_SECONDS,
|
||||
request=request,
|
||||
)
|
||||
if refresh is not None:
|
||||
_set_persistent_cookie(
|
||||
response,
|
||||
key=config.REFRESH_COOKIE_NAME,
|
||||
value=refresh,
|
||||
max_age=config.REFRESH_TOKEN_LIFETIME_SECONDS,
|
||||
request=request,
|
||||
)
|
||||
|
||||
|
||||
def clear_session(response: Response, request: Request | None = None) -> None:
|
||||
for key in (config.SESSION_COOKIE_NAME, config.REFRESH_COOKIE_NAME):
|
||||
response.delete_cookie(
|
||||
key=key,
|
||||
path="/",
|
||||
domain=config.COOKIE_DOMAIN,
|
||||
secure=_cookie_secure(request),
|
||||
samesite=config.SESSION_COOKIE_SAMESITE,
|
||||
httponly=True,
|
||||
)
|
||||
|
||||
|
||||
def read_refresh(
|
||||
request: Request, body: Any | None = None
|
||||
) -> tuple[str | None, TransportMode]:
|
||||
cookie = request.cookies.get(config.REFRESH_COOKIE_NAME)
|
||||
if cookie:
|
||||
return cookie, TransportMode.COOKIE
|
||||
if body is None:
|
||||
return None, TransportMode.HEADER
|
||||
return getattr(body, "refresh_token", None), TransportMode.HEADER
|
||||
|
||||
|
||||
def access_expires_at(access_token: str) -> int:
|
||||
payload = jwt.decode(
|
||||
access_token,
|
||||
config.SECRET_KEY,
|
||||
algorithms=["HS256"],
|
||||
options={"verify_aud": False},
|
||||
)
|
||||
return int(payload["exp"])
|
||||
|
||||
|
||||
def issue(
|
||||
response: Response,
|
||||
mode: TransportMode,
|
||||
*,
|
||||
access: str,
|
||||
refresh: str | None,
|
||||
access_expires_at: int,
|
||||
request: Request | None = None,
|
||||
) -> dict:
|
||||
if mode is TransportMode.COOKIE:
|
||||
write_session(response, access, refresh, request)
|
||||
return {"authenticated": True, "access_expires_at": access_expires_at}
|
||||
|
||||
return {
|
||||
"access_token": access,
|
||||
"refresh_token": refresh,
|
||||
"token_type": "bearer",
|
||||
"access_expires_at": access_expires_at,
|
||||
}
|
||||
|
|
@ -10,6 +10,7 @@ from sqlalchemy import func, select
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.automations.persistence.enums.trigger_type import TriggerType
|
||||
from app.automations.persistence.models.automation import Automation
|
||||
from app.automations.persistence.models.trigger import AutomationTrigger
|
||||
|
|
@ -27,7 +28,6 @@ from app.automations.services.model_policy import (
|
|||
)
|
||||
from app.automations.triggers import get_trigger
|
||||
from app.automations.triggers.builtin.schedule import compute_next_fire_at
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import Permission, SearchSpace, get_async_session
|
||||
from app.users import get_auth_context
|
||||
from app.utils.rbac import check_permission
|
||||
|
|
|
|||
|
|
@ -6,9 +6,9 @@ from fastapi import Depends, HTTPException
|
|||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.automations.persistence.models.automation import Automation
|
||||
from app.automations.persistence.models.run import AutomationRun
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import Permission, get_async_session
|
||||
from app.users import get_auth_context
|
||||
from app.utils.rbac import check_permission
|
||||
|
|
|
|||
|
|
@ -8,13 +8,13 @@ from fastapi import Depends, HTTPException
|
|||
from pydantic import ValidationError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.automations.persistence.enums.trigger_type import TriggerType
|
||||
from app.automations.persistence.models.automation import Automation
|
||||
from app.automations.persistence.models.trigger import AutomationTrigger
|
||||
from app.automations.schemas.api import TriggerCreate, TriggerUpdate
|
||||
from app.automations.triggers import get_trigger
|
||||
from app.automations.triggers.builtin.schedule import compute_next_fire_at
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import Permission, get_async_session
|
||||
from app.users import get_auth_context
|
||||
from app.utils.rbac import check_permission
|
||||
|
|
|
|||
|
|
@ -188,6 +188,7 @@ celery_app = Celery(
|
|||
"app.tasks.celery_tasks.document_reindex_tasks",
|
||||
"app.tasks.celery_tasks.stale_notification_cleanup_task",
|
||||
"app.tasks.celery_tasks.stripe_reconciliation_task",
|
||||
"app.tasks.celery_tasks.refresh_token_cleanup_task",
|
||||
"app.tasks.celery_tasks.auto_reload_task",
|
||||
"app.tasks.celery_tasks.gateway_tasks",
|
||||
"app.etl_pipeline.cache.eviction.task",
|
||||
|
|
@ -306,6 +307,11 @@ celery_app.conf.beat_schedule = {
|
|||
"schedule": crontab(hour="3", minute="17"),
|
||||
"options": {"expires": 600},
|
||||
},
|
||||
"purge-refresh-tokens": {
|
||||
"task": "purge_refresh_tokens",
|
||||
"schedule": crontab(hour="3", minute="41"),
|
||||
"options": {"expires": 600},
|
||||
},
|
||||
# Prune the ETL parse cache (TTL + size budget) once daily, off-peak.
|
||||
"evict-etl-cache": {
|
||||
"task": "evict_etl_cache",
|
||||
|
|
|
|||
|
|
@ -768,6 +768,8 @@ class Config:
|
|||
# Google OAuth
|
||||
GOOGLE_OAUTH_CLIENT_ID = os.getenv("GOOGLE_OAUTH_CLIENT_ID")
|
||||
GOOGLE_OAUTH_CLIENT_SECRET = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET")
|
||||
GOOGLE_DESKTOP_CLIENT_ID = os.getenv("GOOGLE_DESKTOP_CLIENT_ID")
|
||||
GOOGLE_DESKTOP_CLIENT_SECRET = os.getenv("GOOGLE_DESKTOP_CLIENT_SECRET")
|
||||
GOOGLE_PICKER_API_KEY = os.getenv("GOOGLE_PICKER_API_KEY")
|
||||
|
||||
# Google Calendar redirect URI
|
||||
|
|
@ -914,15 +916,39 @@ class Config:
|
|||
|
||||
# JWT Token Lifetimes
|
||||
ACCESS_TOKEN_LIFETIME_SECONDS = int(
|
||||
os.getenv("ACCESS_TOKEN_LIFETIME_SECONDS", str(24 * 60 * 60)) # 1 day
|
||||
os.getenv("ACCESS_TOKEN_LIFETIME_SECONDS", str(30 * 60)) # 30 minutes
|
||||
)
|
||||
MIN_ISSUED_AT = int(os.getenv("MIN_ISSUED_AT", "0"))
|
||||
REFRESH_TOKEN_LIFETIME_SECONDS = int(
|
||||
os.getenv("REFRESH_TOKEN_LIFETIME_SECONDS", str(14 * 24 * 60 * 60)) # 2 weeks
|
||||
)
|
||||
_PAT_MAX_EXPIRY_DAYS = os.getenv("PAT_MAX_EXPIRY_DAYS", "").strip()
|
||||
PAT_MAX_EXPIRY_DAYS = (
|
||||
int(_PAT_MAX_EXPIRY_DAYS) if _PAT_MAX_EXPIRY_DAYS else None
|
||||
REFRESH_ROTATION_GRACE_SECONDS = int(
|
||||
os.getenv("REFRESH_ROTATION_GRACE_SECONDS", "45")
|
||||
)
|
||||
REFRESH_ABSOLUTE_LIFETIME_SECONDS = int(
|
||||
os.getenv("REFRESH_ABSOLUTE_LIFETIME_SECONDS", str(30 * 24 * 60 * 60))
|
||||
)
|
||||
if REFRESH_ABSOLUTE_LIFETIME_SECONDS <= REFRESH_TOKEN_LIFETIME_SECONDS:
|
||||
raise ValueError(
|
||||
"REFRESH_ABSOLUTE_LIFETIME_SECONDS must be greater than "
|
||||
"REFRESH_TOKEN_LIFETIME_SECONDS so the sliding inactivity window works."
|
||||
)
|
||||
SESSION_COOKIE_NAME = os.getenv("SESSION_COOKIE_NAME", "surfsense_session")
|
||||
REFRESH_COOKIE_NAME = os.getenv("REFRESH_COOKIE_NAME", "surfsense_refresh")
|
||||
SESSION_COOKIE_SECURE_POLICY = os.getenv(
|
||||
"SESSION_COOKIE_SECURE_POLICY", "auto"
|
||||
).lower()
|
||||
SESSION_COOKIE_SAMESITE = os.getenv("SESSION_COOKIE_SAMESITE", "lax").lower()
|
||||
if SESSION_COOKIE_SAMESITE == "none":
|
||||
raise ValueError("SESSION_COOKIE_SAMESITE=none is not supported")
|
||||
COOKIE_DOMAIN = os.getenv("COOKIE_DOMAIN") or None
|
||||
CSRF_ALLOWED_ORIGINS = [
|
||||
origin.strip()
|
||||
for origin in os.getenv("CSRF_ALLOWED_ORIGINS", "").split(",")
|
||||
if origin.strip()
|
||||
]
|
||||
_PAT_MAX_EXPIRY_DAYS = os.getenv("PAT_MAX_EXPIRY_DAYS", "").strip()
|
||||
PAT_MAX_EXPIRY_DAYS = int(_PAT_MAX_EXPIRY_DAYS) if _PAT_MAX_EXPIRY_DAYS else None
|
||||
|
||||
# ETL Service
|
||||
ETL_SERVICE = os.getenv("ETL_SERVICE")
|
||||
|
|
|
|||
|
|
@ -2714,9 +2714,10 @@ class RefreshToken(Base, TimestampMixin):
|
|||
index=True,
|
||||
)
|
||||
user = relationship("User", back_populates="refresh_tokens")
|
||||
token_hash = Column(String(256), unique=True, nullable=False, index=True)
|
||||
token_hash = Column(String(64), unique=True, nullable=False, index=True)
|
||||
expires_at = Column(TIMESTAMP(timezone=True), nullable=False, index=True)
|
||||
is_revoked = Column(Boolean, default=False, nullable=False)
|
||||
revoked_at = Column(TIMESTAMP(timezone=True), nullable=True)
|
||||
absolute_expiry = Column(TIMESTAMP(timezone=True), nullable=True)
|
||||
family_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
|
||||
@property
|
||||
|
|
@ -2725,7 +2726,7 @@ class RefreshToken(Base, TimestampMixin):
|
|||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
return not self.is_expired and not self.is_revoked
|
||||
return not self.is_expired and self.revoked_at is None
|
||||
|
||||
|
||||
class PersonalAccessToken(BaseModel, TimestampMixin):
|
||||
|
|
|
|||
|
|
@ -28,13 +28,12 @@ from pydantic import BaseModel
|
|||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.agents.chat.multi_agent_chat.shared.feature_flags import get_flags
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import (
|
||||
AgentActionLog,
|
||||
NewChatThread,
|
||||
Permission,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.users import get_auth_context
|
||||
|
|
@ -114,7 +113,6 @@ async def list_thread_actions(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> AgentActionListResponse:
|
||||
user = auth.user
|
||||
"""List agent actions for a thread, newest first.
|
||||
|
||||
Authorization:
|
||||
|
|
|
|||
|
|
@ -30,14 +30,13 @@ from sqlalchemy import select
|
|||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.agents.chat.multi_agent_chat.shared.feature_flags import get_flags
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import (
|
||||
AgentPermissionRule,
|
||||
NewChatThread,
|
||||
Permission,
|
||||
SearchSpace,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.users import get_auth_context
|
||||
|
|
@ -136,7 +135,6 @@ def _to_read(row: AgentPermissionRule) -> AgentPermissionRuleRead:
|
|||
async def _ensure_search_space_membership_admin(
|
||||
session: AsyncSession, auth: AuthContext, search_space_id: int
|
||||
) -> None:
|
||||
user = auth.user
|
||||
"""Curating agent rules == "settings" administration on the space."""
|
||||
space = await session.get(SearchSpace, search_space_id)
|
||||
if space is None:
|
||||
|
|
|
|||
|
|
@ -1,21 +1,46 @@
|
|||
"""Authentication routes for refresh token management."""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from fastapi_users import exceptions as fastapi_users_exceptions
|
||||
from google.auth.transport import requests as google_requests
|
||||
from google.oauth2 import id_token as google_id_token
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import User, async_session_maker
|
||||
from app.auth.session_cookies import (
|
||||
access_expires_at,
|
||||
clear_session,
|
||||
issue,
|
||||
read_refresh,
|
||||
)
|
||||
from app.config import config
|
||||
from app.db import User, async_session_maker, get_async_session
|
||||
from app.rate_limiter import limiter
|
||||
from app.schemas.auth import (
|
||||
DesktopLoginRequest,
|
||||
DesktopSessionRequest,
|
||||
LogoutAllResponse,
|
||||
LogoutRequest,
|
||||
LogoutResponse,
|
||||
RefreshTokenRequest,
|
||||
RefreshTokenResponse,
|
||||
SessionResponse,
|
||||
)
|
||||
from app.users import (
|
||||
UserManager,
|
||||
get_auth_context,
|
||||
get_jwt_strategy,
|
||||
get_user_manager,
|
||||
)
|
||||
from app.users import get_jwt_strategy, require_session_context
|
||||
from app.utils.refresh_tokens import (
|
||||
create_refresh_token,
|
||||
revoke_all_user_tokens,
|
||||
revoke_refresh_token,
|
||||
rotate_refresh_token,
|
||||
|
|
@ -25,57 +50,140 @@ from app.utils.refresh_tokens import (
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/auth/jwt", tags=["auth"])
|
||||
session_router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=RefreshTokenResponse)
|
||||
async def refresh_access_token(request: RefreshTokenRequest):
|
||||
async def _load_user(user_id) -> User | None:
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(select(User).where(User.id == user_id))
|
||||
return result.scalars().first()
|
||||
|
||||
|
||||
async def resolve_google_user(
|
||||
*,
|
||||
user_manager: UserManager,
|
||||
request: Request,
|
||||
google_access_token: str,
|
||||
claims: dict,
|
||||
expires_at: int | None = None,
|
||||
google_refresh_token: str | None = None,
|
||||
) -> User:
|
||||
"""Resolve a Google identity with one policy for web and desktop OAuth.
|
||||
|
||||
Email-based account linking is only allowed when Google asserts that the
|
||||
email is verified. Existing OAuth accounts continue to resolve by provider
|
||||
account id regardless of the current email claim.
|
||||
"""
|
||||
if not claims.get("sub") or not claims.get("email"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid Google identity token",
|
||||
)
|
||||
|
||||
sub = claims["sub"]
|
||||
email_verified = bool(claims.get("email_verified"))
|
||||
|
||||
canonical_user = await user_manager.user_db.get_by_oauth_account("google", sub)
|
||||
if canonical_user is None:
|
||||
legacy_account_id = f"people/{sub}"
|
||||
legacy_user = await user_manager.user_db.get_by_oauth_account(
|
||||
"google", legacy_account_id
|
||||
)
|
||||
if legacy_user is not None:
|
||||
# Fallback for pre-sub Google OAuth rows created by the old web flow.
|
||||
# TODO: Remove after oauth_account is fully backfilled to bare Google
|
||||
# sub and production has zero google rows with account_id LIKE 'people/%'.
|
||||
for oauth_account in legacy_user.oauth_accounts:
|
||||
if (
|
||||
oauth_account.oauth_name == "google"
|
||||
and oauth_account.account_id == legacy_account_id
|
||||
):
|
||||
await user_manager.user_db.update_oauth_account(
|
||||
legacy_user,
|
||||
oauth_account,
|
||||
{"account_id": sub},
|
||||
)
|
||||
break
|
||||
|
||||
try:
|
||||
return await user_manager.oauth_callback(
|
||||
"google",
|
||||
google_access_token,
|
||||
sub,
|
||||
claims["email"],
|
||||
expires_at=expires_at,
|
||||
refresh_token=google_refresh_token,
|
||||
request=request,
|
||||
associate_by_email=email_verified,
|
||||
is_verified_by_default=email_verified,
|
||||
)
|
||||
except fastapi_users_exceptions.UserAlreadyExists as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="OAUTH_USER_ALREADY_EXISTS",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=None)
|
||||
@limiter.limit("30/minute")
|
||||
async def refresh_access_token(
|
||||
request: Request,
|
||||
response: Response,
|
||||
body: RefreshTokenRequest | None = None,
|
||||
):
|
||||
"""
|
||||
Exchange a valid refresh token for a new access token and refresh token.
|
||||
Implements token rotation for security.
|
||||
"""
|
||||
token_record = await validate_refresh_token(request.refresh_token)
|
||||
|
||||
if not token_record:
|
||||
refresh_token, mode = read_refresh(request, body)
|
||||
if not refresh_token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired refresh token",
|
||||
)
|
||||
|
||||
# Get user from token record
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == token_record.user_id)
|
||||
rotation = await rotate_refresh_token(refresh_token)
|
||||
if not rotation:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired refresh token",
|
||||
)
|
||||
user = result.scalars().first()
|
||||
|
||||
user = await _load_user(rotation.user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not found",
|
||||
)
|
||||
|
||||
# Generate new access token
|
||||
strategy = get_jwt_strategy()
|
||||
access_token = await strategy.write_token(user)
|
||||
|
||||
# Rotate refresh token
|
||||
new_refresh_token = await rotate_refresh_token(token_record)
|
||||
|
||||
logger.info(f"Refreshed token for user {user.id}")
|
||||
|
||||
return RefreshTokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=new_refresh_token,
|
||||
return issue(
|
||||
response,
|
||||
mode,
|
||||
access=access_token,
|
||||
refresh=rotation.refresh_token,
|
||||
access_expires_at=access_expires_at(access_token),
|
||||
request=request,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/revoke", response_model=LogoutResponse)
|
||||
async def revoke_token(request: LogoutRequest):
|
||||
async def revoke_token(
|
||||
request: Request,
|
||||
response: Response,
|
||||
body: LogoutRequest | None = None,
|
||||
):
|
||||
"""
|
||||
Logout current device by revoking the provided refresh token.
|
||||
Does not require authentication - just the refresh token.
|
||||
"""
|
||||
revoked = await revoke_refresh_token(request.refresh_token)
|
||||
refresh_token, _mode = read_refresh(request, body)
|
||||
revoked = await revoke_refresh_token(refresh_token) if refresh_token else False
|
||||
clear_session(response, request)
|
||||
if revoked:
|
||||
logger.info("User logged out from current device - token revoked")
|
||||
else:
|
||||
|
|
@ -85,13 +193,185 @@ async def revoke_token(request: LogoutRequest):
|
|||
|
||||
@router.post("/logout-all", response_model=LogoutAllResponse)
|
||||
async def logout_all_devices(
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
request: Request,
|
||||
response: Response,
|
||||
body: LogoutRequest | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user_manager: UserManager = Depends(get_user_manager),
|
||||
):
|
||||
"""
|
||||
Logout from all devices by revoking all refresh tokens for the user.
|
||||
Requires valid access token.
|
||||
"""
|
||||
user = auth.user
|
||||
user: User | None = None
|
||||
try:
|
||||
auth = await get_auth_context(
|
||||
request, session=session, user_manager=user_manager
|
||||
)
|
||||
if auth.is_session:
|
||||
user = auth.user
|
||||
except HTTPException:
|
||||
user = None
|
||||
|
||||
if user is None:
|
||||
refresh_token, _mode = read_refresh(request, body)
|
||||
token_record = (
|
||||
await validate_refresh_token(refresh_token) if refresh_token else None
|
||||
)
|
||||
if token_record:
|
||||
user = await _load_user(token_record.user_id)
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired refresh token",
|
||||
)
|
||||
|
||||
await revoke_all_user_tokens(user.id)
|
||||
clear_session(response, request)
|
||||
logger.info(f"User {user.id} logged out from all devices")
|
||||
return LogoutAllResponse()
|
||||
|
||||
|
||||
@session_router.get("/session", response_model=SessionResponse)
|
||||
async def get_session(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
if auth.method == "pat":
|
||||
return SessionResponse(access_expires_at=None)
|
||||
|
||||
access_token = request.cookies.get(config.SESSION_COOKIE_NAME)
|
||||
if access_token is None:
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header:
|
||||
scheme, _, token = auth_header.partition(" ")
|
||||
if scheme.lower() == "bearer" and token:
|
||||
access_token = token
|
||||
|
||||
if access_token is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized"
|
||||
)
|
||||
return SessionResponse(access_expires_at=access_expires_at(access_token))
|
||||
|
||||
|
||||
@session_router.post("/desktop/login", response_model=RefreshTokenResponse)
|
||||
@limiter.limit("5/minute")
|
||||
async def desktop_password_login(
|
||||
request: Request,
|
||||
body: DesktopLoginRequest,
|
||||
user_manager: UserManager = Depends(get_user_manager),
|
||||
):
|
||||
if config.AUTH_TYPE == "GOOGLE":
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not found")
|
||||
if not config.REGISTRATION_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Registration is disabled",
|
||||
)
|
||||
|
||||
credentials = SimpleNamespace(username=body.email, password=body.password)
|
||||
user = await user_manager.authenticate(credentials)
|
||||
if user is None or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="LOGIN_BAD_CREDENTIALS",
|
||||
)
|
||||
|
||||
app_access_token = await get_jwt_strategy().write_token(user)
|
||||
app_refresh_token = await create_refresh_token(user.id)
|
||||
await user_manager.on_after_login(user, request, None)
|
||||
return RefreshTokenResponse(
|
||||
access_token=app_access_token,
|
||||
refresh_token=app_refresh_token,
|
||||
access_expires_at=access_expires_at(app_access_token),
|
||||
)
|
||||
|
||||
|
||||
@session_router.post("/desktop/session", response_model=RefreshTokenResponse)
|
||||
@limiter.limit("20/minute")
|
||||
async def create_desktop_session(
|
||||
request: Request,
|
||||
body: DesktopSessionRequest,
|
||||
user_manager: UserManager = Depends(get_user_manager),
|
||||
):
|
||||
parsed_redirect = urlparse(body.redirect_uri)
|
||||
try:
|
||||
redirect_port = parsed_redirect.port
|
||||
except ValueError:
|
||||
redirect_port = None
|
||||
if not (
|
||||
parsed_redirect.scheme == "http"
|
||||
and parsed_redirect.hostname in {"127.0.0.1", "::1"}
|
||||
and redirect_port is not None
|
||||
and parsed_redirect.path == "/callback"
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid redirect URI"
|
||||
)
|
||||
if not config.GOOGLE_DESKTOP_CLIENT_ID:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Desktop OAuth is not configured",
|
||||
)
|
||||
|
||||
token_payload = {
|
||||
"client_id": config.GOOGLE_DESKTOP_CLIENT_ID,
|
||||
"code": body.code,
|
||||
"code_verifier": body.code_verifier,
|
||||
"grant_type": "authorization_code",
|
||||
"redirect_uri": body.redirect_uri,
|
||||
}
|
||||
if config.GOOGLE_DESKTOP_CLIENT_SECRET:
|
||||
token_payload["client_secret"] = config.GOOGLE_DESKTOP_CLIENT_SECRET
|
||||
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
token_response = await client.post(
|
||||
"https://oauth2.googleapis.com/token", data=token_payload
|
||||
)
|
||||
if token_response.status_code >= 400:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="OAuth exchange failed"
|
||||
)
|
||||
token_data = token_response.json()
|
||||
|
||||
id_token = token_data.get("id_token")
|
||||
access_token = token_data.get("access_token")
|
||||
if not id_token or not access_token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="OAuth exchange failed"
|
||||
)
|
||||
|
||||
try:
|
||||
claims = google_id_token.verify_oauth2_token(
|
||||
id_token,
|
||||
google_requests.Request(),
|
||||
config.GOOGLE_DESKTOP_CLIENT_ID,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("Desktop Google id_token verification failed: %s", exc)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid Google identity token",
|
||||
) from exc
|
||||
|
||||
user = await resolve_google_user(
|
||||
user_manager=user_manager,
|
||||
request=request,
|
||||
google_access_token=access_token,
|
||||
claims=claims,
|
||||
expires_at=(
|
||||
int(datetime.now(UTC).timestamp()) + int(token_data["expires_in"])
|
||||
if token_data.get("expires_in")
|
||||
else None
|
||||
),
|
||||
google_refresh_token=token_data.get("refresh_token"),
|
||||
)
|
||||
app_access_token = await get_jwt_strategy().write_token(user)
|
||||
app_refresh_token = await create_refresh_token(user.id)
|
||||
return RefreshTokenResponse(
|
||||
access_token=app_access_token,
|
||||
refresh_token=app_refresh_token,
|
||||
access_expires_at=access_expires_at(app_access_token),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.agents.chat.runtime.path_resolver import virtual_path_to_doc
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import (
|
||||
Chunk,
|
||||
Document,
|
||||
|
|
@ -18,7 +18,6 @@ from app.db import (
|
|||
Permission,
|
||||
SearchSpace,
|
||||
SearchSpaceMembership,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.schemas import (
|
||||
|
|
@ -684,7 +683,6 @@ async def search_document_titles(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Lightweight document title search optimized for mention picker (@mentions).
|
||||
|
||||
|
|
@ -789,7 +787,6 @@ async def get_document_by_virtual_path(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Resolve a knowledge-base document by its agent-facing virtual path.
|
||||
|
||||
The agent renders every document under ``/documents/...`` with a
|
||||
|
|
@ -847,7 +844,6 @@ async def get_documents_status(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Batch status endpoint for documents in a search space.
|
||||
|
||||
|
|
@ -1071,7 +1067,6 @@ async def get_watched_folders(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Return root folders that are marked as watched (metadata->>'watched' = 'true')."""
|
||||
await check_permission(
|
||||
session,
|
||||
|
|
@ -1113,7 +1108,6 @@ async def get_document_chunks_paginated(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Paginated chunk loading for a document.
|
||||
Supports both page-based and offset-based access.
|
||||
|
|
@ -1175,7 +1169,6 @@ async def read_document(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Get a specific document by ID.
|
||||
Requires DOCUMENTS_READ permission for the search space.
|
||||
|
|
@ -1230,7 +1223,6 @@ async def update_document(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Update a document.
|
||||
Requires DOCUMENTS_UPDATE permission for the search space.
|
||||
|
|
@ -1290,7 +1282,6 @@ async def delete_document(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Delete a document.
|
||||
Requires DOCUMENTS_DELETE permission for the search space.
|
||||
|
|
@ -1536,7 +1527,6 @@ async def folder_mtime_check(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Pre-upload optimization: check which files need uploading based on mtime.
|
||||
|
||||
Returns the subset of relative paths where the file is new or has a
|
||||
|
|
@ -1754,7 +1744,6 @@ async def folder_unlink(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Handle file deletion events from the desktop watcher.
|
||||
|
||||
For each relative path, find the matching document and delete it.
|
||||
|
|
@ -1809,7 +1798,6 @@ async def folder_sync_finalize(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Finalize a full folder scan by deleting orphaned documents.
|
||||
|
||||
The client sends the complete list of relative paths currently in the
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from sqlalchemy import func, select
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import Chunk, Document, DocumentType, Permission, User, get_async_session
|
||||
from app.db import Chunk, Document, DocumentType, Permission, get_async_session
|
||||
from app.routes.reports_routes import (
|
||||
_FILE_EXTENSIONS,
|
||||
_MEDIA_TYPES,
|
||||
|
|
@ -50,7 +50,6 @@ async def get_editor_content(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Get document content for editing.
|
||||
|
||||
|
|
@ -182,7 +181,6 @@ async def download_document_markdown(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Download the full document content as a .md file.
|
||||
Reconstructs markdown from source_markdown or chunks.
|
||||
|
|
@ -337,7 +335,6 @@ async def export_document(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Export a document in the requested format (reuses the report export pipeline)."""
|
||||
await check_permission(
|
||||
session,
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from fastapi.responses import StreamingResponse
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import Permission, User, get_async_session
|
||||
from app.db import Permission, get_async_session
|
||||
from app.services.export_service import build_export_zip
|
||||
from app.users import get_auth_context
|
||||
from app.utils.rbac import check_permission
|
||||
|
|
@ -27,7 +27,6 @@ async def export_knowledge_base(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Export documents as a ZIP of markdown files preserving folder structure."""
|
||||
await check_permission(
|
||||
session,
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from sqlalchemy.future import select
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import Document, Folder, Permission, User, get_async_session
|
||||
from app.db import Document, Folder, Permission, get_async_session
|
||||
from app.schemas import (
|
||||
BulkDocumentMove,
|
||||
DocumentMove,
|
||||
|
|
@ -95,7 +95,6 @@ async def list_folders(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""List all folders in a search space (flat). Requires DOCUMENTS_READ permission."""
|
||||
try:
|
||||
await check_permission(
|
||||
|
|
@ -127,7 +126,6 @@ async def get_folder(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Get a single folder. Requires DOCUMENTS_READ permission."""
|
||||
try:
|
||||
folder = await session.get(Folder, folder_id)
|
||||
|
|
@ -158,7 +156,6 @@ async def get_folder_breadcrumb(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Get ancestor chain for breadcrumb display. Requires DOCUMENTS_READ permission."""
|
||||
try:
|
||||
folder = await session.get(Folder, folder_id)
|
||||
|
|
@ -203,7 +200,6 @@ async def stop_watching_folder(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Clear the watched flag from a folder's metadata."""
|
||||
folder = await session.get(Folder, folder_id)
|
||||
if not folder:
|
||||
|
|
@ -232,7 +228,6 @@ async def update_folder(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Rename a folder. Requires DOCUMENTS_UPDATE permission."""
|
||||
try:
|
||||
folder = await session.get(Folder, folder_id)
|
||||
|
|
@ -273,7 +268,6 @@ async def move_folder(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Move a folder to a new parent. Requires DOCUMENTS_UPDATE permission."""
|
||||
try:
|
||||
folder = await session.get(Folder, folder_id)
|
||||
|
|
@ -334,7 +328,6 @@ async def reorder_folder(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Reorder a folder among its siblings via fractional indexing. Requires DOCUMENTS_UPDATE."""
|
||||
try:
|
||||
folder = await session.get(Folder, folder_id)
|
||||
|
|
@ -376,7 +369,6 @@ async def delete_folder(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Mark documents for deletion and dispatch Celery to delete docs first, then folders."""
|
||||
try:
|
||||
folder = await session.get(Folder, folder_id)
|
||||
|
|
@ -451,7 +443,6 @@ async def move_document(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Move a document to a folder (or root). Requires DOCUMENTS_UPDATE permission."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
|
|
@ -498,7 +489,6 @@ async def bulk_move_documents(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Move multiple documents to a folder (or root). Requires DOCUMENTS_UPDATE permission."""
|
||||
try:
|
||||
if not request.document_ids:
|
||||
|
|
|
|||
|
|
@ -30,7 +30,6 @@ from app.db import (
|
|||
ExternalChatHealthStatus,
|
||||
ExternalChatPeerKind,
|
||||
ExternalChatPlatform,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.gateway.accounts import (
|
||||
|
|
@ -979,7 +978,6 @@ async def list_platforms(
|
|||
async def get_gateway_config(
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> dict[str, bool | str]:
|
||||
user = auth.user
|
||||
if not config.GATEWAY_ENABLED:
|
||||
return {
|
||||
"enabled": False,
|
||||
|
|
|
|||
|
|
@ -101,7 +101,6 @@ async def request_pairing_code(
|
|||
async def bridge_health(
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> dict[str, Any]:
|
||||
user = auth.user
|
||||
_ensure_baileys_enabled()
|
||||
adapter = WhatsAppBaileysAdapter()
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -24,7 +24,6 @@ from app.db import (
|
|||
Permission,
|
||||
SearchSpace,
|
||||
SearchSpaceMembership,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.schemas import (
|
||||
|
|
@ -224,6 +223,7 @@ async def _execute_image_generation(
|
|||
|
||||
# Fix relative URLs in response data (for the serving endpoint)
|
||||
from urllib.parse import urlparse
|
||||
|
||||
images = response_dict.get("data", [])
|
||||
provider_base_url = resolved_kwargs.get("api_base")
|
||||
for image in images:
|
||||
|
|
@ -422,7 +422,6 @@ async def get_image_generation(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Get a specific image generation by ID."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
|
|
@ -455,7 +454,6 @@ async def delete_image_generation(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Delete an image generation record."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ from app.db import (
|
|||
Permission,
|
||||
SearchSpace,
|
||||
SearchSpaceMembership,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.schemas import LogCreate, LogRead, LogUpdate
|
||||
|
|
@ -29,7 +28,6 @@ async def create_log(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Create a new log entry.
|
||||
Note: This is typically called internally. Requires LOGS_READ permission (since logs are usually system-generated).
|
||||
|
|
@ -141,7 +139,6 @@ async def read_log(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Get a specific log by ID.
|
||||
Requires LOGS_READ permission for the search space.
|
||||
|
|
@ -178,7 +175,6 @@ async def update_log(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Update a log entry.
|
||||
Requires LOGS_READ permission (logs are typically updated by system).
|
||||
|
|
@ -222,7 +218,6 @@ async def delete_log(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Delete a log entry.
|
||||
Requires LOGS_DELETE permission for the search space.
|
||||
|
|
@ -262,7 +257,6 @@ async def get_logs_summary(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Get a summary of logs for a search space in the last X hours.
|
||||
Requires LOGS_READ permission for the search space.
|
||||
|
|
|
|||
|
|
@ -325,7 +325,9 @@ async def _assert_connection_access(
|
|||
|
||||
|
||||
@router.get("/global-llm-config-status")
|
||||
async def global_llm_config_status(auth: AuthContext = Depends(require_session_context)):
|
||||
async def global_llm_config_status(
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
):
|
||||
del auth
|
||||
return {"exists": config.GLOBAL_LLM_CONFIG_FILE_EXISTS}
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from sqlalchemy import select
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import Document, DocumentType, Permission, User, get_async_session
|
||||
from app.db import Document, DocumentType, Permission, get_async_session
|
||||
from app.schemas import DocumentRead, PaginatedResponse
|
||||
from app.users import get_auth_context
|
||||
from app.utils.rbac import check_permission
|
||||
|
|
@ -102,7 +102,6 @@ async def list_notes(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
List all notes in a search space.
|
||||
|
||||
|
|
@ -196,7 +195,6 @@ async def delete_note(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Delete a note.
|
||||
|
||||
|
|
|
|||
|
|
@ -125,7 +125,6 @@ PERMISSION_DESCRIPTIONS = {
|
|||
async def list_all_permissions(
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
List all available permissions that can be assigned to roles.
|
||||
"""
|
||||
|
|
@ -162,7 +161,6 @@ async def create_role(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Create a new custom role in a search space.
|
||||
Requires ROLES_CREATE permission.
|
||||
|
|
@ -244,7 +242,6 @@ async def list_roles(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
List all roles in a search space.
|
||||
Requires ROLES_READ permission.
|
||||
|
|
@ -283,7 +280,6 @@ async def get_role(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Get a specific role by ID.
|
||||
Requires ROLES_READ permission.
|
||||
|
|
@ -329,7 +325,6 @@ async def update_role(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Update a role.
|
||||
Requires ROLES_UPDATE permission.
|
||||
|
|
@ -427,7 +422,6 @@ async def delete_role(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Delete a custom role.
|
||||
Requires ROLES_DELETE permission.
|
||||
|
|
@ -485,7 +479,6 @@ async def list_members(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
List all members of a search space.
|
||||
Requires MEMBERS_VIEW permission.
|
||||
|
|
@ -551,7 +544,6 @@ async def update_member_role(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Update a member's role.
|
||||
Requires MEMBERS_MANAGE_ROLES permission.
|
||||
|
|
@ -689,7 +681,6 @@ async def remove_member(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Remove a member from a search space.
|
||||
Requires MEMBERS_REMOVE permission.
|
||||
|
|
@ -814,7 +805,6 @@ async def list_invites(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
List all invites for a search space.
|
||||
Requires MEMBERS_INVITE permission.
|
||||
|
|
@ -854,7 +844,6 @@ async def update_invite(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Update an invite.
|
||||
Requires MEMBERS_INVITE permission.
|
||||
|
|
@ -921,7 +910,6 @@ async def revoke_invite(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Revoke (delete) an invite.
|
||||
Requires MEMBERS_INVITE permission.
|
||||
|
|
|
|||
|
|
@ -33,7 +33,6 @@ from app.db import (
|
|||
Report,
|
||||
SearchSpace,
|
||||
SearchSpaceMembership,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.schemas import ReportContentRead, ReportContentUpdate, ReportRead
|
||||
|
|
@ -161,7 +160,6 @@ async def _get_report_with_access(
|
|||
session: AsyncSession,
|
||||
auth: AuthContext,
|
||||
) -> Report:
|
||||
user = auth.user
|
||||
"""Fetch a report and verify the user belongs to its search space.
|
||||
|
||||
Raises HTTPException(404) if not found, HTTPException(403) if no access.
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from sqlalchemy.future import select
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import NewChatThread, Permission, User, get_async_session
|
||||
from app.db import NewChatThread, Permission, get_async_session
|
||||
from app.users import get_auth_context
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
|
|
@ -50,7 +50,6 @@ async def download_sandbox_file(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Download a file from the Daytona sandbox associated with a chat thread."""
|
||||
|
||||
from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.sandbox import (
|
||||
|
|
|
|||
|
|
@ -40,7 +40,6 @@ from app.db import (
|
|||
Permission,
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
User,
|
||||
async_session_maker,
|
||||
get_async_session,
|
||||
)
|
||||
|
|
@ -286,7 +285,6 @@ async def read_search_source_connectors(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
List all search source connectors for a search space.
|
||||
Requires CONNECTORS_READ permission.
|
||||
|
|
@ -330,7 +328,6 @@ async def read_search_source_connector(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Get a specific search source connector by ID.
|
||||
Requires CONNECTORS_READ permission.
|
||||
|
|
@ -565,7 +562,6 @@ async def delete_search_source_connector(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Delete a search source connector and all its associated documents.
|
||||
|
||||
|
|
@ -2735,7 +2731,6 @@ async def list_mcp_connectors(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
List all MCP connectors for a search space.
|
||||
|
||||
|
|
@ -2787,7 +2782,6 @@ async def get_mcp_connector(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Get a specific MCP connector by ID.
|
||||
|
||||
|
|
@ -2841,7 +2835,6 @@ async def update_mcp_connector(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Update an MCP connector.
|
||||
|
||||
|
|
@ -2918,7 +2911,6 @@ async def delete_mcp_connector(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Delete an MCP connector.
|
||||
|
||||
|
|
@ -2977,7 +2969,6 @@ async def test_mcp_server_connection(
|
|||
server_config: dict = Body(...),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Test connection to an MCP server and fetch available tools.
|
||||
|
||||
|
|
@ -3058,7 +3049,6 @@ async def get_drive_picker_token(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Return an OAuth access token + client ID for the Google Picker API."""
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(SearchSourceConnector.id == connector_id)
|
||||
|
|
|
|||
|
|
@ -279,7 +279,9 @@ async def update_search_space(
|
|||
) from e
|
||||
|
||||
|
||||
@router.put("/searchspaces/{search_space_id}/api-access", response_model=SearchSpaceRead)
|
||||
@router.put(
|
||||
"/searchspaces/{search_space_id}/api-access", response_model=SearchSpaceRead
|
||||
)
|
||||
async def update_search_space_api_access(
|
||||
search_space_id: int,
|
||||
body: SearchSpaceApiAccessUpdate,
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from pydantic import BaseModel
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import User, get_async_session
|
||||
from app.db import get_async_session
|
||||
from app.services.memory import (
|
||||
MemoryRead,
|
||||
MemoryScope,
|
||||
|
|
@ -32,7 +32,6 @@ async def get_team_memory(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
await check_search_space_access(session, auth, search_space_id)
|
||||
memory_md = await read_memory(
|
||||
scope=MemoryScope.TEAM,
|
||||
|
|
@ -49,7 +48,6 @@ async def update_team_memory(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
await check_search_space_access(session, auth, search_space_id)
|
||||
result = await save_memory(
|
||||
scope=MemoryScope.TEAM,
|
||||
|
|
@ -68,7 +66,6 @@ async def reset_team_memory(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
await check_search_space_access(session, auth, search_space_id)
|
||||
result = await reset_memory(
|
||||
scope=MemoryScope.TEAM,
|
||||
|
|
|
|||
34
surfsense_backend/app/routes/users_routes.py
Normal file
34
surfsense_backend/app/routes/users_routes.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
"""Cookie-aware user profile routes."""
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.schemas import UserRead, UserUpdate
|
||||
from app.users import (
|
||||
UserManager,
|
||||
get_auth_context,
|
||||
get_user_manager,
|
||||
require_session_context,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/users", tags=["users"])
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserRead)
|
||||
async def get_current_user_profile(
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
return auth.user
|
||||
|
||||
|
||||
@router.patch("/me", response_model=UserRead)
|
||||
async def update_current_user_profile(
|
||||
update: UserUpdate,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_session_context),
|
||||
user_manager: UserManager = Depends(get_user_manager),
|
||||
):
|
||||
updated_user = await user_manager.update(
|
||||
update, auth.user, safe=True, request=request
|
||||
)
|
||||
return updated_user
|
||||
|
|
@ -21,7 +21,6 @@ from app.db import (
|
|||
Permission,
|
||||
SearchSpace,
|
||||
SearchSpaceMembership,
|
||||
User,
|
||||
VideoPresentation,
|
||||
get_async_session,
|
||||
)
|
||||
|
|
@ -93,7 +92,6 @@ async def read_video_presentation(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Get a specific video presentation by ID.
|
||||
Requires authentication with VIDEO_PRESENTATIONS_READ permission.
|
||||
|
|
@ -137,7 +135,6 @@ async def delete_video_presentation(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Delete a video presentation.
|
||||
Requires VIDEO_PRESENTATIONS_DELETE permission for the search space.
|
||||
|
|
@ -181,7 +178,6 @@ async def stream_slide_audio(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Stream the audio file for a specific slide in a video presentation.
|
||||
The slide_number is 1-based. Audio path is read from the slides JSONB.
|
||||
|
|
|
|||
31
surfsense_backend/app/routes/zero_context_routes.py
Normal file
31
surfsense_backend/app/routes/zero_context_routes.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
"""Zero sync authentication context routes."""
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import get_async_session
|
||||
from app.users import get_auth_context
|
||||
from app.utils.rbac import get_allowed_read_space_ids
|
||||
|
||||
router = APIRouter(prefix="/zero", tags=["zero"])
|
||||
|
||||
|
||||
class ZeroContextResponse(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
user_id: str = Field(alias="userId")
|
||||
allowed_space_ids: list[int] = Field(alias="allowedSpaceIds")
|
||||
|
||||
|
||||
@router.get("/context", response_model=ZeroContextResponse)
|
||||
async def get_zero_context(
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> ZeroContextResponse:
|
||||
allowed_space_ids = await get_allowed_read_space_ids(session, auth)
|
||||
return ZeroContextResponse(
|
||||
user_id=str(auth.user.id),
|
||||
allowed_space_ids=allowed_space_ids,
|
||||
)
|
||||
|
|
@ -242,9 +242,9 @@ __all__ = [
|
|||
"SearchSourceConnectorCreate",
|
||||
"SearchSourceConnectorRead",
|
||||
"SearchSourceConnectorUpdate",
|
||||
"SearchSpaceApiAccessUpdate",
|
||||
# Search space schemas
|
||||
"SearchSpaceBase",
|
||||
"SearchSpaceApiAccessUpdate",
|
||||
"SearchSpaceCreate",
|
||||
"SearchSpaceRead",
|
||||
"SearchSpaceUpdate",
|
||||
|
|
|
|||
|
|
@ -6,21 +6,22 @@ from pydantic import BaseModel
|
|||
class RefreshTokenRequest(BaseModel):
|
||||
"""Request body for token refresh endpoint."""
|
||||
|
||||
refresh_token: str
|
||||
refresh_token: str | None = None
|
||||
|
||||
|
||||
class RefreshTokenResponse(BaseModel):
|
||||
"""Response from token refresh endpoint."""
|
||||
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
refresh_token: str | None = None
|
||||
token_type: str = "bearer"
|
||||
access_expires_at: int
|
||||
|
||||
|
||||
class LogoutRequest(BaseModel):
|
||||
"""Request body for logout endpoint (current device)."""
|
||||
|
||||
refresh_token: str
|
||||
refresh_token: str | None = None
|
||||
|
||||
|
||||
class LogoutResponse(BaseModel):
|
||||
|
|
@ -33,3 +34,19 @@ class LogoutAllResponse(BaseModel):
|
|||
"""Response from logout all devices endpoint."""
|
||||
|
||||
detail: str = "Successfully logged out from all devices"
|
||||
|
||||
|
||||
class SessionResponse(BaseModel):
|
||||
authenticated: bool = True
|
||||
access_expires_at: int | None = None
|
||||
|
||||
|
||||
class DesktopSessionRequest(BaseModel):
|
||||
code: str
|
||||
code_verifier: str
|
||||
redirect_uri: str
|
||||
|
||||
|
||||
class DesktopLoginRequest(BaseModel):
|
||||
email: str
|
||||
password: str
|
||||
|
|
|
|||
|
|
@ -435,7 +435,6 @@ async def list_snapshots_for_thread(
|
|||
thread_id: int,
|
||||
auth: AuthContext,
|
||||
) -> list[dict]:
|
||||
user = auth.user
|
||||
"""List all public snapshots for a thread."""
|
||||
from app.config import config
|
||||
|
||||
|
|
@ -482,7 +481,6 @@ async def list_snapshots_for_search_space(
|
|||
search_space_id: int,
|
||||
auth: AuthContext,
|
||||
) -> list[dict]:
|
||||
user = auth.user
|
||||
"""List all public snapshots for a search space."""
|
||||
from app.config import config
|
||||
|
||||
|
|
@ -540,7 +538,6 @@ async def delete_snapshot(
|
|||
snapshot_id: int,
|
||||
auth: AuthContext,
|
||||
) -> bool:
|
||||
user = auth.user
|
||||
"""Delete a specific snapshot. Only thread owner can delete."""
|
||||
# Get snapshot with thread
|
||||
result = await session.execute(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,34 @@
|
|||
"""Celery task for pruning expired refresh-token rows."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from sqlalchemy import delete, or_
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.config import config
|
||||
from app.db import RefreshToken, async_session_maker
|
||||
|
||||
|
||||
@celery_app.task(name="purge_refresh_tokens")
|
||||
def purge_refresh_tokens() -> int:
|
||||
return asyncio.run(_purge_refresh_tokens())
|
||||
|
||||
|
||||
async def _purge_refresh_tokens() -> int:
|
||||
now = datetime.now(UTC)
|
||||
revoked_cutoff = now - timedelta(seconds=config.REFRESH_ROTATION_GRACE_SECONDS)
|
||||
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(
|
||||
delete(RefreshToken).where(
|
||||
or_(
|
||||
RefreshToken.expires_at < now,
|
||||
RefreshToken.revoked_at < revoked_cutoff,
|
||||
)
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
return result.rowcount or 0
|
||||
|
|
@ -3,6 +3,7 @@ import uuid
|
|||
from datetime import UTC, datetime
|
||||
|
||||
import httpx
|
||||
import jwt
|
||||
from fastapi import Depends, HTTPException, Request, Response, status
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, models
|
||||
|
|
@ -12,11 +13,12 @@ from fastapi_users.authentication import (
|
|||
JWTStrategy,
|
||||
)
|
||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||
from pydantic import BaseModel
|
||||
from fastapi_users.jwt import generate_jwt
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.auth.session_cookies import access_expires_at, write_session
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
Prompt,
|
||||
|
|
@ -36,12 +38,6 @@ from app.utils.refresh_tokens import create_refresh_token
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BearerResponse(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str
|
||||
|
||||
|
||||
SECRET = config.SECRET_KEY
|
||||
|
||||
|
||||
|
|
@ -230,8 +226,23 @@ async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db
|
|||
yield UserManager(user_db)
|
||||
|
||||
|
||||
class IatJWTStrategy(JWTStrategy[models.UP, models.ID]):
|
||||
async def write_token(self, user: models.UP) -> str:
|
||||
data = {
|
||||
"sub": str(user.id),
|
||||
"aud": self.token_audience,
|
||||
"iat": int(datetime.now(UTC).timestamp()),
|
||||
}
|
||||
return generate_jwt(
|
||||
data,
|
||||
self.encode_key,
|
||||
self.lifetime_seconds,
|
||||
algorithm=self.algorithm,
|
||||
)
|
||||
|
||||
|
||||
def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
|
||||
return JWTStrategy(
|
||||
return IatJWTStrategy(
|
||||
secret=SECRET,
|
||||
lifetime_seconds=config.ACCESS_TOKEN_LIFETIME_SECONDS,
|
||||
)
|
||||
|
|
@ -260,9 +271,6 @@ def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
|
|||
# BEARER AUTH CODE.
|
||||
class CustomBearerTransport(BearerTransport):
|
||||
async def get_login_response(self, token: str) -> Response:
|
||||
import jwt
|
||||
|
||||
# Decode JWT to get user_id for refresh token creation
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, SECRET, algorithms=["HS256"], options={"verify_aud": False}
|
||||
|
|
@ -271,24 +279,26 @@ class CustomBearerTransport(BearerTransport):
|
|||
refresh_token = await create_refresh_token(user_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create refresh token: {e}")
|
||||
# Fall back to response without refresh token
|
||||
refresh_token = ""
|
||||
|
||||
bearer_response = BearerResponse(
|
||||
access_token=token,
|
||||
refresh_token=refresh_token,
|
||||
token_type="bearer",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to create session",
|
||||
) from e
|
||||
|
||||
if config.AUTH_TYPE == "GOOGLE":
|
||||
redirect_url = (
|
||||
f"{config.NEXT_FRONTEND_URL}/auth/callback"
|
||||
f"?token={bearer_response.access_token}"
|
||||
f"&refresh_token={bearer_response.refresh_token}"
|
||||
response = RedirectResponse(
|
||||
f"{config.NEXT_FRONTEND_URL}/dashboard",
|
||||
status_code=302,
|
||||
)
|
||||
return RedirectResponse(redirect_url, status_code=302)
|
||||
else:
|
||||
return JSONResponse(bearer_response.model_dump())
|
||||
response = JSONResponse(
|
||||
{
|
||||
"authenticated": True,
|
||||
"access_expires_at": access_expires_at(token),
|
||||
}
|
||||
)
|
||||
|
||||
write_session(response, token, refresh_token)
|
||||
return response
|
||||
|
||||
|
||||
bearer_transport = CustomBearerTransport(tokenUrl="auth/jwt/login")
|
||||
|
|
@ -303,6 +313,22 @@ auth_backend = AuthenticationBackend(
|
|||
fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])
|
||||
|
||||
|
||||
def _token_meets_epoch(token: str) -> bool:
|
||||
min_issued_at = config.MIN_ISSUED_AT
|
||||
if min_issued_at <= 0:
|
||||
return True
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, SECRET, algorithms=["HS256"], options={"verify_aud": False}
|
||||
)
|
||||
except jwt.PyJWTError:
|
||||
return False
|
||||
|
||||
issued_at = payload.get("iat")
|
||||
return isinstance(issued_at, (int, float)) and int(issued_at) >= min_issued_at
|
||||
|
||||
|
||||
async def get_auth_context(
|
||||
request: Request,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
|
|
@ -315,38 +341,42 @@ async def get_auth_context(
|
|||
receives the full SurfSense principal instead of a bare User.
|
||||
"""
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if not auth_header:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Unauthorized",
|
||||
)
|
||||
if auth_header:
|
||||
scheme, _, credential = auth_header.partition(" ")
|
||||
is_bearer = scheme.lower() == "bearer" and bool(credential)
|
||||
token = credential if is_bearer else auth_header.strip()
|
||||
|
||||
scheme, _, token = auth_header.partition(" ")
|
||||
if scheme.lower() != "bearer" or not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Unauthorized",
|
||||
)
|
||||
if token.startswith(PAT_PREFIX):
|
||||
pat = await resolve_pat(session, token)
|
||||
if pat and pat.user and pat.user.is_active:
|
||||
maybe_touch_last_used(pat)
|
||||
return AuthContext.pat_auth(pat.user, pat)
|
||||
|
||||
if token.startswith(PAT_PREFIX):
|
||||
pat = await resolve_pat(session, token)
|
||||
if pat and pat.user and pat.user.is_active:
|
||||
maybe_touch_last_used(pat)
|
||||
return AuthContext.pat_auth(pat.user, pat)
|
||||
if is_bearer and _token_meets_epoch(token):
|
||||
try:
|
||||
user = await get_jwt_strategy().read_token(token, user_manager)
|
||||
except Exception:
|
||||
logger.exception("Failed to read bearer access token")
|
||||
user = None
|
||||
|
||||
try:
|
||||
user = await get_jwt_strategy().read_token(token, user_manager)
|
||||
except Exception:
|
||||
logger.exception("Failed to read access token")
|
||||
user = None
|
||||
if user and user.is_active:
|
||||
return AuthContext.session(user)
|
||||
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Unauthorized",
|
||||
)
|
||||
cookie_token = request.cookies.get(config.SESSION_COOKIE_NAME)
|
||||
if cookie_token and _token_meets_epoch(cookie_token):
|
||||
try:
|
||||
user = await get_jwt_strategy().read_token(cookie_token, user_manager)
|
||||
except Exception:
|
||||
logger.exception("Failed to read session cookie access token")
|
||||
user = None
|
||||
|
||||
return AuthContext.session(user)
|
||||
if user and user.is_active:
|
||||
return AuthContext.session(user)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Unauthorized",
|
||||
)
|
||||
|
||||
|
||||
async def allow_any_principal(
|
||||
|
|
@ -371,6 +401,3 @@ async def require_session_context(
|
|||
detail="This action requires an interactive session",
|
||||
)
|
||||
return auth
|
||||
|
||||
|
||||
current_optional_user = fastapi_users.current_user(active=True, optional=True)
|
||||
|
|
|
|||
|
|
@ -23,11 +23,15 @@ logger = logging.getLogger(__name__)
|
|||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _render_inline_content(content: list[dict[str, Any]] | None) -> str:
|
||||
def _render_inline_content(
|
||||
content: list[dict[str, Any]] | None,
|
||||
inherited_styles: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""Convert BlockNote inline content array to a markdown string."""
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
inherited_styles = inherited_styles or {}
|
||||
parts: list[str] = []
|
||||
for item in content:
|
||||
if not isinstance(item, dict):
|
||||
|
|
@ -37,7 +41,10 @@ def _render_inline_content(content: list[dict[str, Any]] | None) -> str:
|
|||
|
||||
if item_type == "text":
|
||||
text = item.get("text", "")
|
||||
styles: dict[str, Any] = item.get("styles", {})
|
||||
styles: dict[str, Any] = {
|
||||
**inherited_styles,
|
||||
**item.get("styles", {}),
|
||||
}
|
||||
|
||||
# Apply inline styles (order: code first so nested marks don't break it)
|
||||
if styles.get("code"):
|
||||
|
|
@ -56,7 +63,11 @@ def _render_inline_content(content: list[dict[str, Any]] | None) -> str:
|
|||
elif item_type == "link":
|
||||
href = item.get("href", "")
|
||||
link_content = item.get("content", [])
|
||||
link_text = _render_inline_content(link_content) if link_content else href
|
||||
link_text = (
|
||||
_render_inline_content(link_content, inherited_styles)
|
||||
if link_content
|
||||
else href
|
||||
)
|
||||
parts.append(f"[{link_text}]({href})")
|
||||
|
||||
else:
|
||||
|
|
@ -89,6 +100,7 @@ def _render_block(
|
|||
"""
|
||||
block_type = block.get("type", "paragraph")
|
||||
props: dict[str, Any] = block.get("props", {})
|
||||
styles: dict[str, Any] = block.get("styles", {})
|
||||
content = block.get("content")
|
||||
children: list[dict[str, Any]] = block.get("children", [])
|
||||
prefix = " " * indent # 2-space indent per nesting level
|
||||
|
|
@ -98,17 +110,17 @@ def _render_block(
|
|||
# --- Block type handlers ---
|
||||
|
||||
if block_type == "paragraph":
|
||||
text = _render_inline_content(content) if content else ""
|
||||
text = _render_inline_content(content, styles) if content else ""
|
||||
lines.append(f"{prefix}{text}")
|
||||
|
||||
elif block_type == "heading":
|
||||
level = props.get("level", 1)
|
||||
hashes = "#" * min(max(level, 1), 6)
|
||||
text = _render_inline_content(content) if content else ""
|
||||
text = _render_inline_content(content, styles) if content else ""
|
||||
lines.append(f"{prefix}{hashes} {text}")
|
||||
|
||||
elif block_type == "bulletListItem":
|
||||
text = _render_inline_content(content) if content else ""
|
||||
text = _render_inline_content(content, styles) if content else ""
|
||||
lines.append(f"{prefix}- {text}")
|
||||
|
||||
elif block_type == "numberedListItem":
|
||||
|
|
@ -118,13 +130,13 @@ def _render_block(
|
|||
numbered_list_counter = int(start)
|
||||
else:
|
||||
numbered_list_counter += 1
|
||||
text = _render_inline_content(content) if content else ""
|
||||
text = _render_inline_content(content, styles) if content else ""
|
||||
lines.append(f"{prefix}{numbered_list_counter}. {text}")
|
||||
|
||||
elif block_type == "checkListItem":
|
||||
checked = props.get("checked", False)
|
||||
marker = "[x]" if checked else "[ ]"
|
||||
text = _render_inline_content(content) if content else ""
|
||||
text = _render_inline_content(content, styles) if content else ""
|
||||
lines.append(f"{prefix}- {marker} {text}")
|
||||
|
||||
elif block_type == "codeBlock":
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ logger = logging.getLogger(__name__)
|
|||
PAT_PREFIX = "ss_pat_"
|
||||
PAT_TOKEN_BYTES = 32
|
||||
LAST_USED_THROTTLE = timedelta(minutes=10)
|
||||
_last_used_tasks: set[asyncio.Task[None]] = set()
|
||||
|
||||
|
||||
def generate_pat() -> str:
|
||||
|
|
@ -70,4 +71,6 @@ def maybe_touch_last_used(pat: PersonalAccessToken) -> None:
|
|||
if last_used_at is not None and now - last_used_at < LAST_USED_THROTTLE:
|
||||
return
|
||||
|
||||
asyncio.create_task(_touch_last_used(pat.id))
|
||||
task = asyncio.create_task(_touch_last_used(pat.id))
|
||||
_last_used_tasks.add(task)
|
||||
task.add_done_callback(_last_used_tasks.discard)
|
||||
|
|
|
|||
|
|
@ -80,6 +80,28 @@ async def get_user_permissions(
|
|||
return []
|
||||
|
||||
|
||||
async def get_allowed_read_space_ids(
|
||||
session: AsyncSession,
|
||||
auth: AuthContext,
|
||||
) -> list[int]:
|
||||
"""Return search spaces the principal may read through sync transports.
|
||||
|
||||
This mirrors the basic REST search-space access rule: membership is required,
|
||||
and PAT principals are additionally constrained by the per-space API gate.
|
||||
"""
|
||||
stmt = (
|
||||
select(SearchSpaceMembership.search_space_id)
|
||||
.join(SearchSpace, SearchSpace.id == SearchSpaceMembership.search_space_id)
|
||||
.filter(SearchSpaceMembership.user_id == auth.user.id)
|
||||
.order_by(SearchSpaceMembership.search_space_id)
|
||||
)
|
||||
if auth.is_gated:
|
||||
stmt = stmt.filter(SearchSpace.api_access_enabled == True) # noqa: E712
|
||||
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
async def _enforce_api_access_gate(
|
||||
session: AsyncSession,
|
||||
auth: AuthContext,
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import hashlib
|
|||
import logging
|
||||
import secrets
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from sqlalchemy import select, update
|
||||
|
|
@ -14,6 +15,13 @@ from app.db import RefreshToken, async_session_maker
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RefreshRotationResult:
|
||||
user_id: uuid.UUID
|
||||
refresh_token: str | None
|
||||
access_only: bool = False
|
||||
|
||||
|
||||
def generate_refresh_token() -> str:
|
||||
"""Generate a cryptographically secure refresh token."""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
|
@ -27,6 +35,7 @@ def hash_token(token: str) -> str:
|
|||
async def create_refresh_token(
|
||||
user_id: uuid.UUID,
|
||||
family_id: uuid.UUID | None = None,
|
||||
absolute_expiry: datetime | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Create and store a new refresh token for a user.
|
||||
|
|
@ -40,8 +49,14 @@ async def create_refresh_token(
|
|||
"""
|
||||
token = generate_refresh_token()
|
||||
token_hash = hash_token(token)
|
||||
expires_at = datetime.now(UTC) + timedelta(
|
||||
seconds=config.REFRESH_TOKEN_LIFETIME_SECONDS
|
||||
now = datetime.now(UTC)
|
||||
if absolute_expiry is None:
|
||||
absolute_expiry = now + timedelta(
|
||||
seconds=config.REFRESH_ABSOLUTE_LIFETIME_SECONDS
|
||||
)
|
||||
expires_at = min(
|
||||
now + timedelta(seconds=config.REFRESH_TOKEN_LIFETIME_SECONDS),
|
||||
absolute_expiry,
|
||||
)
|
||||
|
||||
if family_id is None:
|
||||
|
|
@ -53,6 +68,7 @@ async def create_refresh_token(
|
|||
token_hash=token_hash,
|
||||
expires_at=expires_at,
|
||||
family_id=family_id,
|
||||
absolute_expiry=absolute_expiry,
|
||||
)
|
||||
session.add(refresh_token)
|
||||
await session.commit()
|
||||
|
|
@ -61,15 +77,7 @@ async def create_refresh_token(
|
|||
|
||||
|
||||
async def validate_refresh_token(token: str) -> RefreshToken | None:
|
||||
"""
|
||||
Validate a refresh token. Handles reuse detection.
|
||||
|
||||
Args:
|
||||
token: The plaintext refresh token
|
||||
|
||||
Returns:
|
||||
RefreshToken if valid, None otherwise
|
||||
"""
|
||||
"""Validate an active refresh token without rotating it."""
|
||||
token_hash = hash_token(token)
|
||||
|
||||
async with async_session_maker() as session:
|
||||
|
|
@ -81,43 +89,87 @@ async def validate_refresh_token(token: str) -> RefreshToken | None:
|
|||
if not refresh_token:
|
||||
return None
|
||||
|
||||
# Reuse detection: revoked token used while family has active tokens
|
||||
if refresh_token.is_revoked:
|
||||
active = await session.execute(
|
||||
select(RefreshToken).where(
|
||||
RefreshToken.family_id == refresh_token.family_id,
|
||||
RefreshToken.is_revoked == False, # noqa: E712
|
||||
RefreshToken.expires_at > datetime.now(UTC),
|
||||
)
|
||||
now = datetime.now(UTC)
|
||||
if (
|
||||
refresh_token.revoked_at is not None
|
||||
or now >= refresh_token.expires_at
|
||||
or (
|
||||
refresh_token.absolute_expiry is not None
|
||||
and now >= refresh_token.absolute_expiry
|
||||
)
|
||||
if active.scalars().first():
|
||||
# Revoke entire family
|
||||
await session.execute(
|
||||
update(RefreshToken)
|
||||
.where(RefreshToken.family_id == refresh_token.family_id)
|
||||
.values(is_revoked=True)
|
||||
)
|
||||
await session.commit()
|
||||
logger.warning(f"Token reuse detected for user {refresh_token.user_id}")
|
||||
return None
|
||||
|
||||
if refresh_token.is_expired:
|
||||
):
|
||||
return None
|
||||
|
||||
return refresh_token
|
||||
|
||||
|
||||
async def rotate_refresh_token(old_token: RefreshToken) -> str:
|
||||
"""Revoke old token and create new one in same family."""
|
||||
async with async_session_maker() as session:
|
||||
await session.execute(
|
||||
update(RefreshToken)
|
||||
.where(RefreshToken.id == old_token.id)
|
||||
.values(is_revoked=True)
|
||||
)
|
||||
await session.commit()
|
||||
async def rotate_refresh_token(token: str) -> RefreshRotationResult | None:
|
||||
"""Atomically rotate a refresh token with access-only grace."""
|
||||
token_hash = hash_token(token)
|
||||
now = datetime.now(UTC)
|
||||
grace_window = timedelta(seconds=config.REFRESH_ROTATION_GRACE_SECONDS)
|
||||
|
||||
return await create_refresh_token(old_token.user_id, old_token.family_id)
|
||||
async with async_session_maker() as session:
|
||||
async with session.begin():
|
||||
result = await session.execute(
|
||||
select(RefreshToken)
|
||||
.where(RefreshToken.token_hash == token_hash)
|
||||
.with_for_update()
|
||||
)
|
||||
refresh_token = result.scalars().first()
|
||||
|
||||
if not refresh_token:
|
||||
return None
|
||||
user_id = refresh_token.user_id
|
||||
|
||||
if refresh_token.revoked_at is not None:
|
||||
if (
|
||||
now - refresh_token.revoked_at <= grace_window
|
||||
and now < refresh_token.expires_at
|
||||
):
|
||||
return RefreshRotationResult(
|
||||
user_id=user_id,
|
||||
refresh_token=None,
|
||||
access_only=True,
|
||||
)
|
||||
|
||||
await session.execute(
|
||||
update(RefreshToken)
|
||||
.where(RefreshToken.family_id == refresh_token.family_id)
|
||||
.values(revoked_at=now, expires_at=now)
|
||||
)
|
||||
logger.warning(f"Token reuse detected for user {user_id}")
|
||||
return None
|
||||
|
||||
if now >= refresh_token.expires_at:
|
||||
return None
|
||||
|
||||
family_cap = refresh_token.absolute_expiry or (
|
||||
now + timedelta(seconds=config.REFRESH_ABSOLUTE_LIFETIME_SECONDS)
|
||||
)
|
||||
if now >= family_cap:
|
||||
return None
|
||||
|
||||
new_plaintext = generate_refresh_token()
|
||||
child = RefreshToken(
|
||||
user_id=user_id,
|
||||
token_hash=hash_token(new_plaintext),
|
||||
expires_at=min(
|
||||
now + timedelta(seconds=config.REFRESH_TOKEN_LIFETIME_SECONDS),
|
||||
family_cap,
|
||||
),
|
||||
family_id=refresh_token.family_id,
|
||||
absolute_expiry=family_cap,
|
||||
)
|
||||
session.add(child)
|
||||
refresh_token.revoked_at = now
|
||||
refresh_token.absolute_expiry = family_cap
|
||||
|
||||
return RefreshRotationResult(
|
||||
user_id=user_id,
|
||||
refresh_token=new_plaintext,
|
||||
access_only=False,
|
||||
)
|
||||
|
||||
|
||||
async def revoke_refresh_token(token: str) -> bool:
|
||||
|
|
@ -131,12 +183,13 @@ async def revoke_refresh_token(token: str) -> bool:
|
|||
True if token was found and revoked, False otherwise
|
||||
"""
|
||||
token_hash = hash_token(token)
|
||||
now = datetime.now(UTC)
|
||||
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(
|
||||
update(RefreshToken)
|
||||
.where(RefreshToken.token_hash == token_hash)
|
||||
.values(is_revoked=True)
|
||||
.values(revoked_at=now, expires_at=now)
|
||||
)
|
||||
await session.commit()
|
||||
return result.rowcount > 0
|
||||
|
|
@ -144,10 +197,11 @@ async def revoke_refresh_token(token: str) -> bool:
|
|||
|
||||
async def revoke_all_user_tokens(user_id: uuid.UUID) -> None:
|
||||
"""Revoke all refresh tokens for a user (logout all devices)."""
|
||||
now = datetime.now(UTC)
|
||||
async with async_session_maker() as session:
|
||||
await session.execute(
|
||||
update(RefreshToken)
|
||||
.where(RefreshToken.user_id == user_id)
|
||||
.values(is_revoked=True)
|
||||
.values(revoked_at=now, expires_at=now)
|
||||
)
|
||||
await session.commit()
|
||||
|
|
|
|||
|
|
@ -52,6 +52,16 @@ AUTOMATION_RUN_COLS = [
|
|||
"created_at",
|
||||
]
|
||||
|
||||
AUTOMATION_COLS = [
|
||||
"id",
|
||||
"search_space_id",
|
||||
]
|
||||
|
||||
NEW_CHAT_THREAD_COLS = [
|
||||
"id",
|
||||
"search_space_id",
|
||||
]
|
||||
|
||||
# Enough to drive the lifecycle UI by push: status, the reviewable brief, and
|
||||
# its version. The bulky source_content and transcript are deliberately excluded
|
||||
# and fetched over REST when a gate opens.
|
||||
|
|
@ -73,10 +83,12 @@ ZERO_PUBLICATION: Mapping[str, Sequence[str] | None] = {
|
|||
"documents": DOCUMENT_COLS,
|
||||
"folders": None,
|
||||
"search_source_connectors": None,
|
||||
"new_chat_threads": NEW_CHAT_THREAD_COLS,
|
||||
"new_chat_messages": None,
|
||||
"chat_comments": None,
|
||||
"chat_session_state": None,
|
||||
"user": USER_COLS,
|
||||
"automations": AUTOMATION_COLS,
|
||||
"automation_runs": AUTOMATION_RUN_COLS,
|
||||
"podcasts": PODCAST_COLS,
|
||||
}
|
||||
|
|
|
|||
69
surfsense_backend/scripts/revoke_refresh_tokens_cutover.py
Normal file
69
surfsense_backend/scripts/revoke_refresh_tokens_cutover.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
"""One-shot cutover helper to revoke every refresh token.
|
||||
|
||||
Run with --yes during the auth-hardening cutover, alongside setting
|
||||
MIN_ISSUED_AT to the deploy epoch.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.db import async_session_maker
|
||||
|
||||
|
||||
async def _count_active_tokens() -> int:
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT count(*)
|
||||
FROM refresh_tokens
|
||||
WHERE revoked_at IS NULL
|
||||
AND expires_at > NOW()
|
||||
"""
|
||||
)
|
||||
)
|
||||
return int(result.scalar_one())
|
||||
|
||||
|
||||
async def _revoke_all_tokens() -> int:
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE refresh_tokens
|
||||
SET revoked_at = NOW(),
|
||||
expires_at = NOW()
|
||||
WHERE revoked_at IS NULL
|
||||
OR expires_at > NOW()
|
||||
"""
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
return int(result.rowcount or 0)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--yes",
|
||||
action="store_true",
|
||||
help="Actually revoke tokens. Without this flag the command is a dry run.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
active_count = await _count_active_tokens()
|
||||
if not args.yes:
|
||||
print(f"Dry run: {active_count} active refresh token(s) would be revoked.")
|
||||
print("Re-run with --yes during the auth-hardening cutover to revoke them.")
|
||||
return
|
||||
|
||||
updated_count = await _revoke_all_tokens()
|
||||
print(f"Revoked {updated_count} refresh token row(s).")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
@ -8,7 +8,6 @@ webhook fulfillment (idempotent), and the reconciliation fallback.
|
|||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import asyncpg
|
||||
import httpx
|
||||
|
|
@ -63,18 +62,13 @@ def _extract_access_token(response: httpx.Response) -> str | None:
|
|||
if response.status_code == 200:
|
||||
return response.json()["access_token"]
|
||||
|
||||
if response.status_code == 302:
|
||||
location = response.headers.get("location", "")
|
||||
return parse_qs(urlparse(location).query).get("token", [None])[0]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def _authenticate_test_user(client: httpx.AsyncClient) -> str:
|
||||
response = await client.post(
|
||||
"/auth/jwt/login",
|
||||
data={"username": TEST_EMAIL, "password": TEST_PASSWORD},
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
"/auth/desktop/login",
|
||||
json={"email": TEST_EMAIL, "password": TEST_PASSWORD},
|
||||
)
|
||||
token = _extract_access_token(response)
|
||||
if token:
|
||||
|
|
@ -89,9 +83,8 @@ async def _authenticate_test_user(client: httpx.AsyncClient) -> str:
|
|||
)
|
||||
|
||||
response = await client.post(
|
||||
"/auth/jwt/login",
|
||||
data={"username": TEST_EMAIL, "password": TEST_PASSWORD},
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
"/auth/desktop/login",
|
||||
json={"email": TEST_EMAIL, "password": TEST_PASSWORD},
|
||||
)
|
||||
token = _extract_access_token(response)
|
||||
assert token, f"Login failed ({response.status_code}): {response.text}"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,90 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
from fastapi import Request, Response
|
||||
|
||||
from app.auth.session_cookies import TransportMode, issue, read_refresh
|
||||
from app.config import config
|
||||
|
||||
|
||||
def _request_with_refresh_cookie(token: str) -> Request:
|
||||
scope = {
|
||||
"type": "http",
|
||||
"method": "POST",
|
||||
"path": "/auth/jwt/refresh",
|
||||
"headers": [(b"cookie", f"{config.REFRESH_COOKIE_NAME}={token}".encode())],
|
||||
"scheme": "https",
|
||||
"server": ("testserver", 443),
|
||||
}
|
||||
return Request(scope)
|
||||
|
||||
|
||||
def test_cookie_transport_sets_cookies_without_body_tokens():
|
||||
response = Response()
|
||||
|
||||
body = issue(
|
||||
response,
|
||||
TransportMode.COOKIE,
|
||||
access="access-token",
|
||||
refresh="refresh-token",
|
||||
access_expires_at=123,
|
||||
)
|
||||
|
||||
assert "access_token" not in body
|
||||
assert "refresh_token" not in body
|
||||
assert body == {"authenticated": True, "access_expires_at": 123}
|
||||
|
||||
set_cookie_headers = response.headers.getlist("set-cookie")
|
||||
assert any(config.SESSION_COOKIE_NAME in header for header in set_cookie_headers)
|
||||
assert any(config.REFRESH_COOKIE_NAME in header for header in set_cookie_headers)
|
||||
|
||||
|
||||
def test_cookie_transport_re_stamps_access_without_refresh_body_or_cookie():
|
||||
response = Response()
|
||||
|
||||
body = issue(
|
||||
response,
|
||||
TransportMode.COOKIE,
|
||||
access="access-token",
|
||||
refresh=None,
|
||||
access_expires_at=123,
|
||||
)
|
||||
|
||||
assert "access_token" not in body
|
||||
assert "refresh_token" not in body
|
||||
|
||||
set_cookie_headers = response.headers.getlist("set-cookie")
|
||||
assert any(config.SESSION_COOKIE_NAME in header for header in set_cookie_headers)
|
||||
assert not any(
|
||||
config.REFRESH_COOKIE_NAME in header for header in set_cookie_headers
|
||||
)
|
||||
|
||||
|
||||
def test_header_transport_returns_body_tokens_without_cookies():
|
||||
response = Response()
|
||||
|
||||
body = issue(
|
||||
response,
|
||||
TransportMode.HEADER,
|
||||
access="access-token",
|
||||
refresh="refresh-token",
|
||||
access_expires_at=123,
|
||||
)
|
||||
|
||||
assert body == {
|
||||
"access_token": "access-token",
|
||||
"refresh_token": "refresh-token",
|
||||
"token_type": "bearer",
|
||||
"access_expires_at": 123,
|
||||
}
|
||||
assert "set-cookie" not in response.headers
|
||||
|
||||
|
||||
def test_read_refresh_cookie_source_wins_over_body_source():
|
||||
request = _request_with_refresh_cookie("cookie-token")
|
||||
|
||||
refresh, mode = read_refresh(request, SimpleNamespace(refresh_token="body-token"))
|
||||
|
||||
assert refresh == "cookie-token"
|
||||
assert mode is TransportMode.COOKIE
|
||||
|
|
@ -0,0 +1,85 @@
|
|||
"""Regression tests for Zero's backend-computed authorization context."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import PersonalAccessToken, SearchSpace, User
|
||||
from app.routes.search_spaces_routes import create_default_roles_and_membership
|
||||
from app.utils.rbac import check_search_space_access, get_allowed_read_space_ids
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
def _pat_auth(user: User) -> AuthContext:
|
||||
pat = PersonalAccessToken(
|
||||
user_id=user.id,
|
||||
user=user,
|
||||
token_hash="1" * 64,
|
||||
token_prefix="ss_pat_zero",
|
||||
label="Zero PAT",
|
||||
)
|
||||
return AuthContext.pat_auth(user, pat)
|
||||
|
||||
|
||||
async def _space_with_membership(
|
||||
db_session: AsyncSession,
|
||||
user: User,
|
||||
*,
|
||||
api_access_enabled: bool,
|
||||
) -> SearchSpace:
|
||||
space = SearchSpace(
|
||||
name="Zero Authz Space",
|
||||
user_id=user.id,
|
||||
api_access_enabled=api_access_enabled,
|
||||
)
|
||||
db_session.add(space)
|
||||
await db_session.flush()
|
||||
await create_default_roles_and_membership(db_session, space.id, user.id)
|
||||
await db_session.flush()
|
||||
return space
|
||||
|
||||
|
||||
async def test_zero_read_set_matches_session_search_space_access(
|
||||
db_session: AsyncSession,
|
||||
db_user: User,
|
||||
db_search_space: SearchSpace,
|
||||
):
|
||||
disabled_space = await _space_with_membership(
|
||||
db_session,
|
||||
db_user,
|
||||
api_access_enabled=False,
|
||||
)
|
||||
session_auth = AuthContext.session(db_user)
|
||||
|
||||
allowed_ids = set(await get_allowed_read_space_ids(db_session, session_auth))
|
||||
|
||||
for space in (db_search_space, disabled_space):
|
||||
membership = await check_search_space_access(db_session, session_auth, space.id)
|
||||
assert membership.search_space_id in allowed_ids
|
||||
|
||||
|
||||
async def test_zero_read_set_applies_pat_api_access_gate(
|
||||
db_session: AsyncSession,
|
||||
db_user: User,
|
||||
db_search_space: SearchSpace,
|
||||
):
|
||||
db_search_space.api_access_enabled = True
|
||||
disabled_space = await _space_with_membership(
|
||||
db_session,
|
||||
db_user,
|
||||
api_access_enabled=False,
|
||||
)
|
||||
await db_session.flush()
|
||||
pat_auth = _pat_auth(db_user)
|
||||
|
||||
allowed_ids = set(await get_allowed_read_space_ids(db_session, pat_auth))
|
||||
|
||||
assert db_search_space.id in allowed_ids
|
||||
assert disabled_space.id not in allowed_ids
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await check_search_space_access(db_session, pat_auth, disabled_space.id)
|
||||
assert exc_info.value.status_code == 403
|
||||
|
|
@ -40,7 +40,9 @@ async def cleanup_supervisors():
|
|||
async def test_start_byo_long_poll_noops_when_mode_is_webhook(monkeypatch):
|
||||
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_ENABLED", True)
|
||||
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_TELEGRAM_INTAKE_MODE", "webhook")
|
||||
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_WHATSAPP_INTAKE_MODE", "disabled")
|
||||
monkeypatch.setattr(
|
||||
byo_long_poll.config, "GATEWAY_WHATSAPP_INTAKE_MODE", "disabled"
|
||||
)
|
||||
|
||||
await byo_long_poll.start_byo_long_poll_supervisors()
|
||||
|
||||
|
|
@ -53,7 +55,9 @@ async def test_start_byo_long_poll_noops_when_no_byo_accounts(mocker, monkeypatc
|
|||
monkeypatch.setattr(
|
||||
byo_long_poll.config, "GATEWAY_TELEGRAM_INTAKE_MODE", "longpoll"
|
||||
)
|
||||
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_WHATSAPP_INTAKE_MODE", "disabled")
|
||||
monkeypatch.setattr(
|
||||
byo_long_poll.config, "GATEWAY_WHATSAPP_INTAKE_MODE", "disabled"
|
||||
)
|
||||
session = mocker.AsyncMock()
|
||||
session.execute.return_value = ScalarResult([])
|
||||
monkeypatch.setattr(
|
||||
|
|
@ -75,7 +79,9 @@ async def test_start_byo_long_poll_spawns_one_supervisor_per_account(
|
|||
monkeypatch.setattr(
|
||||
byo_long_poll.config, "GATEWAY_TELEGRAM_INTAKE_MODE", "longpoll"
|
||||
)
|
||||
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_WHATSAPP_INTAKE_MODE", "disabled")
|
||||
monkeypatch.setattr(
|
||||
byo_long_poll.config, "GATEWAY_WHATSAPP_INTAKE_MODE", "disabled"
|
||||
)
|
||||
accounts = [mocker.Mock(id=1), mocker.Mock(id=2)]
|
||||
session = mocker.AsyncMock()
|
||||
session.execute.return_value = ScalarResult(accounts)
|
||||
|
|
@ -125,7 +131,9 @@ async def test_shutdown_cancels_running_supervisors(mocker, monkeypatch):
|
|||
monkeypatch.setattr(
|
||||
byo_long_poll.config, "GATEWAY_TELEGRAM_INTAKE_MODE", "longpoll"
|
||||
)
|
||||
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_WHATSAPP_INTAKE_MODE", "disabled")
|
||||
monkeypatch.setattr(
|
||||
byo_long_poll.config, "GATEWAY_WHATSAPP_INTAKE_MODE", "disabled"
|
||||
)
|
||||
session = mocker.AsyncMock()
|
||||
session.execute.return_value = ScalarResult([mocker.Mock(id=1)])
|
||||
monkeypatch.setattr(
|
||||
|
|
|
|||
|
|
@ -450,7 +450,9 @@ class TestRevertTurnDispatch:
|
|||
thread_id=1,
|
||||
chat_turn_id="ct-mixed-all",
|
||||
session=session,
|
||||
auth=AuthContext.session(_FakeUser()), # only id=7 has a different user_id
|
||||
auth=AuthContext.session(
|
||||
_FakeUser()
|
||||
), # only id=7 has a different user_id
|
||||
)
|
||||
|
||||
assert response.total == len(rows) == 6
|
||||
|
|
|
|||
|
|
@ -32,7 +32,9 @@ def _patch_common_bundle_dependencies(monkeypatch: pytest.MonkeyPatch):
|
|||
|
||||
_CapturedChatLiteLLM.calls = []
|
||||
|
||||
async def _fake_search_space(_session: Any, _search_space_id: int) -> SimpleNamespace:
|
||||
async def _fake_search_space(
|
||||
_session: Any, _search_space_id: int
|
||||
) -> SimpleNamespace:
|
||||
return SimpleNamespace(id=42, user_id="user-1")
|
||||
|
||||
monkeypatch.setattr(llm_bundle, "_load_search_space", _fake_search_space)
|
||||
|
|
|
|||
|
|
@ -32,11 +32,7 @@ CONNECTOR_LISTERS = [
|
|||
|
||||
|
||||
def _python_files() -> list[Path]:
|
||||
return [
|
||||
path
|
||||
for path in APP_ROOT.rglob("*.py")
|
||||
if "__pycache__" not in path.parts
|
||||
]
|
||||
return [path for path in APP_ROOT.rglob("*.py") if "__pycache__" not in path.parts]
|
||||
|
||||
|
||||
def test_current_active_user_is_removed_from_app_tree() -> None:
|
||||
|
|
|
|||
22
surfsense_backend/tests/unit/test_zero_authz_static.py
Normal file
22
surfsense_backend/tests/unit/test_zero_authz_static.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
"""Static guards for Zero authorization wiring."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[3]
|
||||
WEB_ROOT = REPO_ROOT / "surfsense_web"
|
||||
|
||||
|
||||
def test_zero_query_route_uses_authoritative_backend_context() -> None:
|
||||
route = WEB_ROOT / "app/api/zero/query/route.ts"
|
||||
text = route.read_text()
|
||||
|
||||
assert "/zero/context" in text
|
||||
assert "/users/me" not in text
|
||||
assert "userID: auth.ctx.userId" in text
|
||||
assert "handleQueryRequest({" in text
|
||||
|
|
@ -290,4 +290,4 @@ class TestExtractTextContent:
|
|||
def test_boolean_returns_empty_string(self):
|
||||
from app.utils.content_utils import extract_text_content
|
||||
|
||||
assert extract_text_content(True) == ""
|
||||
assert extract_text_content(True) == ""
|
||||
|
|
|
|||
|
|
@ -16,9 +16,8 @@ TEST_PASSWORD = "testpassword123"
|
|||
async def get_auth_token(client: httpx.AsyncClient) -> str:
|
||||
"""Log in and return a Bearer JWT token, registering the user first if needed."""
|
||||
response = await client.post(
|
||||
"/auth/jwt/login",
|
||||
data={"username": TEST_EMAIL, "password": TEST_PASSWORD},
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
"/auth/desktop/login",
|
||||
json={"email": TEST_EMAIL, "password": TEST_PASSWORD},
|
||||
)
|
||||
if response.status_code == 200:
|
||||
return response.json()["access_token"]
|
||||
|
|
@ -32,9 +31,8 @@ async def get_auth_token(client: httpx.AsyncClient) -> str:
|
|||
)
|
||||
|
||||
response = await client.post(
|
||||
"/auth/jwt/login",
|
||||
data={"username": TEST_EMAIL, "password": TEST_PASSWORD},
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
"/auth/desktop/login",
|
||||
json={"email": TEST_EMAIL, "password": TEST_PASSWORD},
|
||||
)
|
||||
assert response.status_code == 200, (
|
||||
f"Login after registration failed ({response.status_code}): {response.text}"
|
||||
|
|
|
|||
|
|
@ -1,10 +0,0 @@
|
|||
# Electron-specific build-time configuration.
|
||||
# Set before running pnpm dist:mac / dist:win / dist:linux.
|
||||
|
||||
# The hosted web frontend URL. Used to intercept OAuth redirects and keep them
|
||||
# inside the desktop app. Set to your production frontend domain.
|
||||
HOSTED_FRONTEND_URL=https://surfsense.com
|
||||
|
||||
# PostHog analytics (leave empty to disable)
|
||||
POSTHOG_KEY=
|
||||
POSTHOG_HOST=https://assets.surfsense.com
|
||||
|
|
@ -3,7 +3,15 @@
|
|||
|
||||
# The hosted web frontend URL. Used to intercept OAuth redirects and keep them
|
||||
# inside the desktop app. Set to your production frontend domain.
|
||||
HOSTED_FRONTEND_URL=https://surfsense.com
|
||||
HOSTED_FRONTEND_URL=http://localhost:3000
|
||||
|
||||
# The backend API URL used by desktop auth and refresh flows.
|
||||
HOSTED_BACKEND_URL=http://localhost:8000
|
||||
|
||||
# Public Google OAuth Desktop app client ID. Required for packaged desktop
|
||||
# Google login using loopback + PKCE. This is safe to ship in the desktop app;
|
||||
# the PKCE code verifier, not a client secret, protects the token exchange.
|
||||
GOOGLE_DESKTOP_CLIENT_ID=your_google_desktop_client_id.apps.googleusercontent.com
|
||||
|
||||
# Runtime override for the above (read at app start, no rebuild required).
|
||||
# Useful for self-hosters whose backend NEXT_FRONTEND_URL differs from the
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@
|
|||
"@types/node": "^25.5.0",
|
||||
"concurrently": "^9.2.1",
|
||||
"dotenv": "^17.3.1",
|
||||
"electron": "^41.0.2",
|
||||
"electron": "^42.4.0",
|
||||
"electron-builder": "^26.8.1",
|
||||
"esbuild": "^0.27.4",
|
||||
"typescript": "^5.9.3",
|
||||
|
|
|
|||
143
surfsense_desktop/pnpm-lock.yaml
generated
143
surfsense_desktop/pnpm-lock.yaml
generated
|
|
@ -46,8 +46,8 @@ importers:
|
|||
specifier: ^17.3.1
|
||||
version: 17.3.1
|
||||
electron:
|
||||
specifier: ^41.0.2
|
||||
version: 41.0.2
|
||||
specifier: ^42.4.0
|
||||
version: 42.4.0
|
||||
electron-builder:
|
||||
specifier: ^26.8.1
|
||||
version: 26.8.1(electron-builder-squirrel-windows@26.8.1)
|
||||
|
|
@ -70,6 +70,10 @@ packages:
|
|||
resolution: {integrity: sha512-0cp4PsWQ/9avqTVMCtZ+GirikIA36ikvjtHweU4/j8yLtgObI0+JUPhYFScgwlteveGB1rt3Cm8UhN04XayDig==}
|
||||
engines: {node: '>= 8.9.0'}
|
||||
|
||||
'@electron-internal/extract-zip@1.0.3':
|
||||
resolution: {integrity: sha512-OjKpjB7gohtEjZiq6nDx1egqjZJhGPN1iFOIED+NFhB/MMkXw/XRcHjh1DGXKT5z2W9eW7Jy2UKU3gpjvusFTQ==}
|
||||
engines: {node: '>=22.12.0'}
|
||||
|
||||
'@electron/asar@3.4.1':
|
||||
resolution: {integrity: sha512-i4/rNPRS84t0vSRa2HorerGRXWyF4vThfHesw0dmcWHp+cspK743UanA0suA5Q5y8kzY2y6YKrvbIUn69BCAiA==}
|
||||
engines: {node: '>=10.12.0'}
|
||||
|
|
@ -79,14 +83,14 @@ packages:
|
|||
resolution: {integrity: sha512-zx0EIq78WlY/lBb1uXlziZmDZI4ubcCXIMJ4uGjXzZW0nS19TjSPeXPAjzzTmKQlJUZm0SbmZhPKP7tuQ1SsEw==}
|
||||
hasBin: true
|
||||
|
||||
'@electron/get@2.0.3':
|
||||
resolution: {integrity: sha512-Qkzpg2s9GnVV2I2BjRksUi43U5e6+zaQMcjoJy0C+C5oxaKl+fmckGDQFtRpZpZV0NQekuZZ+tGz7EA9TVnQtQ==}
|
||||
engines: {node: '>=12'}
|
||||
|
||||
'@electron/get@3.1.0':
|
||||
resolution: {integrity: sha512-F+nKc0xW+kVbBRhFzaMgPy3KwmuNTYX1fx6+FxxoSnNgwYX6LD7AKBTWkU0MQ6IBoe7dz069CNkR673sPAgkCQ==}
|
||||
engines: {node: '>=14'}
|
||||
|
||||
'@electron/get@5.0.0':
|
||||
resolution: {integrity: sha512-pjoBpru1KdEtcExBnuHAP1cAc/5faoedw0hzJkL3o4/IJp7HNF1+fbrdxT3gMYRX2oJfvnA/WXeCTVQpYYxyJA==}
|
||||
engines: {node: '>=22.12.0'}
|
||||
|
||||
'@electron/notarize@2.5.0':
|
||||
resolution: {integrity: sha512-jNT8nwH1f9X5GEITXaQ8IF/KdskvIkOFfB2CvwumsveVidzpSc+mvhhTMdAGSYF3O+Nq49lJ7y+ssODRXu06+A==}
|
||||
engines: {node: '>= 10.0.0'}
|
||||
|
|
@ -346,8 +350,8 @@ packages:
|
|||
'@types/ms@2.1.0':
|
||||
resolution: {integrity: sha512-GsCCIZDE/p3i96vtEqx+7dBUGXrc7zeSK3wwPHIaRThS+9OhWIXRqzs4d6k1SVU8g91DrNRWxWUGhp5KXQb2VA==}
|
||||
|
||||
'@types/node@24.12.0':
|
||||
resolution: {integrity: sha512-GYDxsZi3ChgmckRT9HPU0WEhKLP08ev/Yfcq2AstjrDASOYCSXeyjDsHg4v5t4jOj7cyDX3vmprafKlWIG9MXQ==}
|
||||
'@types/node@24.13.2':
|
||||
resolution: {integrity: sha512-fRa09kZTgu8o71KFcDjUFuc7F+dEbZYZmkI0mg5YBTRs0yMKjYHsq/c0urDKeDb+D5qVgXOdFcuu+DZPKOITwA==}
|
||||
|
||||
'@types/node@25.5.0':
|
||||
resolution: {integrity: sha512-jp2P3tQMSxWugkCUKLRPVUpGaL5MVFwF8RDuSRztfwgN1wmqJeMSbKlnEtQqU8UrhTmzEmZdu2I6v2dpp7XIxw==}
|
||||
|
|
@ -361,9 +365,6 @@ packages:
|
|||
'@types/verror@1.10.11':
|
||||
resolution: {integrity: sha512-RlDm9K7+o5stv0Co8i8ZRGxDbrTxhJtgjqjFyVh/tXQyl/rYtTKlnTvZ88oSTeYREWurwx20Js4kTuKCsFkUtg==}
|
||||
|
||||
'@types/yauzl@2.10.3':
|
||||
resolution: {integrity: sha512-oJoftv0LSuaDZE3Le4DbKX+KS9G36NzOeSap90UIK0yMA/NhKJhqlSGtNDORNRaIbQfzjXDrQa0ytJ6mNRGz/Q==}
|
||||
|
||||
'@xmldom/xmldom@0.8.11':
|
||||
resolution: {integrity: sha512-cQzWCtO6C8TQiYl1ruKNn2U6Ao4o4WBBcbL61yJl84x+j5sOWWFU9X7DpND8XZG3daDppSsigMdfAIl2upQBRw==}
|
||||
engines: {node: '>=10.0.0'}
|
||||
|
|
@ -483,9 +484,6 @@ packages:
|
|||
resolution: {integrity: sha512-h+DEnpVvxmfVefa4jFbCf5HdH5YMDXRsmKflpf1pILZWRFlTbJpxeU55nJl4Smt5HQaGzg1o6RHFPJaOqnmBDg==}
|
||||
engines: {node: 18 || 20 || >=22}
|
||||
|
||||
buffer-crc32@0.2.13:
|
||||
resolution: {integrity: sha512-VO9Ht/+p3SN7SKWqcrgEzjGbRSJYTx+Q1pTQC0wrWqHx0vpJraQ6GtHx8tvcg1rlK1byhU5gccxgOgj7B0TDkQ==}
|
||||
|
||||
buffer-from@1.1.2:
|
||||
resolution: {integrity: sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==}
|
||||
|
||||
|
|
@ -714,9 +712,9 @@ packages:
|
|||
resolution: {integrity: sha512-bO3y10YikuUwUuDUQRM4KfwNkKhnpVO7IPdbsrejwN9/AABJzzTQ4GeHwyzNSrVO+tEH3/Np255a3sVZpZDjvg==}
|
||||
engines: {node: '>=8.0.0'}
|
||||
|
||||
electron@41.0.2:
|
||||
resolution: {integrity: sha512-raotm/aO8kOs1jD8SI8ssJ7EKciQOY295AOOprl1TxW7B0At8m5Ae7qNU1xdMxofiHMR8cNEGi9PKD3U+yT/mA==}
|
||||
engines: {node: '>= 12.20.55'}
|
||||
electron@42.4.0:
|
||||
resolution: {integrity: sha512-OXXqh9LD9KxXPv2Fe25EfU9N9AvWTuV6V81sfhQaNvTAXCd9ONA+Q4OWvMe+CmYD6xIwjFxGGtG/ZphDYYC5OQ==}
|
||||
engines: {node: '>= 22.12.0'}
|
||||
hasBin: true
|
||||
|
||||
emoji-regex@8.0.0:
|
||||
|
|
@ -777,11 +775,6 @@ packages:
|
|||
exponential-backoff@3.1.3:
|
||||
resolution: {integrity: sha512-ZgEeZXj30q+I0EN+CbSSpIyPaJ5HVQD18Z1m+u1FXbAeT94mr1zw50q4q6jiiC447Nl/YTcIYSAftiGqetwXCA==}
|
||||
|
||||
extract-zip@2.0.1:
|
||||
resolution: {integrity: sha512-GDhU9ntwuKyGXdZBUgTIe+vXnWj0fppUEtMDL0+idd5Sta8TGpHssn/eusA9mrPr9qNDym6SxAYZjNvCn/9RBg==}
|
||||
engines: {node: '>= 10.17.0'}
|
||||
hasBin: true
|
||||
|
||||
extsprintf@1.4.1:
|
||||
resolution: {integrity: sha512-Wrk35e8ydCKDj/ArClo1VrPVmN8zph5V4AtHwIuHhvMXsKf73UT3BOD+azBIW+3wOJ4FhEH7zyaJCFvChjYvMA==}
|
||||
engines: {'0': node >=0.6.0}
|
||||
|
|
@ -795,9 +788,6 @@ packages:
|
|||
fast-uri@3.1.0:
|
||||
resolution: {integrity: sha512-iPeeDKJSWf4IEOasVVrknXpaBV0IApz/gp7S2bb7Z4Lljbl2MGJRqInZiUrQwV16cpzw/D3S5j5Julj/gT52AA==}
|
||||
|
||||
fd-slicer@1.1.0:
|
||||
resolution: {integrity: sha512-cE1qsB/VwyQozZ+q1dGxR8LBYNZeofhEdUNGSMbQD3Gw2lAzX9Zb3uIU6Ebc/Fmyjo9AWWfnn0AUCHqtevs/8g==}
|
||||
|
||||
fdir@6.5.0:
|
||||
resolution: {integrity: sha512-tIbYtZbucOs0BRGqPJkshJUYdL+SDH7dVM8gjy+ERp3WAUjLEFJE+02kanyHtwjWOnwrKYBiwAmM0p4kLJAnXg==}
|
||||
engines: {node: '>=12.0.0'}
|
||||
|
|
@ -838,6 +828,10 @@ packages:
|
|||
resolution: {integrity: sha512-CTXd6rk/M3/ULNQj8FBqBWHYBVYybQ3VPBw0xGKFe3tuH7ytT6ACnvzpIQ3UZtB8yvUKC2cXn1a+x+5EVQLovA==}
|
||||
engines: {node: '>=14.14'}
|
||||
|
||||
fs-extra@11.3.5:
|
||||
resolution: {integrity: sha512-eKpRKAovdpZtR1WopLHxlBWvAgPny3c4gX1G5Jhwmmw4XJj0ifSD5qB5TOo8hmA0wlRKDAOAhEE1yVPgs6Fgcg==}
|
||||
engines: {node: '>=14.14'}
|
||||
|
||||
fs-extra@7.0.1:
|
||||
resolution: {integrity: sha512-YJDaCJZEnBmcbw13fvdAM9AwNOJwOzrE4pqMqBq5nFiEqXUqHwlK4B+3pUw6JNvfSPtX05xFHtYy/1ni01eGCw==}
|
||||
engines: {node: '>=6 <7 || >=8'}
|
||||
|
|
@ -1045,6 +1039,9 @@ packages:
|
|||
jsonfile@6.2.0:
|
||||
resolution: {integrity: sha512-FGuPw30AdOIUTRMC2OMRtQV+jkVj2cfPqSeWXv1NEAJ1qZ5zb1X6z1mFhbfOB/iy3ssJCD+3KuZ8r8C3uVFlAg==}
|
||||
|
||||
jsonfile@6.2.1:
|
||||
resolution: {integrity: sha512-zwOTdL3rFQ/lRdBnntKVOX6k5cKJwEc1HdilT71BWEu7J41gXIB2MRp+vxduPSwZJPWBxEzv4yH1wYLJGUHX4Q==}
|
||||
|
||||
keyv@4.5.4:
|
||||
resolution: {integrity: sha512-oxVHkHR/EJf2CNXnWxRLW6mg7JyCCUcG0DtEGmL2ctUo1PNTin1PUil+r/+4r5MpVgC/fn1kjsx7mjSujKqIpw==}
|
||||
|
||||
|
|
@ -1261,9 +1258,6 @@ packages:
|
|||
resolution: {integrity: sha512-eRWB5LBz7PpDu4PUlwT0PhnQfTQJlDDdPa35urV4Osrm0t0AqQFGn+UIkU3klZvwJ8KPO3VbBFsXquA6p6kqZw==}
|
||||
engines: {node: '>=12', npm: '>=6'}
|
||||
|
||||
pend@1.2.0:
|
||||
resolution: {integrity: sha512-F3asv42UuXchdzt+xXqfW1OGlVBe+mxa2mqI0pg5yAHZPvFmY3Y6drSf/GQ1A86WgWEN9Kzh/WrgKa6iGcHXLg==}
|
||||
|
||||
picocolors@1.1.1:
|
||||
resolution: {integrity: sha512-xceH2snhtb5M9liqDsmEw56le376mTZkEX/jEb/RxNFyegNul7eNslCXP9FDj/Lcu0X8KEyMceP2ntpaHrDEVA==}
|
||||
|
||||
|
|
@ -1397,6 +1391,11 @@ packages:
|
|||
engines: {node: '>=10'}
|
||||
hasBin: true
|
||||
|
||||
semver@7.8.5:
|
||||
resolution: {integrity: sha512-Y7/KDsb8LjooZpwaqGyulO6DQlksgCncchHGk+sZIY4SBvUocMBEFH5Ur1fI4dV+Jvl0w6cjvucaIi40puRioA==}
|
||||
engines: {node: '>=10'}
|
||||
hasBin: true
|
||||
|
||||
serialize-error@7.0.1:
|
||||
resolution: {integrity: sha512-8I8TjW5KMOKsZQTvoxjuSIa7foAwPWGOts+6o7sgjz41/qMD9VQHEDxi6PBvK2l0MXUmqZyNpUK+T2tQaaElvw==}
|
||||
engines: {node: '>=10'}
|
||||
|
|
@ -1554,12 +1553,13 @@ packages:
|
|||
resolution: {integrity: sha512-rvKSBiC5zqCCiDZ9kAOszZcDvdAHwwIKJG33Ykj43OKcWsnmcBRL09YTU4nOeHZ8Y2a7l1MgTd08SBe9A8Qj6A==}
|
||||
engines: {node: '>=18'}
|
||||
|
||||
undici-types@7.16.0:
|
||||
resolution: {integrity: sha512-Zz+aZWSj8LE6zoxD+xrjh4VfkIG8Ya6LvYkZqtUQGJPZjYl53ypCaUwWqo7eI0x66KBGeRo+mlBEkMSeSZ38Nw==}
|
||||
|
||||
undici-types@7.18.2:
|
||||
resolution: {integrity: sha512-AsuCzffGHJybSaRrmr5eHr81mwJU3kjw6M+uprWvCXiNeN9SOGwQ3Jn8jb8m3Z6izVgknn1R0FTCEAP2QrLY/w==}
|
||||
|
||||
undici@7.28.0:
|
||||
resolution: {integrity: sha512-cRZYrTDwWznlnRiPjggAGxZXanty6M8RV1ff8Wm4LWXBp7/IG8v5DnOm74DtUBp9OONpK75YlPnIjQqX0dBDtA==}
|
||||
engines: {node: '>=20.18.1'}
|
||||
|
||||
unique-filename@4.0.0:
|
||||
resolution: {integrity: sha512-XSnEewXmQ+veP7xX2dS5Q4yZAvO40cBN2MWkJ7D/6sW4Dg6wYBNwM1Vrnz1FhH5AdeLIlUXRI9e28z1YZi71NQ==}
|
||||
engines: {node: ^18.17.0 || >=20.5.0}
|
||||
|
|
@ -1644,9 +1644,6 @@ packages:
|
|||
resolution: {integrity: sha512-7dSzzRQ++CKnNI/krKnYRV7JKKPUXMEh61soaHKg9mrWEhzFWhFnxPxGl+69cD1Ou63C13NUPCnmIcrvqCuM6w==}
|
||||
engines: {node: '>=12'}
|
||||
|
||||
yauzl@2.10.0:
|
||||
resolution: {integrity: sha512-p4a9I6X6nu6IhoGmBqAcbJy1mlC4j27vEPZX9F4L4/vZT3Lyq1VkFHw/V/PUcB9Buo+DG3iHkT0x3Qya58zc3g==}
|
||||
|
||||
yocto-queue@0.1.0:
|
||||
resolution: {integrity: sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q==}
|
||||
engines: {node: '>=10'}
|
||||
|
|
@ -1660,6 +1657,8 @@ snapshots:
|
|||
ajv: 6.14.0
|
||||
ajv-keywords: 3.5.2(ajv@6.14.0)
|
||||
|
||||
'@electron-internal/extract-zip@1.0.3': {}
|
||||
|
||||
'@electron/asar@3.4.1':
|
||||
dependencies:
|
||||
commander: 5.1.0
|
||||
|
|
@ -1672,7 +1671,7 @@ snapshots:
|
|||
fs-extra: 9.1.0
|
||||
minimist: 1.2.8
|
||||
|
||||
'@electron/get@2.0.3':
|
||||
'@electron/get@3.1.0':
|
||||
dependencies:
|
||||
debug: 4.4.3
|
||||
env-paths: 2.2.1
|
||||
|
|
@ -1686,17 +1685,16 @@ snapshots:
|
|||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
'@electron/get@3.1.0':
|
||||
'@electron/get@5.0.0':
|
||||
dependencies:
|
||||
debug: 4.4.3
|
||||
env-paths: 2.2.1
|
||||
fs-extra: 8.1.0
|
||||
got: 11.8.6
|
||||
env-paths: 3.0.0
|
||||
graceful-fs: 4.2.11
|
||||
progress: 2.0.3
|
||||
semver: 6.3.1
|
||||
semver: 7.8.5
|
||||
sumchecker: 3.0.1
|
||||
optionalDependencies:
|
||||
global-agent: 3.0.0
|
||||
undici: 7.28.0
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
|
|
@ -1753,7 +1751,7 @@ snapshots:
|
|||
dependencies:
|
||||
cross-dirname: 0.1.0
|
||||
debug: 4.4.3
|
||||
fs-extra: 11.3.4
|
||||
fs-extra: 11.3.5
|
||||
minimist: 1.2.8
|
||||
postject: 1.0.0-alpha.6
|
||||
transitivePeerDependencies:
|
||||
|
|
@ -1930,9 +1928,9 @@ snapshots:
|
|||
|
||||
'@types/ms@2.1.0': {}
|
||||
|
||||
'@types/node@24.12.0':
|
||||
'@types/node@24.13.2':
|
||||
dependencies:
|
||||
undici-types: 7.16.0
|
||||
undici-types: 7.18.2
|
||||
|
||||
'@types/node@25.5.0':
|
||||
dependencies:
|
||||
|
|
@ -1951,11 +1949,6 @@ snapshots:
|
|||
'@types/verror@1.10.11':
|
||||
optional: true
|
||||
|
||||
'@types/yauzl@2.10.3':
|
||||
dependencies:
|
||||
'@types/node': 25.5.0
|
||||
optional: true
|
||||
|
||||
'@xmldom/xmldom@0.8.11': {}
|
||||
|
||||
abbrev@3.0.1: {}
|
||||
|
|
@ -2100,8 +2093,6 @@ snapshots:
|
|||
dependencies:
|
||||
balanced-match: 4.0.4
|
||||
|
||||
buffer-crc32@0.2.13: {}
|
||||
|
||||
buffer-from@1.1.2: {}
|
||||
|
||||
buffer@5.7.1:
|
||||
|
|
@ -2428,11 +2419,11 @@ snapshots:
|
|||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
electron@41.0.2:
|
||||
electron@42.4.0:
|
||||
dependencies:
|
||||
'@electron/get': 2.0.3
|
||||
'@types/node': 24.12.0
|
||||
extract-zip: 2.0.1
|
||||
'@electron-internal/extract-zip': 1.0.3
|
||||
'@electron/get': 5.0.0
|
||||
'@types/node': 24.13.2
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
|
|
@ -2509,16 +2500,6 @@ snapshots:
|
|||
|
||||
exponential-backoff@3.1.3: {}
|
||||
|
||||
extract-zip@2.0.1:
|
||||
dependencies:
|
||||
debug: 4.4.3
|
||||
get-stream: 5.2.0
|
||||
yauzl: 2.10.0
|
||||
optionalDependencies:
|
||||
'@types/yauzl': 2.10.3
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
extsprintf@1.4.1:
|
||||
optional: true
|
||||
|
||||
|
|
@ -2528,10 +2509,6 @@ snapshots:
|
|||
|
||||
fast-uri@3.1.0: {}
|
||||
|
||||
fd-slicer@1.1.0:
|
||||
dependencies:
|
||||
pend: 1.2.0
|
||||
|
||||
fdir@6.5.0(picomatch@4.0.3):
|
||||
optionalDependencies:
|
||||
picomatch: 4.0.3
|
||||
|
|
@ -2569,6 +2546,13 @@ snapshots:
|
|||
jsonfile: 6.2.0
|
||||
universalify: 2.0.1
|
||||
|
||||
fs-extra@11.3.5:
|
||||
dependencies:
|
||||
graceful-fs: 4.2.11
|
||||
jsonfile: 6.2.1
|
||||
universalify: 2.0.1
|
||||
optional: true
|
||||
|
||||
fs-extra@7.0.1:
|
||||
dependencies:
|
||||
graceful-fs: 4.2.11
|
||||
|
|
@ -2804,6 +2788,13 @@ snapshots:
|
|||
optionalDependencies:
|
||||
graceful-fs: 4.2.11
|
||||
|
||||
jsonfile@6.2.1:
|
||||
dependencies:
|
||||
universalify: 2.0.1
|
||||
optionalDependencies:
|
||||
graceful-fs: 4.2.11
|
||||
optional: true
|
||||
|
||||
keyv@4.5.4:
|
||||
dependencies:
|
||||
json-buffer: 3.0.1
|
||||
|
|
@ -3015,8 +3006,6 @@ snapshots:
|
|||
|
||||
pe-library@0.4.1: {}
|
||||
|
||||
pend@1.2.0: {}
|
||||
|
||||
picocolors@1.1.1: {}
|
||||
|
||||
picomatch@4.0.3: {}
|
||||
|
|
@ -3136,6 +3125,8 @@ snapshots:
|
|||
|
||||
semver@7.7.4: {}
|
||||
|
||||
semver@7.8.5: {}
|
||||
|
||||
serialize-error@7.0.1:
|
||||
dependencies:
|
||||
type-fest: 0.13.1
|
||||
|
|
@ -3295,10 +3286,11 @@ snapshots:
|
|||
|
||||
uint8array-extras@1.5.0: {}
|
||||
|
||||
undici-types@7.16.0: {}
|
||||
|
||||
undici-types@7.18.2: {}
|
||||
|
||||
undici@7.28.0:
|
||||
optional: true
|
||||
|
||||
unique-filename@4.0.0:
|
||||
dependencies:
|
||||
unique-slug: 5.0.0
|
||||
|
|
@ -3384,9 +3376,4 @@ snapshots:
|
|||
y18n: 5.0.8
|
||||
yargs-parser: 21.1.1
|
||||
|
||||
yauzl@2.10.0:
|
||||
dependencies:
|
||||
buffer-crc32: 0.2.13
|
||||
fd-slicer: 1.1.0
|
||||
|
||||
yocto-queue@0.1.0: {}
|
||||
|
|
|
|||
|
|
@ -40,8 +40,12 @@ export const IPC_CHANNELS = {
|
|||
READ_AGENT_LOCAL_FILE_TEXT: 'agent-filesystem:read-local-file-text',
|
||||
WRITE_AGENT_LOCAL_FILE_TEXT: 'agent-filesystem:write-local-file-text',
|
||||
// Auth token sync across windows
|
||||
GET_AUTH_TOKENS: 'auth:get-tokens',
|
||||
SET_AUTH_TOKENS: 'auth:set-tokens',
|
||||
GET_ACCESS_TOKEN: 'auth:get-access-token',
|
||||
REFRESH_ACCESS_TOKEN: 'auth:refresh-access-token',
|
||||
LOGOUT: 'auth:logout',
|
||||
AUTH_CHANGED: 'auth:changed',
|
||||
AUTH_START_GOOGLE: 'auth:start-google',
|
||||
AUTH_LOGIN_PASSWORD: 'auth:login-password',
|
||||
// Keyboard shortcut configuration
|
||||
GET_SHORTCUTS: 'shortcuts:get',
|
||||
SET_SHORTCUTS: 'shortcuts:set',
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import { app, ipcMain, shell } from 'electron';
|
||||
import { app, BrowserWindow, ipcMain, shell } from 'electron';
|
||||
import { IPC_CHANNELS } from './channels';
|
||||
import {
|
||||
getPermissionsStatus,
|
||||
|
|
@ -52,8 +52,64 @@ import {
|
|||
type AgentFilesystemTreeWatchOptions,
|
||||
} from '../modules/agent-filesystem-tree-watcher';
|
||||
import { installDownloadedUpdate } from '../modules/auto-updater';
|
||||
import { secretStore } from '../modules/secret-store';
|
||||
import { startGoogleOAuth } from '../modules/oauth';
|
||||
|
||||
let authTokens: { bearer: string; refresh: string } | null = null;
|
||||
const REFRESH_TOKEN_KEY = 'surfsense_refresh_token';
|
||||
let accessToken: string | null = null;
|
||||
let refreshInFlight: Promise<string | null> | null = null;
|
||||
|
||||
type DesktopAuthResponse = {
|
||||
access_token?: string;
|
||||
refresh_token?: string | null;
|
||||
};
|
||||
|
||||
function getBackendUrl(): string {
|
||||
return (process.env.HOSTED_BACKEND_URL || process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || '').replace(
|
||||
/\/+$/,
|
||||
''
|
||||
);
|
||||
}
|
||||
|
||||
function broadcastAuthChanged(): void {
|
||||
for (const win of BrowserWindow.getAllWindows()) {
|
||||
win.webContents.send(IPC_CHANNELS.AUTH_CHANGED, { authed: !!accessToken, accessToken });
|
||||
}
|
||||
}
|
||||
|
||||
async function storeTokens(tokens: { bearer: string; refresh?: string | null }): Promise<void> {
|
||||
accessToken = tokens.bearer || null;
|
||||
if (tokens.refresh) {
|
||||
await secretStore.set(REFRESH_TOKEN_KEY, tokens.refresh);
|
||||
}
|
||||
broadcastAuthChanged();
|
||||
}
|
||||
|
||||
async function refreshAccessToken(): Promise<string | null> {
|
||||
if (refreshInFlight) return refreshInFlight;
|
||||
|
||||
refreshInFlight = (async () => {
|
||||
const refresh = await secretStore.get(REFRESH_TOKEN_KEY);
|
||||
const backendUrl = getBackendUrl();
|
||||
if (!refresh || !backendUrl) return null;
|
||||
|
||||
const response = await fetch(`${backendUrl}/auth/jwt/refresh`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ refresh_token: refresh }),
|
||||
});
|
||||
if (!response.ok) return null;
|
||||
|
||||
const data = (await response.json()) as { access_token?: string; refresh_token?: string | null };
|
||||
if (!data.access_token) return null;
|
||||
await storeTokens({ bearer: data.access_token, refresh: data.refresh_token });
|
||||
return data.access_token;
|
||||
})().finally(() => {
|
||||
refreshInFlight = null;
|
||||
});
|
||||
|
||||
return refreshInFlight;
|
||||
}
|
||||
|
||||
export function registerIpcHandlers(): void {
|
||||
ipcMain.on(IPC_CHANNELS.OPEN_EXTERNAL, (_event, url: string) => {
|
||||
|
|
@ -173,14 +229,81 @@ export function registerIpcHandlers(): void {
|
|||
}
|
||||
);
|
||||
|
||||
ipcMain.handle(IPC_CHANNELS.SET_AUTH_TOKENS, (_event, tokens: { bearer: string; refresh: string }) => {
|
||||
authTokens = tokens;
|
||||
ipcMain.handle(IPC_CHANNELS.GET_ACCESS_TOKEN, async () => {
|
||||
if (!accessToken) {
|
||||
await refreshAccessToken();
|
||||
}
|
||||
return accessToken;
|
||||
});
|
||||
|
||||
ipcMain.handle(IPC_CHANNELS.GET_AUTH_TOKENS, () => {
|
||||
return authTokens;
|
||||
ipcMain.handle(IPC_CHANNELS.REFRESH_ACCESS_TOKEN, () => {
|
||||
return refreshAccessToken();
|
||||
});
|
||||
|
||||
ipcMain.handle(IPC_CHANNELS.LOGOUT, async () => {
|
||||
const backendUrl = getBackendUrl();
|
||||
const refresh = await secretStore.get(REFRESH_TOKEN_KEY);
|
||||
if (backendUrl && refresh) {
|
||||
try {
|
||||
await fetch(`${backendUrl}/auth/jwt/revoke`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ refresh_token: refresh }),
|
||||
});
|
||||
} catch {
|
||||
// Local logout is fail-closed even if the server revoke call fails.
|
||||
}
|
||||
}
|
||||
accessToken = null;
|
||||
await secretStore.clear(REFRESH_TOKEN_KEY);
|
||||
broadcastAuthChanged();
|
||||
});
|
||||
|
||||
ipcMain.handle(IPC_CHANNELS.AUTH_START_GOOGLE, async () => {
|
||||
const backendUrl = getBackendUrl();
|
||||
if (!backendUrl) {
|
||||
throw new Error('Backend URL is not configured');
|
||||
}
|
||||
const tokens = await startGoogleOAuth(backendUrl);
|
||||
await storeTokens({ bearer: tokens.access_token, refresh: tokens.refresh_token });
|
||||
return { ok: true };
|
||||
});
|
||||
|
||||
ipcMain.handle(
|
||||
IPC_CHANNELS.AUTH_LOGIN_PASSWORD,
|
||||
async (_event, payload: { email: string; password: string }) => {
|
||||
const backendUrl = getBackendUrl();
|
||||
if (!backendUrl) {
|
||||
throw new Error('Backend URL is not configured');
|
||||
}
|
||||
|
||||
const response = await fetch(`${backendUrl}/auth/desktop/login`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(payload),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
let detail = 'Password login failed';
|
||||
try {
|
||||
const error = (await response.json()) as { detail?: string };
|
||||
detail = error.detail || detail;
|
||||
} catch {
|
||||
// Keep the generic error if the backend did not return JSON.
|
||||
}
|
||||
throw new Error(detail);
|
||||
}
|
||||
|
||||
const tokens = (await response.json()) as DesktopAuthResponse;
|
||||
if (!tokens.access_token || !tokens.refresh_token) {
|
||||
throw new Error('Password login did not return desktop tokens');
|
||||
}
|
||||
|
||||
await storeTokens({ bearer: tokens.access_token, refresh: tokens.refresh_token });
|
||||
return { ok: true };
|
||||
}
|
||||
);
|
||||
|
||||
ipcMain.handle(IPC_CHANNELS.GET_SHORTCUTS, () => getShortcuts());
|
||||
|
||||
ipcMain.handle(IPC_CHANNELS.GET_AUTO_LAUNCH, () => getAutoLaunchState());
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ import {
|
|||
syncAutoLaunchOnStartup,
|
||||
wasLaunchedAtLogin,
|
||||
} from './modules/auto-launch';
|
||||
import { purgeLegacyAuthCutover } from './modules/auth-cutover';
|
||||
|
||||
registerGlobalErrorHandlers();
|
||||
app.setName('SurfSense');
|
||||
|
|
@ -29,6 +30,7 @@ registerIpcHandlers();
|
|||
|
||||
app.whenReady().then(async () => {
|
||||
initAnalytics();
|
||||
await purgeLegacyAuthCutover();
|
||||
const launchedAtLogin = wasLaunchedAtLogin();
|
||||
const startedHidden = shouldStartHidden();
|
||||
trackEvent('desktop_app_launched', {
|
||||
|
|
|
|||
30
surfsense_desktop/src/modules/auth-cutover.ts
Normal file
30
surfsense_desktop/src/modules/auth-cutover.ts
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
import { app } from 'electron';
|
||||
import { mkdir, readFile, writeFile } from 'node:fs/promises';
|
||||
import path from 'node:path';
|
||||
import { secretStore } from './secret-store';
|
||||
|
||||
const CUTOVER_FLAG_FILE = 'auth-cutover-v1.json';
|
||||
const REFRESH_TOKEN_KEY = 'surfsense_refresh_token';
|
||||
|
||||
async function hasCompletedCutover(flagPath: string): Promise<boolean> {
|
||||
try {
|
||||
const raw = await readFile(flagPath, 'utf8');
|
||||
return JSON.parse(raw)?.complete === true;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
export async function purgeLegacyAuthCutover(): Promise<void> {
|
||||
const userDataPath = app.getPath('userData');
|
||||
const flagPath = path.join(userDataPath, CUTOVER_FLAG_FILE);
|
||||
if (await hasCompletedCutover(flagPath)) return;
|
||||
|
||||
await secretStore.clear(REFRESH_TOKEN_KEY);
|
||||
await mkdir(userDataPath, { recursive: true });
|
||||
await writeFile(
|
||||
flagPath,
|
||||
JSON.stringify({ complete: true, completedAt: new Date().toISOString() }),
|
||||
{ mode: 0o600 }
|
||||
);
|
||||
}
|
||||
|
|
@ -22,8 +22,7 @@ function handleDeepLink(url: string) {
|
|||
path: parsed.pathname,
|
||||
});
|
||||
if (parsed.hostname === 'auth' && parsed.pathname === '/callback') {
|
||||
const params = parsed.searchParams.toString();
|
||||
win.loadURL(`${getServerOrigin()}/auth/callback?${params}`);
|
||||
win.loadURL(`${getServerOrigin()}/dashboard`);
|
||||
}
|
||||
|
||||
win.show();
|
||||
|
|
|
|||
72
surfsense_desktop/src/modules/oauth-page.ts
Normal file
72
surfsense_desktop/src/modules/oauth-page.ts
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
import http from 'node:http';
|
||||
|
||||
function escapeHtml(value: string): string {
|
||||
return value
|
||||
.replace(/&/g, '&')
|
||||
.replace(/</g, '<')
|
||||
.replace(/>/g, '>')
|
||||
.replace(/"/g, '"')
|
||||
.replace(/'/g, ''');
|
||||
}
|
||||
|
||||
function renderOAuthPage(title: string, message: string): string {
|
||||
return `<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
<title>${escapeHtml(title)}</title>
|
||||
<style>
|
||||
:root {
|
||||
color-scheme: dark;
|
||||
}
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
}
|
||||
body {
|
||||
margin: 0;
|
||||
min-height: 100vh;
|
||||
display: grid;
|
||||
place-items: center;
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif;
|
||||
background: #303030;
|
||||
background: oklch(0.24 0 0);
|
||||
color: #fafafa;
|
||||
}
|
||||
main {
|
||||
width: min(420px, calc(100vw - 32px));
|
||||
text-align: center;
|
||||
}
|
||||
h1 {
|
||||
margin: 0 0 12px;
|
||||
font-size: 24px;
|
||||
line-height: 1.2;
|
||||
letter-spacing: -0.02em;
|
||||
}
|
||||
p {
|
||||
margin: 0;
|
||||
color: #d4d4d4;
|
||||
line-height: 1.5;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<main>
|
||||
<h1>${escapeHtml(title)}</h1>
|
||||
<p>${escapeHtml(message)}</p>
|
||||
</main>
|
||||
</body>
|
||||
</html>`;
|
||||
}
|
||||
|
||||
export function writeOAuthPage(
|
||||
res: http.ServerResponse,
|
||||
statusCode: number,
|
||||
title: string,
|
||||
message: string,
|
||||
_tone?: 'success' | 'error' | 'neutral',
|
||||
): void {
|
||||
res
|
||||
.writeHead(statusCode, { 'content-type': 'text/html; charset=utf-8' })
|
||||
.end(renderOAuthPage(title, message));
|
||||
}
|
||||
155
surfsense_desktop/src/modules/oauth.ts
Normal file
155
surfsense_desktop/src/modules/oauth.ts
Normal file
|
|
@ -0,0 +1,155 @@
|
|||
import { shell } from 'electron';
|
||||
import crypto from 'node:crypto';
|
||||
import http from 'node:http';
|
||||
import { writeOAuthPage } from './oauth-page';
|
||||
|
||||
export interface DesktopAuthTokens {
|
||||
access_token: string;
|
||||
refresh_token: string;
|
||||
}
|
||||
|
||||
const OAUTH_TIMEOUT_MS = 5 * 60 * 1000;
|
||||
const OAUTH_CALLBACK_PATH = '/callback';
|
||||
|
||||
function base64Url(buffer: Buffer): string {
|
||||
return buffer.toString('base64').replace(/\+/g, '-').replace(/\//g, '_').replace(/=+$/, '');
|
||||
}
|
||||
|
||||
function randomUrlSafe(bytes = 32): string {
|
||||
return base64Url(crypto.randomBytes(bytes));
|
||||
}
|
||||
|
||||
function sha256(value: string): string {
|
||||
return base64Url(crypto.createHash('sha256').update(value).digest());
|
||||
}
|
||||
|
||||
function getGoogleDesktopClientId(): string {
|
||||
const clientId = (process.env.GOOGLE_DESKTOP_CLIENT_ID || '').trim();
|
||||
if (!clientId) {
|
||||
throw new Error('Google desktop OAuth client ID is not configured');
|
||||
}
|
||||
return clientId;
|
||||
}
|
||||
|
||||
export async function startGoogleOAuth(backendUrl: string): Promise<DesktopAuthTokens> {
|
||||
const clientId = getGoogleDesktopClientId();
|
||||
const state = randomUrlSafe();
|
||||
const codeVerifier = randomUrlSafe(64);
|
||||
const codeChallenge = sha256(codeVerifier);
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
let settled = false;
|
||||
let port: number | null = null;
|
||||
let timeout: NodeJS.Timeout | null = null;
|
||||
|
||||
const cleanup = () => {
|
||||
if (timeout) {
|
||||
clearTimeout(timeout);
|
||||
timeout = null;
|
||||
}
|
||||
if (server.listening) {
|
||||
server.close();
|
||||
}
|
||||
};
|
||||
|
||||
const fail = (error: Error) => {
|
||||
if (settled) return;
|
||||
settled = true;
|
||||
cleanup();
|
||||
reject(error);
|
||||
};
|
||||
|
||||
const succeed = (tokens: DesktopAuthTokens) => {
|
||||
if (settled) return;
|
||||
settled = true;
|
||||
cleanup();
|
||||
resolve(tokens);
|
||||
};
|
||||
|
||||
const server = http.createServer(async (req, res) => {
|
||||
try {
|
||||
const url = new URL(req.url || '/', 'http://127.0.0.1');
|
||||
if (url.pathname !== OAUTH_CALLBACK_PATH) {
|
||||
writeOAuthPage(res, 404, 'Not found', 'This OAuth callback endpoint is only used by SurfSense.');
|
||||
return;
|
||||
}
|
||||
|
||||
const oauthError = url.searchParams.get('error');
|
||||
if (oauthError) {
|
||||
const description = url.searchParams.get('error_description');
|
||||
writeOAuthPage(res, 400, 'Authentication failed', 'You can close this window and return to SurfSense.', 'error');
|
||||
fail(new Error(description || `Google OAuth failed: ${oauthError}`));
|
||||
return;
|
||||
}
|
||||
|
||||
const code = url.searchParams.get('code');
|
||||
const returnedState = url.searchParams.get('state');
|
||||
if (!code || returnedState !== state) {
|
||||
writeOAuthPage(res, 400, 'Authentication failed', 'You can close this window and return to SurfSense.', 'error');
|
||||
fail(new Error('Invalid OAuth callback'));
|
||||
return;
|
||||
}
|
||||
|
||||
if (!port) {
|
||||
writeOAuthPage(res, 500, 'Authentication failed', 'You can close this window and return to SurfSense.', 'error');
|
||||
fail(new Error('OAuth loopback server was not ready'));
|
||||
return;
|
||||
}
|
||||
|
||||
const redirectUri = `http://127.0.0.1:${port}${OAUTH_CALLBACK_PATH}`;
|
||||
const response = await fetch(`${backendUrl}/auth/desktop/session`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ code, code_verifier: codeVerifier, redirect_uri: redirectUri }),
|
||||
});
|
||||
if (!response.ok) {
|
||||
let detail = 'Desktop session exchange failed';
|
||||
try {
|
||||
const error = (await response.json()) as { detail?: string };
|
||||
detail = error.detail || detail;
|
||||
} catch {
|
||||
// Keep the generic exchange error if the backend did not return JSON.
|
||||
}
|
||||
writeOAuthPage(res, 401, 'Authentication failed', 'You can close this window and return to SurfSense.', 'error');
|
||||
fail(new Error(detail));
|
||||
return;
|
||||
}
|
||||
const tokens = (await response.json()) as DesktopAuthTokens;
|
||||
writeOAuthPage(res, 200, 'Authentication complete', 'You can close this window and return to SurfSense.', 'success');
|
||||
succeed(tokens);
|
||||
} catch (error) {
|
||||
fail(error instanceof Error ? error : new Error('Google OAuth failed'));
|
||||
}
|
||||
});
|
||||
|
||||
server.listen(0, '127.0.0.1', () => {
|
||||
const addressInfo = server.address();
|
||||
if (!addressInfo || typeof addressInfo === 'string') {
|
||||
fail(new Error('Unable to bind loopback OAuth server'));
|
||||
return;
|
||||
}
|
||||
port = addressInfo.port;
|
||||
timeout = setTimeout(() => {
|
||||
fail(new Error('Google OAuth timed out'));
|
||||
}, OAUTH_TIMEOUT_MS);
|
||||
|
||||
const redirectUri = `http://127.0.0.1:${port}${OAUTH_CALLBACK_PATH}`;
|
||||
const authUrl = new URL('https://accounts.google.com/o/oauth2/v2/auth');
|
||||
authUrl.searchParams.set('client_id', clientId);
|
||||
authUrl.searchParams.set('redirect_uri', redirectUri);
|
||||
authUrl.searchParams.set('response_type', 'code');
|
||||
authUrl.searchParams.set('scope', 'openid email profile');
|
||||
authUrl.searchParams.set('state', state);
|
||||
authUrl.searchParams.set('code_challenge', codeChallenge);
|
||||
authUrl.searchParams.set('code_challenge_method', 'S256');
|
||||
|
||||
shell.openExternal(authUrl.toString()).catch((error) => {
|
||||
fail(error instanceof Error ? error : new Error('Unable to open browser for Google OAuth'));
|
||||
});
|
||||
});
|
||||
|
||||
server.on('error', (error) => {
|
||||
fail(error);
|
||||
});
|
||||
});
|
||||
}
|
||||
86
surfsense_desktop/src/modules/secret-store.ts
Normal file
86
surfsense_desktop/src/modules/secret-store.ts
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
import { app, safeStorage } from 'electron';
|
||||
import fs from 'node:fs/promises';
|
||||
import path from 'node:path';
|
||||
|
||||
export interface SecretStore {
|
||||
set(key: string, value: string): Promise<void>;
|
||||
get(key: string): Promise<string | null>;
|
||||
clear(key: string): Promise<void>;
|
||||
isHardwareBacked(): Promise<boolean>;
|
||||
}
|
||||
|
||||
const memoryStore = new Map<string, string>();
|
||||
const storePath = path.join(app.getPath('userData'), 'secrets.enc.json');
|
||||
|
||||
async function readDiskStore(): Promise<Record<string, string>> {
|
||||
try {
|
||||
const raw = await fs.readFile(storePath, 'utf8');
|
||||
return JSON.parse(raw) as Record<string, string>;
|
||||
} catch {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
async function writeDiskStore(data: Record<string, string>): Promise<void> {
|
||||
await fs.mkdir(path.dirname(storePath), { recursive: true });
|
||||
await fs.writeFile(storePath, JSON.stringify(data), { encoding: 'utf8', mode: 0o600 });
|
||||
}
|
||||
|
||||
async function canPersistEncryptedSecrets(): Promise<boolean> {
|
||||
try {
|
||||
if (safeStorage.getSelectedStorageBackend?.() === 'basic_text') {
|
||||
return false;
|
||||
}
|
||||
return await safeStorage.isAsyncEncryptionAvailable();
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
export const secretStore: SecretStore = {
|
||||
async set(key, value) {
|
||||
if (!(await canPersistEncryptedSecrets())) {
|
||||
memoryStore.set(key, value);
|
||||
return;
|
||||
}
|
||||
|
||||
const encrypted = await safeStorage.encryptStringAsync(value);
|
||||
const data = await readDiskStore();
|
||||
data[key] = encrypted.toString('base64');
|
||||
await writeDiskStore(data);
|
||||
},
|
||||
|
||||
async get(key) {
|
||||
if (!(await canPersistEncryptedSecrets())) {
|
||||
return memoryStore.get(key) ?? null;
|
||||
}
|
||||
|
||||
const data = await readDiskStore();
|
||||
const encoded = data[key];
|
||||
if (!encoded) return null;
|
||||
|
||||
try {
|
||||
const decrypted = await safeStorage.decryptStringAsync(Buffer.from(encoded, 'base64'));
|
||||
if (decrypted.shouldReEncrypt) {
|
||||
await this.set(key, decrypted.result);
|
||||
}
|
||||
return decrypted.result;
|
||||
} catch {
|
||||
await this.clear(key);
|
||||
return null;
|
||||
}
|
||||
},
|
||||
|
||||
async clear(key) {
|
||||
memoryStore.delete(key);
|
||||
const data = await readDiskStore();
|
||||
if (key in data) {
|
||||
delete data[key];
|
||||
await writeDiskStore(data);
|
||||
}
|
||||
},
|
||||
|
||||
async isHardwareBacked() {
|
||||
return canPersistEncryptedSecrets();
|
||||
},
|
||||
};
|
||||
|
|
@ -94,6 +94,10 @@ export function createMainWindow(initialPath = '/dashboard'): BrowserWindow {
|
|||
session.defaultSession.webRequest.onBeforeRequest(rewriteFilter, (details, callback) => {
|
||||
try {
|
||||
const u = new URL(details.url);
|
||||
if (!u.pathname.includes('/connectors/callback')) {
|
||||
callback({});
|
||||
return;
|
||||
}
|
||||
const originalHost = u.host;
|
||||
const local = new URL(getServerOrigin());
|
||||
u.protocol = local.protocol;
|
||||
|
|
|
|||
|
|
@ -80,9 +80,18 @@ contextBridge.exposeInMainWorld('electronAPI', {
|
|||
ipcRenderer.invoke(IPC_CHANNELS.WRITE_AGENT_LOCAL_FILE_TEXT, virtualPath, content, searchSpaceId),
|
||||
|
||||
// Auth token sync across windows
|
||||
getAuthTokens: () => ipcRenderer.invoke(IPC_CHANNELS.GET_AUTH_TOKENS),
|
||||
setAuthTokens: (bearer: string, refresh: string) =>
|
||||
ipcRenderer.invoke(IPC_CHANNELS.SET_AUTH_TOKENS, { bearer, refresh }),
|
||||
getAccessToken: () => ipcRenderer.invoke(IPC_CHANNELS.GET_ACCESS_TOKEN),
|
||||
refreshAccessToken: () => ipcRenderer.invoke(IPC_CHANNELS.REFRESH_ACCESS_TOKEN),
|
||||
logout: () => ipcRenderer.invoke(IPC_CHANNELS.LOGOUT),
|
||||
startGoogleOAuth: () => ipcRenderer.invoke(IPC_CHANNELS.AUTH_START_GOOGLE),
|
||||
loginPassword: (email: string, password: string) =>
|
||||
ipcRenderer.invoke(IPC_CHANNELS.AUTH_LOGIN_PASSWORD, { email, password }),
|
||||
onAuthChanged: (callback: (payload: { authed: boolean; accessToken: string | null }) => void) => {
|
||||
const listener = (_event: Electron.IpcRendererEvent, payload: { authed: boolean; accessToken: string | null }) =>
|
||||
callback(payload);
|
||||
ipcRenderer.on(IPC_CHANNELS.AUTH_CHANGED, listener);
|
||||
return () => ipcRenderer.removeListener(IPC_CHANNELS.AUTH_CHANGED, listener);
|
||||
},
|
||||
|
||||
// Keyboard shortcut configuration
|
||||
getShortcuts: () => ipcRenderer.invoke(IPC_CHANNELS.GET_SHORTCUTS),
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@ SurfSense supports ``AUTH_TYPE=LOCAL`` (email + password) and
|
|||
There is no headless equivalent of the Google flow, so the harness handles
|
||||
both modes by treating the JWT as the universal credential:
|
||||
|
||||
* **LOCAL**: harness POSTs form-encoded ``username`` + ``password`` to
|
||||
``/auth/jwt/login``, reads ``{access_token, refresh_token}``.
|
||||
* **LOCAL**: harness POSTs JSON ``email`` + ``password`` to
|
||||
``/auth/desktop/login``, reads ``{access_token, refresh_token}``.
|
||||
* **GOOGLE / pre-issued JWT**: operator pastes their existing JWT (and
|
||||
optionally refresh token) into ``SURFSENSE_JWT`` /
|
||||
``SURFSENSE_REFRESH_TOKEN``; harness skips login.
|
||||
|
|
@ -22,7 +22,7 @@ MIRAGE runs.
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
|
@ -40,9 +40,8 @@ _NO_CREDENTIALS_MESSAGE = (
|
|||
"No SurfSense credentials configured. Set ONE of:\n"
|
||||
" (LOCAL) SURFSENSE_USER_EMAIL + SURFSENSE_USER_PASSWORD\n"
|
||||
" (GOOGLE) SURFSENSE_JWT (and optionally SURFSENSE_REFRESH_TOKEN)\n"
|
||||
"For GOOGLE: log in to SurfSense in your browser, open DevTools → "
|
||||
"Application → Local Storage → copy `surfsense_bearer_token` and "
|
||||
"`surfsense_refresh_token` into those env vars."
|
||||
"For GOOGLE: use a PAT or operator-issued bearer token and set "
|
||||
"SURFSENSE_JWT (plus SURFSENSE_REFRESH_TOKEN if available)."
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -69,7 +68,7 @@ async def acquire_token(config: Config, *, http: httpx.AsyncClient | None = None
|
|||
1. ``SURFSENSE_JWT`` set → use it directly. Refresh token captured if
|
||||
supplied.
|
||||
2. ``SURFSENSE_USER_EMAIL`` + ``SURFSENSE_USER_PASSWORD`` set →
|
||||
form-encoded POST to ``/auth/jwt/login``.
|
||||
JSON POST to ``/auth/desktop/login``.
|
||||
3. Neither → raise ``CredentialError``.
|
||||
|
||||
The optional ``http`` argument lets tests inject a mocked client; if
|
||||
|
|
@ -86,9 +85,9 @@ async def acquire_token(config: Config, *, http: httpx.AsyncClient | None = None
|
|||
if config.has_local_mode():
|
||||
async def _login(client: httpx.AsyncClient) -> TokenBundle:
|
||||
response = await client.post(
|
||||
f"{config.surfsense_api_base}/auth/jwt/login",
|
||||
data={
|
||||
"username": config.surfsense_user_email,
|
||||
f"{config.surfsense_api_base}/auth/desktop/login",
|
||||
json={
|
||||
"email": config.surfsense_user_email,
|
||||
"password": config.surfsense_user_password,
|
||||
},
|
||||
headers={"Accept": "application/json"},
|
||||
|
|
|
|||
|
|
@ -46,8 +46,8 @@ async def test_acquire_token_jwt_mode_short_circuits():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_acquire_token_local_mode_posts_form():
|
||||
respx.post("http://test/auth/jwt/login").mock(
|
||||
async def test_acquire_token_local_mode_posts_desktop_login_json():
|
||||
respx.post("http://test/auth/desktop/login").mock(
|
||||
return_value=httpx.Response(
|
||||
200, json={"access_token": "T", "refresh_token": "R", "token_type": "bearer"}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -41,6 +41,10 @@ NEXT_PUBLIC_POSTHOG_HOST=https://us.i.posthog.com
|
|||
# "/zero" endpoint behind Caddy. Set it for local dev or packaged clients.
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# NEXT_PUBLIC_ZERO_CACHE_URL=http://localhost:4848
|
||||
# Server-only shared secret that authorizes zero-cache when it calls
|
||||
# /api/zero/query. Leave unset during the compatibility rollout, then set it
|
||||
# once every zero-cache instance sends X-Api-Key.
|
||||
# ZERO_QUERY_API_KEY=
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Cloudflare Turnstile CAPTCHA for anonymous chat abuse prevention
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import { useRuntimeConfig } from "@/components/providers/runtime-config";
|
|||
import { Button } from "@/components/ui/button";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
import { getAuthErrorDetails, isNetworkError } from "@/lib/auth-errors";
|
||||
import { getPostLoginRedirectPath } from "@/lib/auth-utils";
|
||||
import { ValidationError } from "@/lib/error";
|
||||
import { trackLoginAttempt, trackLoginFailure, trackLoginSuccess } from "@/lib/posthog/events";
|
||||
|
||||
|
|
@ -38,7 +39,7 @@ export function LocalLoginForm() {
|
|||
trackLoginAttempt("local");
|
||||
|
||||
try {
|
||||
const data = await login({
|
||||
await login({
|
||||
username,
|
||||
password,
|
||||
grant_type: "password",
|
||||
|
|
@ -47,14 +48,9 @@ export function LocalLoginForm() {
|
|||
// Track successful login
|
||||
trackLoginSuccess("local");
|
||||
|
||||
// Set flag so TokenHandler knows local login was already tracked
|
||||
if (typeof window !== "undefined") {
|
||||
sessionStorage.setItem("login_success_tracked", "true");
|
||||
}
|
||||
|
||||
// Small delay to show success message
|
||||
setTimeout(() => {
|
||||
router.push(`/auth/callback?token=${data.access_token}`);
|
||||
router.push(getPostLoginRedirectPath());
|
||||
}, 500);
|
||||
} catch (err) {
|
||||
if (err instanceof ValidationError) {
|
||||
|
|
|
|||
|
|
@ -30,8 +30,7 @@ function LoginContent() {
|
|||
const logout = searchParams.get("logout");
|
||||
const returnUrl = searchParams.get("returnUrl");
|
||||
|
||||
// Save returnUrl to localStorage so it persists through OAuth flows (e.g., Google)
|
||||
// This is read by TokenHandler after successful authentication
|
||||
// Save returnUrl for client-side login flows that can redirect directly after success.
|
||||
if (returnUrl) {
|
||||
setRedirectPath(decodeURIComponent(returnUrl));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,8 +12,8 @@ import { Logo } from "@/components/Logo";
|
|||
import { useRuntimeConfig } from "@/components/providers/runtime-config";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
import { useSession } from "@/hooks/use-session";
|
||||
import { getAuthErrorDetails, isNetworkError, shouldRetry } from "@/lib/auth-errors";
|
||||
import { getBearerToken } from "@/lib/auth-utils";
|
||||
import { AppError, ValidationError } from "@/lib/error";
|
||||
import {
|
||||
trackRegistrationAttempt,
|
||||
|
|
@ -37,18 +37,19 @@ export default function RegisterPage() {
|
|||
message: null,
|
||||
});
|
||||
const router = useRouter();
|
||||
const session = useSession();
|
||||
const [{ mutateAsync: register, isPending: isRegistering }] = useAtom(registerMutationAtom);
|
||||
|
||||
// Check authentication type and redirect if not LOCAL
|
||||
useEffect(() => {
|
||||
if (getBearerToken()) {
|
||||
if (session.status === "authenticated") {
|
||||
router.replace("/dashboard");
|
||||
return;
|
||||
}
|
||||
if (authType !== "LOCAL") {
|
||||
router.push("/login");
|
||||
}
|
||||
}, [authType, router]);
|
||||
}, [authType, router, session.status]);
|
||||
|
||||
const handleSubmit = (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
|
|
|
|||
|
|
@ -12,45 +12,66 @@ import { schema } from "@/zero/schema";
|
|||
// (e.g. http://localhost:8929) does NOT resolve from inside the frontend
|
||||
// container and would make every authenticated Zero query fail with a 503.
|
||||
const backendURL = SERVER_BACKEND_URL.replace(/\/$/, "");
|
||||
const zeroQueryApiKey = process.env.ZERO_QUERY_API_KEY;
|
||||
|
||||
function validateZeroCacheRequest(request: Request): NextResponse | null {
|
||||
if (!zeroQueryApiKey) return null;
|
||||
if (request.headers.get("X-Api-Key") === zeroQueryApiKey) return null;
|
||||
return NextResponse.json({ error: "Forbidden" }, { status: 403 });
|
||||
}
|
||||
|
||||
async function authenticateRequest(
|
||||
request: Request
|
||||
): Promise<{ ctx: Context; error?: never } | { ctx?: never; error: NextResponse }> {
|
||||
): Promise<
|
||||
{ ctx: Exclude<Context, undefined>; error?: never } | { ctx?: never; error: NextResponse }
|
||||
> {
|
||||
const authHeader = request.headers.get("Authorization");
|
||||
if (!authHeader?.startsWith("Bearer ")) {
|
||||
return { ctx: undefined };
|
||||
const cookieHeader = request.headers.get("Cookie");
|
||||
const headers: HeadersInit = {};
|
||||
if (authHeader?.startsWith("Bearer ")) {
|
||||
headers.Authorization = authHeader;
|
||||
} else if (cookieHeader) {
|
||||
headers.Cookie = cookieHeader;
|
||||
} else {
|
||||
return { error: NextResponse.json({ error: "Unauthorized" }, { status: 401 }) };
|
||||
}
|
||||
|
||||
try {
|
||||
const res = await fetch(`${backendURL}/users/me`, {
|
||||
headers: { Authorization: authHeader },
|
||||
const res = await fetch(`${backendURL}/zero/context`, {
|
||||
headers,
|
||||
});
|
||||
|
||||
if (!res.ok) {
|
||||
return { error: NextResponse.json({ error: "Unauthorized" }, { status: 401 }) };
|
||||
}
|
||||
|
||||
const user = await res.json();
|
||||
return { ctx: { userId: String(user.id) } };
|
||||
const ctx = (await res.json()) as Exclude<Context, undefined>;
|
||||
return { ctx };
|
||||
} catch {
|
||||
return { error: NextResponse.json({ error: "Auth service unavailable" }, { status: 503 }) };
|
||||
}
|
||||
}
|
||||
|
||||
export async function POST(request: Request) {
|
||||
const forbidden = validateZeroCacheRequest(request);
|
||||
if (forbidden) {
|
||||
return forbidden;
|
||||
}
|
||||
|
||||
const auth = await authenticateRequest(request);
|
||||
if (auth.error) {
|
||||
return auth.error;
|
||||
}
|
||||
|
||||
const result = await handleQueryRequest(
|
||||
(name, args) => {
|
||||
const result = await handleQueryRequest({
|
||||
handler: (name, args) => {
|
||||
const query = mustGetQuery(queries, name);
|
||||
return query.fn({ args, ctx: auth.ctx });
|
||||
},
|
||||
schema,
|
||||
request
|
||||
);
|
||||
request,
|
||||
userID: auth.ctx.userId,
|
||||
});
|
||||
|
||||
return NextResponse.json(result);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,14 +0,0 @@
|
|||
"use client";
|
||||
|
||||
import { Suspense } from "react";
|
||||
import TokenHandler from "@/components/TokenHandler";
|
||||
|
||||
export default function AuthCallbackPage() {
|
||||
// Suspense fallback returns null - the GlobalLoadingProvider handles the loading UI
|
||||
// TokenHandler uses useGlobalLoadingEffect to show the loading screen
|
||||
return (
|
||||
<Suspense fallback={null}>
|
||||
<TokenHandler redirectPath="/dashboard" tokenParamName="token" />
|
||||
</Suspense>
|
||||
);
|
||||
}
|
||||
|
|
@ -69,7 +69,7 @@ import { useMessagesSync } from "@/hooks/use-messages-sync";
|
|||
import { useThreadDetail, useThreadMessages } from "@/hooks/use-thread-queries";
|
||||
import { getAgentFilesystemSelection } from "@/lib/agent-filesystem";
|
||||
import { documentsApiService } from "@/lib/apis/documents-api.service";
|
||||
import { getBearerToken } from "@/lib/auth-utils";
|
||||
import { getDesktopAccessToken } from "@/lib/auth-fetch";
|
||||
import { type ChatFlow, classifyChatError } from "@/lib/chat/chat-error-classifier";
|
||||
import { tagPreAcceptSendFailure, toHttpResponseError } from "@/lib/chat/chat-request-errors";
|
||||
import { getMentionDocKey } from "@/lib/chat/mention-doc-key";
|
||||
|
|
@ -917,29 +917,26 @@ export default function NewChatPage() {
|
|||
// Cancel ongoing request
|
||||
const cancelRun = useCallback(async () => {
|
||||
if (threadId) {
|
||||
const token = getBearerToken();
|
||||
if (token) {
|
||||
try {
|
||||
const response = await fetch(
|
||||
buildBackendUrl(`/api/v1/threads/${threadId}/cancel-active-turn`),
|
||||
{
|
||||
method: "POST",
|
||||
headers: {
|
||||
Authorization: `Bearer ${token}`,
|
||||
},
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
const payload = (await response.json()) as {
|
||||
error_code?: string;
|
||||
};
|
||||
if (payload.error_code === "TURN_CANCELLING") {
|
||||
recentCancelRequestedAtRef.current = Date.now();
|
||||
}
|
||||
const token = await getDesktopAccessToken();
|
||||
try {
|
||||
const response = await fetch(
|
||||
buildBackendUrl(`/api/v1/threads/${threadId}/cancel-active-turn`),
|
||||
{
|
||||
method: "POST",
|
||||
headers: token ? { Authorization: `Bearer ${token}` } : undefined,
|
||||
credentials: "include",
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
const payload = (await response.json()) as {
|
||||
error_code?: string;
|
||||
};
|
||||
if (payload.error_code === "TURN_CANCELLING") {
|
||||
recentCancelRequestedAtRef.current = Date.now();
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn("[NewChatPage] Failed to signal cancel-active-turn:", error);
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn("[NewChatPage] Failed to signal cancel-active-turn:", error);
|
||||
}
|
||||
}
|
||||
if (abortControllerRef.current) {
|
||||
|
|
@ -964,11 +961,7 @@ export default function NewChatPage() {
|
|||
|
||||
if (!userQuery.trim() && userImages.length === 0) return;
|
||||
|
||||
const token = getBearerToken();
|
||||
if (!token) {
|
||||
toast.error("Not authenticated. Please log in again.");
|
||||
return;
|
||||
}
|
||||
const token = await getDesktopAccessToken();
|
||||
|
||||
// Lazy thread creation: create thread on first message if it doesn't exist
|
||||
let currentThreadId = threadId;
|
||||
|
|
@ -1149,8 +1142,9 @@ export default function NewChatPage() {
|
|||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${token}`,
|
||||
...(token ? { Authorization: `Bearer ${token}` } : {}),
|
||||
},
|
||||
credentials: "include",
|
||||
body: JSON.stringify({
|
||||
chat_id: currentThreadId,
|
||||
user_query: userQuery.trim(),
|
||||
|
|
@ -1537,12 +1531,7 @@ export default function NewChatPage() {
|
|||
stagedDecisionsByInterruptIdRef.current.clear();
|
||||
setIsRunning(true);
|
||||
|
||||
const token = getBearerToken();
|
||||
if (!token) {
|
||||
toast.error("Not authenticated. Please log in again.");
|
||||
setIsRunning(false);
|
||||
return;
|
||||
}
|
||||
const token = await getDesktopAccessToken();
|
||||
|
||||
const controller = new AbortController();
|
||||
abortControllerRef.current = controller;
|
||||
|
|
@ -1648,8 +1637,9 @@ export default function NewChatPage() {
|
|||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${token}`,
|
||||
...(token ? { Authorization: `Bearer ${token}` } : {}),
|
||||
},
|
||||
credentials: "include",
|
||||
body: JSON.stringify({
|
||||
search_space_id: searchSpaceId,
|
||||
decisions,
|
||||
|
|
@ -1981,11 +1971,7 @@ export default function NewChatPage() {
|
|||
abortControllerRef.current = null;
|
||||
}
|
||||
|
||||
const token = getBearerToken();
|
||||
if (!token) {
|
||||
toast.error("Not authenticated. Please log in again.");
|
||||
return;
|
||||
}
|
||||
const token = await getDesktopAccessToken();
|
||||
|
||||
// Extract the original user query BEFORE removing messages (for reload mode)
|
||||
let userQueryToDisplay: string | undefined;
|
||||
|
|
@ -2104,8 +2090,9 @@ export default function NewChatPage() {
|
|||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${token}`,
|
||||
...(token ? { Authorization: `Bearer ${token}` } : {}),
|
||||
},
|
||||
credentials: "include",
|
||||
body: JSON.stringify(requestBody),
|
||||
signal: controller.signal,
|
||||
})
|
||||
|
|
|
|||
|
|
@ -13,13 +13,15 @@ import { Logo } from "@/components/Logo";
|
|||
import { ModelProviderConnectionsPanel } from "@/components/settings/model-connections/model-provider-connections-panel";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { useGlobalLoadingEffect } from "@/hooks/use-global-loading";
|
||||
import { getBearerToken, redirectToLogin } from "@/lib/auth-utils";
|
||||
import { useSession } from "@/hooks/use-session";
|
||||
import { redirectToLogin } from "@/lib/auth-utils";
|
||||
import { hasEnabledChatModel, isLlmOnboardingComplete } from "@/lib/onboarding";
|
||||
|
||||
export default function OnboardPage() {
|
||||
const router = useRouter();
|
||||
const params = useParams();
|
||||
const searchSpaceId = Number(params.search_space_id);
|
||||
const session = useSession();
|
||||
const { data: globalConnections = [], isLoading: globalLoading } = useAtomValue(
|
||||
globalModelConnectionsAtom
|
||||
);
|
||||
|
|
@ -29,8 +31,8 @@ export default function OnboardPage() {
|
|||
useAtomValue(globalLlmConfigStatusAtom);
|
||||
|
||||
useEffect(() => {
|
||||
if (!getBearerToken()) redirectToLogin();
|
||||
}, []);
|
||||
if (session.status === "unauthenticated") redirectToLogin();
|
||||
}, [session.status]);
|
||||
|
||||
const hasUsableChatModel = useMemo(
|
||||
() => hasEnabledChatModel([...globalConnections, ...connections]),
|
||||
|
|
@ -43,7 +45,8 @@ export default function OnboardPage() {
|
|||
connections
|
||||
);
|
||||
|
||||
const isLoading = globalLoading || rolesLoading || globalConfigStatusLoading;
|
||||
const isLoading =
|
||||
session.status === "loading" || globalLoading || rolesLoading || globalConfigStatusLoading;
|
||||
|
||||
// Onboarding only applies when no global_llm_config.yaml exists. If a global
|
||||
// config is present (or onboarding is already complete), leave this page.
|
||||
|
|
|
|||
|
|
@ -1,10 +1,20 @@
|
|||
"use client";
|
||||
|
||||
import { Check, Copy, Info, Plus, Trash2 } from "lucide-react";
|
||||
import { Check, Copy, Info, Trash2 } from "lucide-react";
|
||||
import { useCallback, useMemo, useState } from "react";
|
||||
import { Alert, AlertDescription } from "@/components/ui/alert";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import {
|
||||
AlertDialog,
|
||||
AlertDialogAction,
|
||||
AlertDialogCancel,
|
||||
AlertDialogContent,
|
||||
AlertDialogDescription,
|
||||
AlertDialogFooter,
|
||||
AlertDialogHeader,
|
||||
AlertDialogTitle,
|
||||
} from "@/components/ui/alert-dialog";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Card, CardContent } from "@/components/ui/card";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
|
|
@ -16,6 +26,7 @@ import {
|
|||
import { Input } from "@/components/ui/input";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
import { usePats } from "@/hooks/use-pats";
|
||||
import { copyToClipboard as copyToClipboardUtil } from "@/lib/utils";
|
||||
|
||||
|
|
@ -26,6 +37,7 @@ export function ApiKeyContent() {
|
|||
const [label, setLabel] = useState("");
|
||||
const [expiresInDays, setExpiresInDays] = useState("");
|
||||
const [copiedToken, setCopiedToken] = useState(false);
|
||||
const [deleteTarget, setDeleteTarget] = useState<{ id: number; label: string } | null>(null);
|
||||
|
||||
const sortedTokens = useMemo(() => tokens, [tokens]);
|
||||
|
||||
|
|
@ -51,95 +63,112 @@ export function ApiKeyContent() {
|
|||
}
|
||||
}, [createdToken]);
|
||||
|
||||
const handleDelete = useCallback(
|
||||
async (id: number, tokenLabel: string) => {
|
||||
if (!window.confirm(`Delete personal access token "${tokenLabel}"? This cannot be undone.`)) {
|
||||
return;
|
||||
}
|
||||
await deleteToken(id);
|
||||
},
|
||||
[deleteToken]
|
||||
);
|
||||
const handleConfirmDelete = useCallback(async () => {
|
||||
if (!deleteTarget) return;
|
||||
|
||||
await deleteToken(deleteTarget.id);
|
||||
setDeleteTarget(null);
|
||||
}, [deleteTarget, deleteToken]);
|
||||
|
||||
return (
|
||||
<div className="space-y-6 min-w-0 overflow-hidden">
|
||||
<div className="space-y-6 min-w-0">
|
||||
<Alert>
|
||||
<Info />
|
||||
<AlertDescription>
|
||||
Personal access tokens are long-lived credentials for extensions, Obsidian, and
|
||||
programmatic API clients. Copy a token when you create it; it is shown only once.
|
||||
API keys let extensions, Obsidian, and other apps connect to SurfSense.
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
|
||||
<div className="flex items-center justify-between gap-3">
|
||||
<div>
|
||||
<h3 className="text-sm font-semibold tracking-tight">Personal access tokens</h3>
|
||||
<h3 className="text-sm font-semibold tracking-tight">API keys</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Expired tokens stay listed until you delete them.
|
||||
Expired API keys stay listed until you delete them.
|
||||
</p>
|
||||
</div>
|
||||
<Button size="sm" onClick={() => setCreateOpen(true)}>
|
||||
<Plus className="mr-2 h-4 w-4" />
|
||||
Create token
|
||||
Create API key
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<div className="min-w-0 overflow-hidden rounded-lg border border-border/60">
|
||||
{isLoading ? (
|
||||
<div className="space-y-3 p-4">
|
||||
<Skeleton className="h-12 w-full" />
|
||||
<Skeleton className="h-12 w-full" />
|
||||
</div>
|
||||
) : sortedTokens.length > 0 ? (
|
||||
<div className="divide-y divide-border/60">
|
||||
{sortedTokens.map((token) => {
|
||||
const expiresAt = token.expires_at ? new Date(token.expires_at) : null;
|
||||
const isExpired = expiresAt ? expiresAt.getTime() <= Date.now() : false;
|
||||
return (
|
||||
<div key={token.id} className="flex items-center gap-3 p-4">
|
||||
{isLoading ? (
|
||||
<div className="-m-1 grid grid-cols-1 gap-3 p-1">
|
||||
{["skeleton-a", "skeleton-b"].map((key) => (
|
||||
<Card
|
||||
key={key}
|
||||
className="group relative overflow-hidden transition-all duration-200 border-accent bg-accent/20 hover:shadow-md h-full"
|
||||
>
|
||||
<CardContent className="p-4 flex flex-col gap-3 h-full min-h-24">
|
||||
<Skeleton className="h-4 w-32 md:w-40 bg-accent" />
|
||||
<Skeleton className="h-3 w-full bg-accent" />
|
||||
<Skeleton className="h-3 w-24 md:w-28 bg-accent" />
|
||||
</CardContent>
|
||||
</Card>
|
||||
))}
|
||||
</div>
|
||||
) : sortedTokens.length > 0 ? (
|
||||
<div className="-m-1 grid grid-cols-1 gap-3 p-1">
|
||||
{sortedTokens.map((token) => {
|
||||
const expiresAt = token.expires_at ? new Date(token.expires_at) : null;
|
||||
const isExpired = expiresAt ? expiresAt.getTime() <= Date.now() : false;
|
||||
return (
|
||||
<Card
|
||||
key={token.id}
|
||||
className="group relative overflow-hidden transition-all duration-200 border-accent bg-accent/20 hover:shadow-md h-full"
|
||||
>
|
||||
<CardContent className="flex min-h-24 items-center gap-3 p-4">
|
||||
<div className="min-w-0 flex-1">
|
||||
<div className="flex items-center gap-2">
|
||||
<p className="truncate text-sm font-medium">{token.label}</p>
|
||||
{isExpired ? <Badge variant="secondary">Expired</Badge> : null}
|
||||
<div className="flex flex-col gap-1">
|
||||
<div className="flex items-center gap-2">
|
||||
<h4 className="truncate text-sm font-semibold tracking-tight">
|
||||
{token.label}
|
||||
</h4>
|
||||
{isExpired ? (
|
||||
<span className="rounded-md border-0 bg-muted px-1.5 py-0.5 text-[10px] font-medium text-muted-foreground">
|
||||
Expired
|
||||
</span>
|
||||
) : null}
|
||||
</div>
|
||||
<p className="truncate font-mono text-xs text-muted-foreground">
|
||||
{token.prefix}...
|
||||
</p>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Expires: {expiresAt ? expiresAt.toLocaleDateString() : "Never"} · Last used:{" "}
|
||||
{token.last_used_at ? new Date(token.last_used_at).toLocaleString() : "Never"}
|
||||
</p>
|
||||
</div>
|
||||
<p className="font-mono text-xs text-muted-foreground">{token.prefix}...</p>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Expires: {expiresAt ? expiresAt.toLocaleDateString() : "Never"} · Last used:{" "}
|
||||
{token.last_used_at
|
||||
? new Date(token.last_used_at).toLocaleString()
|
||||
: "Never"}
|
||||
</p>
|
||||
</div>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
disabled={isMutating}
|
||||
onClick={() => handleDelete(token.id, token.label)}
|
||||
onClick={() => setDeleteTarget({ id: token.id, label: token.label })}
|
||||
className="h-7 w-7 shrink-0 rounded-lg text-muted-foreground transition-opacity duration-150 hover:text-accent-foreground sm:opacity-0 sm:pointer-events-none sm:group-hover:opacity-100 sm:group-hover:pointer-events-auto"
|
||||
>
|
||||
<Trash2 className="h-4 w-4 text-muted-foreground" />
|
||||
<Trash2 className="h-4 w-4" />
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
) : (
|
||||
<p className="p-6 text-center text-sm text-muted-foreground">
|
||||
No personal access tokens yet.
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
) : (
|
||||
<p className="py-6 text-center text-sm text-muted-foreground">
|
||||
No API keys yet.
|
||||
</p>
|
||||
)}
|
||||
|
||||
<Dialog open={createOpen} onOpenChange={setCreateOpen}>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>Create personal access token</DialogTitle>
|
||||
<DialogTitle>Create API key</DialogTitle>
|
||||
<DialogDescription>
|
||||
Name this token so you can recognize where it is used later.
|
||||
Name this API key so you can recognize where it is used later.
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
<div className="space-y-4">
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="pat-label">Label</Label>
|
||||
<Label htmlFor="pat-label">Name</Label>
|
||||
<Input
|
||||
id="pat-label"
|
||||
value={label}
|
||||
|
|
@ -160,11 +189,24 @@ export function ApiKeyContent() {
|
|||
</div>
|
||||
</div>
|
||||
<DialogFooter>
|
||||
<Button variant="outline" onClick={() => setCreateOpen(false)}>
|
||||
<Button
|
||||
type="button"
|
||||
variant="secondary"
|
||||
size="sm"
|
||||
onClick={() => setCreateOpen(false)}
|
||||
disabled={isMutating}
|
||||
className="text-sm h-9"
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button disabled={isMutating || !label.trim()} onClick={handleCreate}>
|
||||
Create token
|
||||
<Button
|
||||
size="sm"
|
||||
disabled={isMutating || !label.trim()}
|
||||
onClick={handleCreate}
|
||||
className="relative text-sm h-9 min-w-[128px]"
|
||||
>
|
||||
<span className={isMutating ? "opacity-0" : ""}>Create API key</span>
|
||||
{isMutating && <Spinner size="sm" className="absolute" />}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
|
|
@ -173,17 +215,21 @@ export function ApiKeyContent() {
|
|||
<Dialog open={!!createdToken} onOpenChange={(open) => !open && setCreatedToken(null)}>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>Copy your token now</DialogTitle>
|
||||
<DialogTitle>Copy your API key now</DialogTitle>
|
||||
<DialogDescription>
|
||||
This token is shown only once. Store it somewhere secure before closing this
|
||||
dialog.
|
||||
This API key is shown only once. Store it somewhere secure before closing this dialog.
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
<div className="flex items-center gap-2 rounded-md border border-border/60 bg-muted/30 p-2">
|
||||
<code className="min-w-0 flex-1 overflow-x-auto whitespace-nowrap text-xs">
|
||||
{createdToken?.token}
|
||||
</code>
|
||||
<Button variant="outline" size="sm" onClick={copyCreatedToken}>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={copyCreatedToken}
|
||||
className="border-0 bg-muted/30 hover:bg-muted/50"
|
||||
>
|
||||
{copiedToken ? <Check className="h-4 w-4" /> : <Copy className="h-4 w-4" />}
|
||||
</Button>
|
||||
</div>
|
||||
|
|
@ -192,6 +238,41 @@ export function ApiKeyContent() {
|
|||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
|
||||
<AlertDialog
|
||||
open={deleteTarget !== null}
|
||||
onOpenChange={(open) => !open && setDeleteTarget(null)}
|
||||
>
|
||||
<AlertDialogContent>
|
||||
<AlertDialogHeader>
|
||||
<AlertDialogTitle>Delete API key?</AlertDialogTitle>
|
||||
<AlertDialogDescription>
|
||||
<span className="font-medium text-foreground">{deleteTarget?.label}</span> will be
|
||||
permanently removed. This cannot be undone.
|
||||
</AlertDialogDescription>
|
||||
</AlertDialogHeader>
|
||||
<AlertDialogFooter>
|
||||
<AlertDialogCancel disabled={isMutating}>Cancel</AlertDialogCancel>
|
||||
<AlertDialogAction
|
||||
disabled={isMutating}
|
||||
className="bg-destructive text-white hover:bg-destructive/90"
|
||||
onClick={(event) => {
|
||||
event.preventDefault();
|
||||
void handleConfirmDelete();
|
||||
}}
|
||||
>
|
||||
{isMutating ? (
|
||||
<span className="inline-flex items-center gap-2">
|
||||
<Spinner size="xs" />
|
||||
Deleting...
|
||||
</span>
|
||||
) : (
|
||||
"Delete"
|
||||
)}
|
||||
</AlertDialogAction>
|
||||
</AlertDialogFooter>
|
||||
</AlertDialogContent>
|
||||
</AlertDialog>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -38,13 +38,13 @@ export function CommunityPromptsContent() {
|
|||
const list = prompts ?? [];
|
||||
|
||||
return (
|
||||
<div className="space-y-6 min-w-0 overflow-hidden">
|
||||
<div className="space-y-6 min-w-0">
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Prompts shared by other users. Add any to your collection with one click.
|
||||
</p>
|
||||
|
||||
{isLoading && (
|
||||
<div className="space-y-2">
|
||||
<div className="-m-1 space-y-2 p-1">
|
||||
{["skeleton-a", "skeleton-b", "skeleton-c"].map((key) => (
|
||||
<Card key={key} className="border-accent bg-accent/20">
|
||||
<CardContent className="p-4 flex flex-col gap-3 min-h-24">
|
||||
|
|
@ -76,7 +76,7 @@ export function CommunityPromptsContent() {
|
|||
)}
|
||||
|
||||
{!isLoading && !isError && list.length > 0 && (
|
||||
<div className="space-y-2">
|
||||
<div className="-m-1 space-y-2 p-1">
|
||||
{list.map((prompt) => (
|
||||
<Card
|
||||
key={prompt.id}
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ import { Separator } from "@/components/ui/separator";
|
|||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
import type { SearchSpace } from "@/contracts/types/search-space.types";
|
||||
import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service";
|
||||
import { authenticatedFetch } from "@/lib/auth-utils";
|
||||
import { authenticatedFetch } from "@/lib/auth-fetch";
|
||||
import { buildBackendUrl } from "@/lib/env-config";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
|
|
|
|||
|
|
@ -148,7 +148,7 @@ export function PromptsContent() {
|
|||
const list = prompts ?? [];
|
||||
|
||||
return (
|
||||
<div className="space-y-6 min-w-0 overflow-hidden">
|
||||
<div className="space-y-6 min-w-0">
|
||||
<div className="flex items-center justify-between">
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Create prompt templates triggered with <ShortcutKbd keys={["/"]} className="ml-0" /> in
|
||||
|
|
@ -276,7 +276,7 @@ export function PromptsContent() {
|
|||
</Dialog>
|
||||
|
||||
{isLoading && (
|
||||
<div className="space-y-2">
|
||||
<div className="-m-1 space-y-2 p-1">
|
||||
{["skeleton-a", "skeleton-b", "skeleton-c"].map((key) => (
|
||||
<Card key={key} className="border-accent bg-accent/20">
|
||||
<CardContent className="p-4 flex flex-col gap-3 min-h-24">
|
||||
|
|
@ -308,7 +308,7 @@ export function PromptsContent() {
|
|||
)}
|
||||
|
||||
{!isLoading && !isError && list.length > 0 && (
|
||||
<div className="space-y-2">
|
||||
<div className="-m-1 space-y-2 p-1">
|
||||
{list.map((prompt) => (
|
||||
<div
|
||||
key={prompt.id}
|
||||
|
|
|
|||
|
|
@ -3,31 +3,29 @@
|
|||
import { useEffect, useState } from "react";
|
||||
import { USER_QUERY_KEY } from "@/atoms/user/user-query.atoms";
|
||||
import { useGlobalLoadingEffect } from "@/hooks/use-global-loading";
|
||||
import { ensureTokensFromElectron, getBearerToken, redirectToLogin } from "@/lib/auth-utils";
|
||||
import { useSession } from "@/hooks/use-session";
|
||||
import { redirectToLogin } from "@/lib/auth-utils";
|
||||
import { queryClient } from "@/lib/query-client/client";
|
||||
|
||||
export function DashboardShell({ children }: { children: React.ReactNode }) {
|
||||
const [isCheckingAuth, setIsCheckingAuth] = useState(true);
|
||||
const session = useSession();
|
||||
|
||||
// Use the global loading screen - spinner animation won't reset
|
||||
useGlobalLoadingEffect(isCheckingAuth);
|
||||
|
||||
useEffect(() => {
|
||||
async function checkAuth() {
|
||||
let token = getBearerToken();
|
||||
if (!token) {
|
||||
const synced = await ensureTokensFromElectron();
|
||||
if (synced) token = getBearerToken();
|
||||
}
|
||||
if (!token) {
|
||||
if (session.status === "loading") return;
|
||||
if (session.status === "unauthenticated") {
|
||||
redirectToLogin();
|
||||
return;
|
||||
}
|
||||
queryClient.invalidateQueries({ queryKey: [...USER_QUERY_KEY] });
|
||||
setIsCheckingAuth(false);
|
||||
}
|
||||
checkAuth();
|
||||
}, []);
|
||||
void checkAuth();
|
||||
}, [session.status]);
|
||||
|
||||
// Return null while loading - the global provider handles the loading UI
|
||||
if (isCheckingAuth) {
|
||||
|
|
|
|||
|
|
@ -9,13 +9,7 @@ import { useEffect, useState } from "react";
|
|||
import { searchSpacesAtom } from "@/atoms/search-spaces/search-space-query.atoms";
|
||||
import { CreateSearchSpaceDialog } from "@/components/layout";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Card,
|
||||
CardDescription,
|
||||
CardFooter,
|
||||
CardHeader,
|
||||
CardTitle,
|
||||
} from "@/components/ui/card";
|
||||
import { Card, CardDescription, CardFooter, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import { useGlobalLoadingEffect } from "@/hooks/use-global-loading";
|
||||
|
||||
function ErrorScreen({ message }: { message: string }) {
|
||||
|
|
|
|||
|
|
@ -1,12 +1,10 @@
|
|||
"use client";
|
||||
|
||||
import { useAtom } from "jotai";
|
||||
import { Crop, Eye, EyeOff, Rocket, RotateCcw, Zap } from "lucide-react";
|
||||
import Image from "next/image";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
import { loginMutationAtom } from "@/atoms/auth/auth-mutation.atoms";
|
||||
import { DEFAULT_SHORTCUTS, keyEventToAccelerator } from "@/components/desktop/shortcut-recorder";
|
||||
import { useIsGoogleAuth } from "@/components/providers/runtime-config";
|
||||
import { Button } from "@/components/ui/button";
|
||||
|
|
@ -17,8 +15,7 @@ import { ShortcutKbd } from "@/components/ui/shortcut-kbd";
|
|||
import { Spinner } from "@/components/ui/spinner";
|
||||
import { useElectronAPI } from "@/hooks/use-platform";
|
||||
import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service";
|
||||
import { setBearerToken } from "@/lib/auth-utils";
|
||||
import { buildBackendUrl } from "@/lib/env-config";
|
||||
import { getPostLoginRedirectPath } from "@/lib/auth-utils";
|
||||
|
||||
type ShortcutKey = "generalAssist" | "quickAsk" | "screenshotAssist";
|
||||
type ShortcutMap = typeof DEFAULT_SHORTCUTS;
|
||||
|
|
@ -190,12 +187,12 @@ export default function DesktopLoginPage() {
|
|||
const router = useRouter();
|
||||
const api = useElectronAPI();
|
||||
const isGoogleAuth = useIsGoogleAuth();
|
||||
const [{ mutateAsync: login, isPending: isLoggingIn }] = useAtom(loginMutationAtom);
|
||||
|
||||
const [email, setEmail] = useState("");
|
||||
const [password, setPassword] = useState("");
|
||||
const [showPassword, setShowPassword] = useState(false);
|
||||
const [loginError, setLoginError] = useState<string | null>(null);
|
||||
const [isLoggingIn, setIsLoggingIn] = useState(false);
|
||||
const [isGoogleRedirecting, setIsGoogleRedirecting] = useState(false);
|
||||
|
||||
const [shortcuts, setShortcuts] = useState(DEFAULT_SHORTCUTS);
|
||||
|
|
@ -237,10 +234,17 @@ export default function DesktopLoginPage() {
|
|||
[updateShortcut]
|
||||
);
|
||||
|
||||
const handleGoogleLogin = () => {
|
||||
const handleGoogleLogin = async () => {
|
||||
if (isGoogleRedirecting) return;
|
||||
setIsGoogleRedirecting(true);
|
||||
window.location.href = buildBackendUrl("/auth/google/authorize-redirect");
|
||||
try {
|
||||
await api?.startGoogleOAuth?.();
|
||||
await autoSetSearchSpace();
|
||||
router.push(getPostLoginRedirectPath());
|
||||
} catch (error) {
|
||||
setIsGoogleRedirecting(false);
|
||||
toast.error(error instanceof Error ? error.message : "Google sign-in failed");
|
||||
}
|
||||
};
|
||||
|
||||
const autoSetSearchSpace = async () => {
|
||||
|
|
@ -259,23 +263,19 @@ export default function DesktopLoginPage() {
|
|||
const handleLocalLogin = async (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
setLoginError(null);
|
||||
if (isLoggingIn) return;
|
||||
setIsLoggingIn(true);
|
||||
|
||||
try {
|
||||
const data = await login({
|
||||
username: email,
|
||||
password,
|
||||
grant_type: "password",
|
||||
});
|
||||
|
||||
if (typeof window !== "undefined") {
|
||||
sessionStorage.setItem("login_success_tracked", "true");
|
||||
if (!api?.loginPassword) {
|
||||
throw new Error("Desktop password login is not available");
|
||||
}
|
||||
await api.loginPassword(email, password);
|
||||
|
||||
setBearerToken(data.access_token);
|
||||
await autoSetSearchSpace();
|
||||
|
||||
setTimeout(() => {
|
||||
router.push(`/auth/callback?token=${data.access_token}`);
|
||||
router.push(getPostLoginRedirectPath());
|
||||
}, 300);
|
||||
} catch (err) {
|
||||
if (err instanceof Error) {
|
||||
|
|
@ -283,6 +283,8 @@ export default function DesktopLoginPage() {
|
|||
} else {
|
||||
setLoginError("Login failed. Please check your credentials.");
|
||||
}
|
||||
} finally {
|
||||
setIsLoggingIn(false);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -30,8 +30,9 @@ import {
|
|||
} from "@/components/ui/card";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
import type { AcceptInviteResponse } from "@/contracts/types/invites.types";
|
||||
import { useSession } from "@/hooks/use-session";
|
||||
import { invitesApiService } from "@/lib/apis/invites-api.service";
|
||||
import { getBearerToken, setRedirectPath } from "@/lib/auth-utils";
|
||||
import { setRedirectPath } from "@/lib/auth-utils";
|
||||
import {
|
||||
trackSearchSpaceInviteAccepted,
|
||||
trackSearchSpaceInviteDeclined,
|
||||
|
|
@ -43,6 +44,7 @@ export default function InviteAcceptPage() {
|
|||
const params = useParams();
|
||||
const router = useRouter();
|
||||
const inviteCode = params.invite_code as string;
|
||||
const session = useSession();
|
||||
|
||||
const { data: inviteInfo = null, isLoading: loading } = useQuery({
|
||||
queryKey: cacheKeys.invites.info(inviteCode),
|
||||
|
|
@ -81,11 +83,9 @@ export default function InviteAcceptPage() {
|
|||
|
||||
// Check if user is logged in
|
||||
useEffect(() => {
|
||||
if (typeof window !== "undefined") {
|
||||
const token = getBearerToken();
|
||||
setIsLoggedIn(!!token);
|
||||
}
|
||||
}, []);
|
||||
if (session.status === "loading") return;
|
||||
setIsLoggedIn(session.status === "authenticated");
|
||||
}, [session.status]);
|
||||
|
||||
const handleAccept = async () => {
|
||||
setAccepting(true);
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import { Roboto } from "next/font/google";
|
|||
import Script from "next/script";
|
||||
import { AnnouncementToastProvider } from "@/components/announcements/AnnouncementToastProvider";
|
||||
import { DesktopUpdateToast } from "@/components/desktop/desktop-update-toast";
|
||||
import { AuthCutoverPurge } from "@/components/providers/AuthCutoverPurge";
|
||||
import { GlobalLoadingProvider } from "@/components/providers/GlobalLoadingProvider";
|
||||
import { I18nProvider } from "@/components/providers/I18nProvider";
|
||||
import { PostHogProvider } from "@/components/providers/PostHogProvider";
|
||||
|
|
@ -17,13 +18,10 @@ import {
|
|||
import { ThemeProvider } from "@/components/theme/theme-provider";
|
||||
import { Toaster } from "@/components/ui/sonner";
|
||||
import { LocaleProvider } from "@/contexts/LocaleContext";
|
||||
import { BUILD_TIME_AUTH_TYPE } from "@/lib/env-config";
|
||||
import { PlatformProvider } from "@/contexts/platform-context";
|
||||
import { BUILD_TIME_AUTH_TYPE } from "@/lib/env-config";
|
||||
import { ReactQueryClientProvider } from "@/lib/query-client/query-client.provider";
|
||||
import {
|
||||
getRuntimeAuthInitScript,
|
||||
resolveRuntimeAuthUiMode,
|
||||
} from "@/lib/runtime-auth-config";
|
||||
import { getRuntimeAuthInitScript, resolveRuntimeAuthUiMode } from "@/lib/runtime-auth-config";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
const roboto = Roboto({
|
||||
|
|
@ -164,6 +162,7 @@ export default function RootLayout({
|
|||
<PlatformProvider>
|
||||
<RootProvider>
|
||||
<ReactQueryClientProvider>
|
||||
<AuthCutoverPurge />
|
||||
<ZeroProvider>
|
||||
<GlobalLoadingProvider>{children}</GlobalLoadingProvider>
|
||||
</ZeroProvider>
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ export async function GET(request: NextRequest) {
|
|||
headers: {
|
||||
Authorization: request.headers.get("authorization") || "",
|
||||
"X-API-Key": request.headers.get("x-api-key") || "",
|
||||
Cookie: request.headers.get("cookie") || "",
|
||||
},
|
||||
cache: "no-store",
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import { atomWithQuery } from "jotai-tanstack-query";
|
||||
import { agentFlagsApiService } from "@/lib/apis/agent-flags-api.service";
|
||||
import { getBearerToken } from "@/lib/auth-utils";
|
||||
import { isAuthenticated } from "@/lib/auth-utils";
|
||||
|
||||
export const AGENT_FLAGS_QUERY_KEY = ["agent", "flags"] as const;
|
||||
|
||||
|
|
@ -12,6 +12,6 @@ export const AGENT_FLAGS_QUERY_KEY = ["agent", "flags"] as const;
|
|||
export const agentFlagsAtom = atomWithQuery(() => ({
|
||||
queryKey: AGENT_FLAGS_QUERY_KEY,
|
||||
staleTime: 10 * 60 * 1000,
|
||||
enabled: !!getBearerToken(),
|
||||
enabled: isAuthenticated(),
|
||||
queryFn: () => agentFlagsApiService.get(),
|
||||
}));
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue