diff --git a/surfsense_web/app/api/zero/query/route.ts b/surfsense_web/app/api/zero/query/route.ts index baee8bd75..98cd4551c 100644 --- a/surfsense_web/app/api/zero/query/route.ts +++ b/surfsense_web/app/api/zero/query/route.ts @@ -1,17 +1,49 @@ import { mustGetQuery } from "@rocicorp/zero"; import { handleQueryRequest } from "@rocicorp/zero/server"; import { NextResponse } from "next/server"; +import type { Context } from "@/types/zero"; import { queries } from "@/zero/queries"; import { schema } from "@/zero/schema"; +const backendURL = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; + +async function authenticateRequest( + request: Request +): Promise<{ ctx: Context; error?: never } | { ctx?: never; error: NextResponse }> { + const authHeader = request.headers.get("Authorization"); + if (!authHeader?.startsWith("Bearer ")) { + return { error: NextResponse.json({ error: "Unauthorized" }, { status: 401 }) }; + } + + try { + const res = await fetch(`${backendURL}/users/me`, { + headers: { Authorization: authHeader }, + }); + + if (!res.ok) { + return { error: NextResponse.json({ error: "Unauthorized" }, { status: 401 }) }; + } + + const user = await res.json(); + return { ctx: { userId: String(user.id) } }; + } catch { + return { error: NextResponse.json({ error: "Auth service unavailable" }, { status: 503 }) }; + } +} + export async function POST(request: Request) { + const auth = await authenticateRequest(request); + if (auth.error) { + return auth.error; + } + const result = await handleQueryRequest( (name, args) => { const query = mustGetQuery(queries, name); - return query.fn({ args, ctx: undefined }); + return query.fn({ args, ctx: auth.ctx }); }, schema, - request, + request ); return NextResponse.json(result); diff --git a/surfsense_web/components/providers/ZeroProvider.tsx b/surfsense_web/components/providers/ZeroProvider.tsx index 1a0a2f937..3bd579ea9 100644 --- a/surfsense_web/components/providers/ZeroProvider.tsx +++ b/surfsense_web/components/providers/ZeroProvider.tsx @@ -1,26 +1,60 @@ "use client"; +import { + useConnectionState, + useZero, + ZeroProvider as ZeroReactProvider, +} from "@rocicorp/zero/react"; +import { useAtomValue } from "jotai"; +import { useEffect, useRef } from "react"; import { currentUserAtom } from "@/atoms/user/user-query.atoms"; +import { getBearerToken, handleUnauthorized, refreshAccessToken } from "@/lib/auth-utils"; import { queries } from "@/zero/queries"; import { schema } from "@/zero/schema"; -import { ZeroProvider as ZeroReactProvider } from "@rocicorp/zero/react"; -import { useAtomValue } from "jotai"; const cacheURL = process.env.NEXT_PUBLIC_ZERO_CACHE_URL || "http://localhost:4848"; +function ZeroAuthGuard({ children }: { children: React.ReactNode }) { + const zero = useZero(); + const connectionState = useConnectionState(); + const isRefreshingRef = useRef(false); + + useEffect(() => { + if (connectionState.name !== "needs-auth" || isRefreshingRef.current) return; + + isRefreshingRef.current = true; + + refreshAccessToken() + .then((newToken) => { + if (newToken) { + zero.connection.connect({ auth: newToken }); + } else { + handleUnauthorized(); + } + }) + .finally(() => { + isRefreshingRef.current = false; + }); + }, [connectionState, zero]); + + return <>{children}; +} + export function ZeroProvider({ children }: { children: React.ReactNode }) { const { data: user } = useAtomValue(currentUserAtom); + const token = getBearerToken(); - if (!user?.id) { + if (!user?.id || !token) { return <>{children}; } const userID = String(user.id); const context = { userId: userID }; + const auth = token; return ( - - {children} + + {children} ); }