diff --git a/.github/workflows/desktop-release.yml b/.github/workflows/desktop-release.yml index 7336fa9bd..4fcc8c597 100644 --- a/.github/workflows/desktop-release.yml +++ b/.github/workflows/desktop-release.yml @@ -95,12 +95,10 @@ jobs: run: pnpm build working-directory: surfsense_web env: - NEXT_PUBLIC_FASTAPI_BACKEND_URL: ${{ vars.HOSTED_BACKEND_URL }} - SURFSENSE_BACKEND_INTERNAL_URL: ${{ vars.HOSTED_BACKEND_URL }} + NEXT_PUBLIC_FASTAPI_BACKEND_URL: ${{ vars.NEXT_PUBLIC_FASTAPI_BACKEND_URL }} NEXT_PUBLIC_ZERO_CACHE_URL: ${{ vars.NEXT_PUBLIC_ZERO_CACHE_URL }} NEXT_PUBLIC_DEPLOYMENT_MODE: ${{ vars.NEXT_PUBLIC_DEPLOYMENT_MODE }} - NEXT_PUBLIC_AUTH_TYPE: ${{ vars.NEXT_PUBLIC_AUTH_TYPE }} - NEXT_PUBLIC_ETL_SERVICE: ${{ vars.NEXT_PUBLIC_ETL_SERVICE }} + NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE: ${{ vars.NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE }} NEXT_PUBLIC_POSTHOG_KEY: ${{ secrets.NEXT_PUBLIC_POSTHOG_KEY }} - name: Install desktop dependencies @@ -111,7 +109,6 @@ jobs: run: pnpm build working-directory: surfsense_desktop env: - HOSTED_BACKEND_URL: ${{ vars.HOSTED_BACKEND_URL }} HOSTED_FRONTEND_URL: ${{ vars.HOSTED_FRONTEND_URL }} POSTHOG_KEY: ${{ secrets.POSTHOG_KEY }} POSTHOG_HOST: ${{ vars.POSTHOG_HOST }} diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index f82c4d609..8da56fc33 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -199,6 +199,11 @@ jobs: build-args: | ${{ matrix.image == 'backend' && format('USE_CUDA={0}', matrix.use_cuda) || '' }} ${{ matrix.image == 'backend' && format('CUDA_EXTRA={0}', matrix.cuda_extra) || '' }} + ${{ matrix.image == 'web' && 'NEXT_PUBLIC_FASTAPI_BACKEND_URL=__NEXT_PUBLIC_FASTAPI_BACKEND_URL__' || '' }} + ${{ matrix.image == 'web' && 'NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=__NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE__' || '' }} + ${{ matrix.image == 'web' && 'NEXT_PUBLIC_ETL_SERVICE=__NEXT_PUBLIC_ETL_SERVICE__' || '' }} + ${{ matrix.image == 'web' && 'NEXT_PUBLIC_ZERO_CACHE_URL=__NEXT_PUBLIC_ZERO_CACHE_URL__' || '' }} + ${{ matrix.image == 'web' && 'NEXT_PUBLIC_DEPLOYMENT_MODE=__NEXT_PUBLIC_DEPLOYMENT_MODE__' || '' }} - name: Export digest run: | diff --git a/.github/workflows/e2e-tests.yml b/.github/workflows/e2e-tests.yml index bd9cff13b..b87537dab 100644 --- a/.github/workflows/e2e-tests.yml +++ b/.github/workflows/e2e-tests.yml @@ -27,10 +27,9 @@ jobs: PLAYWRIGHT_TEST_EMAIL: e2e-test@surfsense.net PLAYWRIGHT_TEST_PASSWORD: E2eTestPassword123! # Frontend env: Playwright's webServer (surfsense_web/playwright.config.ts) - # spawns `pnpm build && pnpm start` in CI. + # spawns `pnpm build && pnpm start` in CI; these get baked into the build. NEXT_PUBLIC_FASTAPI_BACKEND_URL: http://localhost:8000 - SURFSENSE_BACKEND_INTERNAL_URL: http://localhost:8000 - AUTH_TYPE: LOCAL + NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE: LOCAL # Shared secret for the test-only POST /__e2e__/auth/token endpoint. # Must match docker-compose.e2e.yml's backend env (x-backend-env). E2E_MINT_SECRET: e2e-mint-secret-not-for-production diff --git a/VERSION b/VERSION index 369bd4c2a..24ff85581 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.0.29 +0.0.27 diff --git a/docker/.env.example b/docker/.env.example index 63308bc9e..cafc74af9 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -30,9 +30,6 @@ SECRET_KEY=replace_me_with_a_random_string # Auth type: LOCAL (email/password) or GOOGLE (OAuth) AUTH_TYPE=LOCAL -# Deployment mode: self-hosted enables local filesystem connectors; cloud hides them. -DEPLOYMENT_MODE=self-hosted - # Allow new user registrations (TRUE or FALSE) # REGISTRATION_ENABLED=TRUE @@ -46,47 +43,51 @@ ETL_SERVICE=DOCLING EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2 # ------------------------------------------------------------------------------ -# How You Access SurfSense +# Ports (change to avoid conflicts with other services on your machine) # ------------------------------------------------------------------------------ -# One public URL. Browser traffic stays same-origin and Caddy routes internally. -SURFSENSE_PUBLIC_URL=http://localhost:3929 + +# BACKEND_PORT=8929 +# FRONTEND_PORT=3929 +# ZERO_CACHE_PORT=5929 +# SEARXNG_PORT=8888 +# FLOWER_PORT=5555 + +# ============================================================================== +# DEV COMPOSE ONLY (docker-compose.dev.yml) +# You only need them only if you are running `docker-compose.dev.yml`. +# ============================================================================== + +# -- pgAdmin (database GUI) -- +# PGADMIN_PORT=5050 +# PGADMIN_DEFAULT_EMAIL=admin@surfsense.com +# PGADMIN_DEFAULT_PASSWORD=surfsense + +# -- Redis exposed port (dev only; Redis is internal-only in prod) -- +# REDIS_PORT=6379 + +# -- WhatsApp bridge exposed port (dev/hybrid only; prod keeps it Docker-internal) -- +# WHATSAPP_BRIDGE_PORT=9929 + +# -- Frontend Build Args -- +# In dev, the frontend is built from source and these are passed as build args. +# In prod, they are automatically derived from AUTH_TYPE, ETL_SERVICE, and the port settings above. +# NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=LOCAL +# NEXT_PUBLIC_ETL_SERVICE=DOCLING +# NEXT_PUBLIC_DEPLOYMENT_MODE=self-hosted # ------------------------------------------------------------------------------ -# Public Ports +# Custom Domain / Reverse Proxy # ------------------------------------------------------------------------------ -# Production Docker exposes only Caddy to your machine. Caddy then routes -# frontend, backend, and zero-cache traffic internally. +# ONLY set these if you are serving SurfSense on a real domain via a reverse +# proxy (e.g. Caddy, Nginx, Cloudflare Tunnel). +# For standard localhost deployments, leave all of these commented out. +# they are automatically derived from the port settings above. # -# Local default: LISTEN_HTTP_PORT=3929 -# Domain default: LISTEN_HTTP_PORT=80 and LISTEN_HTTPS_PORT=443 -LISTEN_HTTP_PORT=3929 -LISTEN_HTTPS_PORT=443 - -# ------------------------------------------------------------------------------ -# Custom Domain / HTTPS -# ------------------------------------------------------------------------------ -# Leave SURFSENSE_SITE_ADDRESS as :80 for local HTTP. -# Set it to your domain to enable automatic HTTPS: -# SURFSENSE_SITE_ADDRESS=surf.example.com -# CERT_EMAIL=you@example.com -SURFSENSE_SITE_ADDRESS=:80 -CERT_EMAIL= - -# ------------------------------------------------------------------------------ -# Advanced Reverse Proxy Settings -# ------------------------------------------------------------------------------ -# Usually do not change these. They are for custom certificate setup, CDNs/load -# balancers, trusted proxy IPs, or changing upload limits. -# -# CERT_ACME_CA=https://acme-v02.api.letsencrypt.org/directory -# CERT_ACME_DNS= -# If a CDN/load balancer sits in front of Caddy, narrow this to that proxy's CIDRs. -# TRUSTED_PROXIES=0.0.0.0/0 -# SURFSENSE_MAX_BODY_SIZE=5GB -# -# Browser API and Zero URLs are same-origin relative behind bundled Caddy. -# Next.js server-side calls use Docker DNS through SURFSENSE_BACKEND_INTERNAL_URL -# set internally by docker-compose.yml. Usually do not override it. +# NEXT_FRONTEND_URL=https://app.yourdomain.com +# BACKEND_URL=https://api.yourdomain.com +# NEXT_PUBLIC_FASTAPI_BACKEND_URL=https://api.yourdomain.com +# NEXT_PUBLIC_ZERO_CACHE_URL=https://zero.yourdomain.com +# FASTAPI_BACKEND_INTERNAL_URL=http://backend:8000 # ------------------------------------------------------------------------------ # Zero-cache (real-time sync) @@ -107,9 +108,10 @@ CERT_EMAIL= # Sync worker tuning. zero-cache defaults ZERO_NUM_SYNC_WORKERS to the number # of CPU cores, which can exceed the connection pool limits on high-core machines. -# Each sync worker needs at least 1 connection from both the UPSTREAM and CVR pools. -# Keep ZERO_UPSTREAM_MAX_CONNS and ZERO_CVR_MAX_CONNS greater than or equal to -# ZERO_NUM_SYNC_WORKERS. +# Each sync worker needs at least 1 connection from both the UPSTREAM and CVR +# pools, so these constraints must hold: +# ZERO_UPSTREAM_MAX_CONNS >= ZERO_NUM_SYNC_WORKERS +# ZERO_CVR_MAX_CONNS >= ZERO_NUM_SYNC_WORKERS # Default of 4 workers is sufficient for self-hosted / personal use. # ZERO_NUM_SYNC_WORKERS=4 # ZERO_UPSTREAM_MAX_CONNS=20 @@ -123,16 +125,16 @@ CERT_EMAIL= # ZERO_QUERY_URL: where zero-cache forwards query requests for resolution. # ZERO_MUTATE_URL: required by zero-cache when auth tokens are used, even though -# SurfSense does not use Zero mutators. Setting both URLs tells zero-cache to -# skip its own JWT verification and let the app endpoints handle auth instead. -# The mutate endpoint is a no-op that returns an empty response. +# SurfSense does not use Zero mutators. Setting both URLs tells zero-cache to +# skip its own JWT verification and let the app endpoints handle auth instead. +# The mutate endpoint is a no-op that returns an empty response. # Default: Docker service networking (http://frontend:3000/api/zero/...). # Override when running the frontend outside Docker: -# ZERO_QUERY_URL=http://host.docker.internal:3000/api/zero/query -# ZERO_MUTATE_URL=http://host.docker.internal:3000/api/zero/mutate -# Override for custom domain only when zero-cache is not in the bundled Docker network: -# ZERO_QUERY_URL=https://surf.example.com/api/zero/query -# ZERO_MUTATE_URL=https://surf.example.com/api/zero/mutate +# ZERO_QUERY_URL=http://host.docker.internal:3000/api/zero/query +# ZERO_MUTATE_URL=http://host.docker.internal:3000/api/zero/mutate +# Override for custom domain: +# ZERO_QUERY_URL=https://app.yourdomain.com/api/zero/query +# ZERO_MUTATE_URL=https://app.yourdomain.com/api/zero/mutate # ZERO_QUERY_URL=http://frontend:3000/api/zero/query # ZERO_MUTATE_URL=http://frontend:3000/api/zero/mutate @@ -164,26 +166,25 @@ CERT_EMAIL= # REDIS_URL=redis://redis:6379/0 # ------------------------------------------------------------------------------ -# Stripe (unified credit wallet, disabled by default) +# Stripe (pay-as-you-go page packs, disabled by default) # ------------------------------------------------------------------------------ -# Set TRUE to allow users to buy credit packs via Stripe Checkout. $1 buys -# 1_000_000 micro-USD of credit; both ETL page processing and premium turns -# debit this balance at the actual per-call provider cost from LiteLLM. -STRIPE_CREDIT_BUYING_ENABLED=FALSE +# Set TRUE to allow users to buy additional page packs via Stripe Checkout +STRIPE_PAGE_BUYING_ENABLED=FALSE # STRIPE_SECRET_KEY=sk_test_... # STRIPE_WEBHOOK_SECRET=whsec_... -# STRIPE_CREDIT_PRICE_ID=price_... -# STRIPE_CREDIT_MICROS_PER_UNIT=1000000 +# STRIPE_PRICE_ID=price_... +# STRIPE_PAGES_PER_UNIT=1000 # STRIPE_RECONCILIATION_INTERVAL=10m # STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10 # STRIPE_RECONCILIATION_BATCH_SIZE=100 -# Auto-reload: top up via a saved Stripe card when the balance drops below -# the user-chosen threshold. Off by default. -# AUTO_RELOAD_ENABLED=FALSE -# AUTO_RELOAD_MIN_AMOUNT_MICROS=1000000 -# AUTO_RELOAD_COOLDOWN_MINUTES=10 +# Premium credit purchases via Stripe ($1 buys 1_000_000 micro-USD of +# credit; premium turns debit the actual per-call provider cost +# reported by LiteLLM, so cheap and expensive models bill proportionally) +# STRIPE_TOKEN_BUYING_ENABLED=FALSE +# STRIPE_PREMIUM_TOKEN_PRICE_ID=price_... +# STRIPE_CREDIT_MICROS_PER_UNIT=1000000 # ------------------------------------------------------------------------------ # TTS & STT (Text-to-Speech / Speech-to-Text) @@ -220,74 +221,73 @@ STT_SERVICE=local/base # ------------------------------------------------------------------------------ # -- Google Connectors -- -# GOOGLE_CALENDAR_REDIRECT_URI=http://localhost:3929/api/v1/auth/google/calendar/connector/callback -# GOOGLE_GMAIL_REDIRECT_URI=http://localhost:3929/api/v1/auth/google/gmail/connector/callback -# GOOGLE_DRIVE_REDIRECT_URI=http://localhost:3929/api/v1/auth/google/drive/connector/callback +# GOOGLE_CALENDAR_REDIRECT_URI=http://localhost:8000/api/v1/auth/google/calendar/connector/callback +# GOOGLE_GMAIL_REDIRECT_URI=http://localhost:8000/api/v1/auth/google/gmail/connector/callback +# GOOGLE_DRIVE_REDIRECT_URI=http://localhost:8000/api/v1/auth/google/drive/connector/callback # -- Notion -- # NOTION_CLIENT_ID= # NOTION_CLIENT_SECRET= -# NOTION_REDIRECT_URI=http://localhost:3929/api/v1/auth/notion/connector/callback +# NOTION_REDIRECT_URI=http://localhost:8000/api/v1/auth/notion/connector/callback # -- Slack -- # SLACK_CLIENT_ID= # SLACK_CLIENT_SECRET= -# SLACK_REDIRECT_URI=http://localhost:3929/api/v1/auth/slack/connector/callback +# SLACK_REDIRECT_URI=http://localhost:8000/api/v1/auth/slack/connector/callback # -- Discord -- # DISCORD_CLIENT_ID= # DISCORD_CLIENT_SECRET= -# DISCORD_REDIRECT_URI=http://localhost:3929/api/v1/auth/discord/connector/callback +# DISCORD_REDIRECT_URI=http://localhost:8000/api/v1/auth/discord/connector/callback # DISCORD_BOT_TOKEN= # -- Atlassian (Jira & Confluence) -- # ATLASSIAN_CLIENT_ID= # ATLASSIAN_CLIENT_SECRET= -# JIRA_REDIRECT_URI=http://localhost:3929/api/v1/auth/jira/connector/callback -# CONFLUENCE_REDIRECT_URI=http://localhost:3929/api/v1/auth/confluence/connector/callback +# JIRA_REDIRECT_URI=http://localhost:8000/api/v1/auth/jira/connector/callback +# CONFLUENCE_REDIRECT_URI=http://localhost:8000/api/v1/auth/confluence/connector/callback # -- Linear -- # LINEAR_CLIENT_ID= # LINEAR_CLIENT_SECRET= -# LINEAR_REDIRECT_URI=http://localhost:3929/api/v1/auth/linear/connector/callback +# LINEAR_REDIRECT_URI=http://localhost:8000/api/v1/auth/linear/connector/callback # -- ClickUp -- # CLICKUP_CLIENT_ID= # CLICKUP_CLIENT_SECRET= -# CLICKUP_REDIRECT_URI=http://localhost:3929/api/v1/auth/clickup/connector/callback +# CLICKUP_REDIRECT_URI=http://localhost:8000/api/v1/auth/clickup/connector/callback # -- Airtable -- # AIRTABLE_CLIENT_ID= # AIRTABLE_CLIENT_SECRET= -# AIRTABLE_REDIRECT_URI=http://localhost:3929/api/v1/auth/airtable/connector/callback +# AIRTABLE_REDIRECT_URI=http://localhost:8000/api/v1/auth/airtable/connector/callback # -- Microsoft OAuth (Teams & OneDrive) -- # MICROSOFT_CLIENT_ID= # MICROSOFT_CLIENT_SECRET= -# TEAMS_REDIRECT_URI=http://localhost:3929/api/v1/auth/teams/connector/callback -# ONEDRIVE_REDIRECT_URI=http://localhost:3929/api/v1/auth/onedrive/connector/callback +# TEAMS_REDIRECT_URI=http://localhost:8000/api/v1/auth/teams/connector/callback +# ONEDRIVE_REDIRECT_URI=http://localhost:8000/api/v1/auth/onedrive/connector/callback # -- Dropbox -- # DROPBOX_APP_KEY= # DROPBOX_APP_SECRET= -# DROPBOX_REDIRECT_URI=http://localhost:3929/api/v1/auth/dropbox/connector/callback +# DROPBOX_REDIRECT_URI=http://localhost:8000/api/v1/auth/dropbox/connector/callback # -- Composio -- # COMPOSIO_API_KEY= # COMPOSIO_ENABLED=TRUE -# COMPOSIO_REDIRECT_URI=http://localhost:3929/api/v1/auth/composio/connector/callback +# COMPOSIO_REDIRECT_URI=http://localhost:8000/api/v1/auth/composio/connector/callback # ------------------------------------------------------------------------------ # Messaging Channels (optional) # ------------------------------------------------------------------------------ # Configure only the external chat channels you want to use. -# GATEWAY_ENABLED=TRUE # -- Telegram -- # TELEGRAM_SHARED_BOT_TOKEN= # TELEGRAM_SHARED_BOT_USERNAME= # TELEGRAM_WEBHOOK_SECRET= -# GATEWAY_BASE_URL=http://localhost:3929 +# GATEWAY_BASE_URL=http://localhost:8929 # GATEWAY_TELEGRAM_INTAKE_MODE=webhook # -- WhatsApp -- @@ -306,20 +306,20 @@ STT_SERVICE=local/base # # GATEWAY_SLACK_ENABLED=FALSE # GATEWAY_SLACK_SIGNING_SECRET= -# GATEWAY_SLACK_REDIRECT_URI=http://localhost:3929/api/v1/gateway/slack/callback +# GATEWAY_SLACK_REDIRECT_URI=http://localhost:8929/api/v1/gateway/slack/callback # -- Discord -- # Uses DISCORD_CLIENT_ID, DISCORD_CLIENT_SECRET, and DISCORD_BOT_TOKEN from the # Discord connector section. # # GATEWAY_DISCORD_ENABLED=FALSE -# GATEWAY_DISCORD_REDIRECT_URI=http://localhost:3929/api/v1/gateway/discord/callback +# GATEWAY_DISCORD_REDIRECT_URI=http://localhost:8929/api/v1/gateway/discord/callback # ------------------------------------------------------------------------------ # SearXNG (bundled web search, works out of the box with no config needed) # ------------------------------------------------------------------------------ # SearXNG provides web search to all search spaces automatically. -# To access the SearXNG UI directly in dev/deps-only compose: http://localhost:8888 +# To access the SearXNG UI directly: http://localhost:8888 # To disable the service entirely: docker compose up --scale searxng=0 # To point at your own SearXNG instance instead of the bundled one: # SEARXNG_DEFAULT_HOST=http://your-searxng:8080 @@ -407,16 +407,13 @@ SURFSENSE_ENABLE_DOOM_LOOP=true # ACCESS_TOKEN_LIFETIME_SECONDS=86400 # REFRESH_TOKEN_LIFETIME_SECONDS=1209600 -# Unified credit wallet starting balance for new users, in micro-USD -# (default: $5). Funds both ETL page processing and premium model calls, -# debited at the actual per-call provider cost reported by LiteLLM. -# DEFAULT_CREDIT_MICROS_BALANCE=5000000 +# Pages limit per user for ETL (default: unlimited) +# PAGES_LIMIT=500 -# Debit the credit wallet for ETL page processing. Default FALSE keeps ETL -# effectively free for self-hosted installs. 1 page == MICROS_PER_PAGE -# micro-USD ($0.001); premium ETL mode is 10x. -# ETL_CREDIT_BILLING_ENABLED=FALSE -# MICROS_PER_PAGE=1000 +# Premium credit quota per registered user, in micro-USD (default: $5). +# Premium turns are debited at the actual per-call provider cost reported +# by LiteLLM. Only applies to models with billing_tier=premium. +# PREMIUM_CREDIT_MICROS_LIMIT=5000000 # Safety ceiling on per-call premium reservation, in micro-USD ($1.00 default). # QUOTA_MAX_RESERVE_MICROS=1000000 @@ -456,36 +453,3 @@ NOLOGIN_MODE_ENABLED=FALSE # RESIDENTIAL_PROXY_HOSTNAME= # RESIDENTIAL_PROXY_LOCATION= # RESIDENTIAL_PROXY_TYPE=1 - -# ============================================================================== -# DEV / DEPS-ONLY COMPOSE OVERRIDES -# These are only needed for docker-compose.dev.yml or docker-compose.deps-only.yml. -# Production Docker exposes Caddy only; raw app ports below do not affect -# docker-compose.yml. -# ============================================================================== - -# -- pgAdmin (database GUI, dev/deps-only only) -- -# PGADMIN_PORT=5050 -# PGADMIN_DEFAULT_EMAIL=admin@surfsense.com -# PGADMIN_DEFAULT_PASSWORD=surfsense - -# -- Redis exposed port (dev/deps-only only; Redis is internal-only in prod) -- -# REDIS_PORT=6379 - -# -- SearXNG exposed port (dev/deps-only only; internal-only in prod) -- -# SEARXNG_PORT=8888 - -# -- WhatsApp bridge exposed port (dev/hybrid only; prod keeps it Docker-internal) -- -# WHATSAPP_BRIDGE_PORT=9929 - -# -- Raw app ports (dev/deps-only only; prod exposes Caddy instead) -- -# BACKEND_PORT=8000 -# FRONTEND_PORT=3000 -# ZERO_CACHE_PORT=4848 - -# -- Frontend runtime flags (prod and dev compose) -- -# The frontend reads these at request time in Docker; no NEXT_PUBLIC_* rebuild -# or startup substitution is required. -# AUTH_TYPE=LOCAL -# ETL_SERVICE=DOCLING -# DEPLOYMENT_MODE=self-hosted diff --git a/docker/docker-compose.dev.yml b/docker/docker-compose.dev.yml index 5b86ea888..35effefc0 100644 --- a/docker/docker-compose.dev.yml +++ b/docker/docker-compose.dev.yml @@ -106,7 +106,6 @@ services: volumes: - ../surfsense_backend/app:/app/app - shared_temp:/shared_tmp - - object_store:/app/.local_object_store env_file: - ../surfsense_backend/.env extra_hosts: @@ -120,7 +119,6 @@ services: - PYTHONPATH=/app - UVICORN_LOOP=asyncio - UNSTRUCTURED_HAS_PATCHED_LOOP=1 - - FILE_STORAGE_LOCAL_PATH=/app/.local_object_store - LANGCHAIN_TRACING_V2=false - LANGSMITH_TRACING=false - AUTH_TYPE=${AUTH_TYPE:-LOCAL} @@ -173,7 +171,6 @@ services: volumes: - ../surfsense_backend/app:/app/app - shared_temp:/shared_tmp - - object_store:/app/.local_object_store env_file: - ../surfsense_backend/.env extra_hosts: @@ -185,7 +182,6 @@ services: - REDIS_APP_URL=${REDIS_URL:-redis://redis:6379/0} - CELERY_TASK_DEFAULT_QUEUE=surfsense - PYTHONPATH=/app - - FILE_STORAGE_LOCAL_PATH=/app/.local_object_store - SEARXNG_DEFAULT_HOST=${SEARXNG_DEFAULT_HOST:-http://searxng:8080} - SERVICE_ROLE=worker depends_on: @@ -257,15 +253,16 @@ services: frontend: build: context: ../surfsense_web + args: + NEXT_PUBLIC_FASTAPI_BACKEND_URL: ${NEXT_PUBLIC_FASTAPI_BACKEND_URL:-http://localhost:8000} + NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE: ${NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE:-LOCAL} + NEXT_PUBLIC_ETL_SERVICE: ${NEXT_PUBLIC_ETL_SERVICE:-DOCLING} + NEXT_PUBLIC_ZERO_CACHE_URL: ${NEXT_PUBLIC_ZERO_CACHE_URL:-http://localhost:${ZERO_CACHE_PORT:-4848}} + NEXT_PUBLIC_DEPLOYMENT_MODE: ${NEXT_PUBLIC_DEPLOYMENT_MODE:-self-hosted} ports: - "${FRONTEND_PORT:-3000}:3000" env_file: - ../surfsense_web/.env - environment: - AUTH_TYPE: ${AUTH_TYPE:-LOCAL} - ETL_SERVICE: ${ETL_SERVICE:-DOCLING} - DEPLOYMENT_MODE: ${DEPLOYMENT_MODE:-self-hosted} - SURFSENSE_BACKEND_INTERNAL_URL: http://backend:8000 depends_on: backend: condition: service_healthy @@ -281,8 +278,6 @@ volumes: name: surfsense-dev-redis shared_temp: name: surfsense-dev-shared-temp - object_store: - name: surfsense-dev-object-store zero_cache_data: name: surfsense-dev-zero-cache whatsapp_sessions: diff --git a/docker/docker-compose.proxy.yml b/docker/docker-compose.proxy.yml deleted file mode 100644 index 1990f6db8..000000000 --- a/docker/docker-compose.proxy.yml +++ /dev/null @@ -1,54 +0,0 @@ -# ============================================================================= -# SurfSense — Optional Caddy reverse-proxy overlay -# ============================================================================= -# Usage (from docker/): -# PROXY_HTTP_PORT=8080 SURFSENSE_PUBLIC_URL=http://localhost:8080 \ -# docker compose -f docker-compose.yml -f docker-compose.proxy.yml up -d -# -# This overlay is for validation and custom deployments. The production -# docker-compose.yml includes Caddy by default. -# ============================================================================= - -services: - backend: - ports: - - "${BACKEND_PORT:-8929}:8000" - - zero-cache: - ports: - - "${ZERO_CACHE_PORT:-5929}:4848" - - frontend: - ports: - - "${FRONTEND_PORT:-3929}:3000" - - proxy: - image: caddy:2-alpine - restart: unless-stopped - ports: - - "${PROXY_HTTP_PORT:-8080}:80" - - "${PROXY_HTTPS_PORT:-8443}:443" - volumes: - - ./proxy/Caddyfile:/etc/caddy/Caddyfile:ro - - caddy_data:/data - - caddy_config:/config - environment: - SURFSENSE_SITE_ADDRESS: ${SURFSENSE_SITE_ADDRESS:-:80} - CERT_EMAIL: ${CERT_EMAIL:-} - CERT_ACME_CA: ${CERT_ACME_CA:-https://acme-v02.api.letsencrypt.org/directory} - CERT_ACME_DNS: ${CERT_ACME_DNS:-} - TRUSTED_PROXIES: ${TRUSTED_PROXIES:-0.0.0.0/0} - SURFSENSE_MAX_BODY_SIZE: ${SURFSENSE_MAX_BODY_SIZE:-5GB} - depends_on: - frontend: - condition: service_started - backend: - condition: service_healthy - zero-cache: - condition: service_healthy - -volumes: - caddy_data: - name: surfsense-caddy-data - caddy_config: - name: surfsense-caddy-config diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 1ee7ae0ed..9bbf28ffd 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -94,42 +94,12 @@ services: timeout: 5s retries: 5 - # Single public entry point for the Docker stack. Comment this service out - # only if you front SurfSense with your own reverse proxy. - proxy: - image: caddy:2-alpine - # For DNS-01/wildcard certificates, replace image with: - # build: ./proxy - restart: unless-stopped - ports: - - "${LISTEN_HTTP_PORT:-3929}:80" - - "${LISTEN_HTTPS_PORT:-443}:443" - volumes: - - ./proxy/Caddyfile:/etc/caddy/Caddyfile:ro - - caddy_data:/data - - caddy_config:/config - environment: - SURFSENSE_SITE_ADDRESS: ${SURFSENSE_SITE_ADDRESS:-:80} - CERT_EMAIL: ${CERT_EMAIL:-} - CERT_ACME_CA: ${CERT_ACME_CA:-https://acme-v02.api.letsencrypt.org/directory} - CERT_ACME_DNS: ${CERT_ACME_DNS:-} - TRUSTED_PROXIES: ${TRUSTED_PROXIES:-0.0.0.0/0} - SURFSENSE_MAX_BODY_SIZE: ${SURFSENSE_MAX_BODY_SIZE:-5GB} - depends_on: - frontend: - condition: service_started - backend: - condition: service_healthy - zero-cache: - condition: service_healthy - backend: image: ghcr.io/modsetter/surfsense-backend:${SURFSENSE_VERSION:-latest}${SURFSENSE_VARIANT:+-${SURFSENSE_VARIANT}} - expose: - - "8000" + ports: + - "${BACKEND_PORT:-8929}:8000" volumes: - shared_temp:/shared_tmp - - object_store:/app/.local_object_store env_file: - .env extra_hosts: @@ -143,9 +113,7 @@ services: PYTHONPATH: /app UVICORN_LOOP: asyncio UNSTRUCTURED_HAS_PATCHED_LOOP: "1" - FILE_STORAGE_LOCAL_PATH: /app/.local_object_store - NEXT_FRONTEND_URL: ${NEXT_FRONTEND_URL:-${SURFSENSE_PUBLIC_URL:-http://localhost:${LISTEN_HTTP_PORT:-3929}}} - BACKEND_URL: ${BACKEND_URL:-${SURFSENSE_PUBLIC_URL:-http://localhost:${LISTEN_HTTP_PORT:-3929}}} + NEXT_FRONTEND_URL: ${NEXT_FRONTEND_URL:-http://localhost:${FRONTEND_PORT:-3929}} SEARXNG_DEFAULT_HOST: ${SEARXNG_DEFAULT_HOST:-http://searxng:8080} WHATSAPP_BRIDGE_URL: ${WHATSAPP_BRIDGE_URL:-http://whatsapp-bridge:9929} # Daytona Sandbox – uncomment and set credentials to enable cloud code execution @@ -197,7 +165,6 @@ services: image: ghcr.io/modsetter/surfsense-backend:${SURFSENSE_VERSION:-latest}${SURFSENSE_VARIANT:+-${SURFSENSE_VARIANT}} volumes: - shared_temp:/shared_tmp - - object_store:/app/.local_object_store env_file: - .env extra_hosts: @@ -209,7 +176,6 @@ services: REDIS_APP_URL: ${REDIS_URL:-redis://redis:6379/0} CELERY_TASK_DEFAULT_QUEUE: surfsense PYTHONPATH: /app - FILE_STORAGE_LOCAL_PATH: /app/.local_object_store SEARXNG_DEFAULT_HOST: ${SEARXNG_DEFAULT_HOST:-http://searxng:8080} SERVICE_ROLE: worker depends_on: @@ -251,8 +217,8 @@ services: zero-cache: image: rocicorp/zero:1.4.0 - expose: - - "4848" + ports: + - "${ZERO_CACHE_PORT:-5929}:4848" extra_hosts: - "host.docker.internal:host-gateway" environment: @@ -286,13 +252,16 @@ services: frontend: image: ghcr.io/modsetter/surfsense-web:${SURFSENSE_VERSION:-latest} - expose: - - "3000" + ports: + - "${FRONTEND_PORT:-3929}:3000" environment: - AUTH_TYPE: ${AUTH_TYPE:-LOCAL} - ETL_SERVICE: ${ETL_SERVICE:-DOCLING} - DEPLOYMENT_MODE: ${DEPLOYMENT_MODE:-self-hosted} - SURFSENSE_BACKEND_INTERNAL_URL: http://backend:8000 + NEXT_PUBLIC_FASTAPI_BACKEND_URL: ${NEXT_PUBLIC_FASTAPI_BACKEND_URL:-http://localhost:${BACKEND_PORT:-8929}} + NEXT_PUBLIC_ZERO_CACHE_URL: ${NEXT_PUBLIC_ZERO_CACHE_URL:-http://localhost:${ZERO_CACHE_PORT:-5929}} + NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE: ${AUTH_TYPE:-LOCAL} + NEXT_PUBLIC_ETL_SERVICE: ${ETL_SERVICE:-DOCLING} + NEXT_PUBLIC_DEPLOYMENT_MODE: ${DEPLOYMENT_MODE:-self-hosted} + NEXT_PUBLIC_WHATSAPP_DISPLAY_PHONE_NUMBER: ${WHATSAPP_SHARED_DISPLAY_PHONE_NUMBER:-} + FASTAPI_BACKEND_INTERNAL_URL: ${FASTAPI_BACKEND_INTERNAL_URL:-http://backend:8000} labels: - "com.centurylinklabs.watchtower.enable=true" depends_on: @@ -309,13 +278,7 @@ volumes: name: surfsense-redis shared_temp: name: surfsense-shared-temp - object_store: - name: surfsense-object-store zero_cache_data: name: surfsense-zero-cache - caddy_data: - name: surfsense-caddy-data - caddy_config: - name: surfsense-caddy-config whatsapp_sessions: name: surfsense-whatsapp-sessions diff --git a/docker/proxy/Caddyfile b/docker/proxy/Caddyfile deleted file mode 100644 index 534a8c2c2..000000000 --- a/docker/proxy/Caddyfile +++ /dev/null @@ -1,45 +0,0 @@ -{ - # Optional ACME/global settings. These are harmless in the default :80 - # localhost mode and become active when SURFSENSE_SITE_ADDRESS is a domain. - {$CERT_EMAIL} - acme_ca {$CERT_ACME_CA:https://acme-v02.api.letsencrypt.org/directory} - {$CERT_ACME_DNS} - servers { - client_ip_headers X-Forwarded-For X-Real-IP - trusted_proxies static {$TRUSTED_PROXIES:0.0.0.0/0} - } -} - -(surfsense_proxy) { - request_body { - max_size {$SURFSENSE_MAX_BODY_SIZE:5GB} - } - - # Frontend-owned auth page (the post-login token handler). More specific than - # /auth/*, so Caddy's matcher-specificity sort routes it here, not to backend. - reverse_proxy /auth/callback* frontend:3000 - - # Backend auth routes (FastAPI Users + OAuth helpers). - reverse_proxy /auth/* backend:8000 - - # Backend user profile routes (FastAPI Users users router, mounted at /users). - reverse_proxy /users/* backend:8000 - - # Backend REST, streaming, connector OAuth, and messaging gateway endpoints. - # FastAPI already serves /api/v1, so the path is forwarded unchanged. - reverse_proxy /api/v1/* backend:8000 { - flush_interval -1 - } - - # Zero accepts a single path-component base URL (Zero >= 0.6). - # Preserve /zero so browser cacheURL can be ${SURFSENSE_PUBLIC_URL}/zero. - reverse_proxy /zero/* zero-cache:4848 - - # Next.js app and frontend-owned API routes: - # /api/zero/*, /api/search, /api/contact, etc. - reverse_proxy /* frontend:3000 -} - -{$SURFSENSE_SITE_ADDRESS::80} { - import surfsense_proxy -} diff --git a/docker/proxy/Dockerfile b/docker/proxy/Dockerfile deleted file mode 100644 index 8395a817c..000000000 --- a/docker/proxy/Dockerfile +++ /dev/null @@ -1,10 +0,0 @@ -FROM caddy:2-builder-alpine AS builder - -RUN xcaddy build \ - --with github.com/caddy-dns/cloudflare \ - --with github.com/caddy-dns/digitalocean - -FROM caddy:2-alpine - -COPY --from=builder /usr/bin/caddy /usr/bin/caddy -COPY Caddyfile /etc/caddy/Caddyfile diff --git a/docker/scripts/install.ps1 b/docker/scripts/install.ps1 index 23b14c3c4..6e973a520 100644 --- a/docker/scripts/install.ps1 +++ b/docker/scripts/install.ps1 @@ -17,14 +17,10 @@ # into the new PostgreSQL 17 stack. The user runs one command for both. # ============================================================================= -# NOTE: Do not use [ValidateSet()] (or other validation attributes without a -# valid default) on these params. When the script is piped into iex, PowerShell -# applies the attributes to variables in the caller's scope, and an empty -# $Variant fails ValidateSet with a ValidationMetadataException. Validate -# manually below instead. param( [switch]$NoWatchtower, [int]$WatchtowerInterval = 86400, + [ValidateSet("cpu", "cuda", "cuda126")] [string]$Variant, [string]$GpuCount, [switch]$Quiet @@ -44,11 +40,6 @@ $MigrationMode = $false $SetupWatchtower = -not $NoWatchtower $WatchtowerContainer = "watchtower" -if ($Variant -and $Variant -notin @("cpu", "cuda", "cuda126")) { - Write-Host "[SurfSense] ERROR: Invalid -Variant '$Variant'. Use 'cpu', 'cuda', or 'cuda126'." -ForegroundColor Red - exit 1 -} - if ($GpuCount -and $GpuCount -notmatch '^([0-9]+|all)$') { Write-Host "[SurfSense] ERROR: Invalid -GpuCount '$GpuCount'. Use a number or 'all'." -ForegroundColor Red exit 1 diff --git a/docker/scripts/install.sh b/docker/scripts/install.sh index f9660b132..4df15fbd0 100644 --- a/docker/scripts/install.sh +++ b/docker/scripts/install.sh @@ -333,13 +333,11 @@ step "Downloading SurfSense files" info "Installation directory: ${INSTALL_DIR}" mkdir -p "${INSTALL_DIR}/scripts" mkdir -p "${INSTALL_DIR}/searxng" -mkdir -p "${INSTALL_DIR}/proxy" FILES=( "docker/docker-compose.yml:docker-compose.yml" "docker/docker-compose.gpu.yml:docker-compose.gpu.yml" "docker/.env.example:.env.example" - "docker/proxy/Caddyfile:proxy/Caddyfile" "docker/postgresql.conf:postgresql.conf" "docker/scripts/migrate-database.sh:scripts/migrate-database.sh" "docker/searxng/settings.yml:searxng/settings.yml" @@ -534,12 +532,9 @@ _variant_display=$(grep '^SURFSENSE_VARIANT=' "${INSTALL_DIR}/.env" 2>/dev/null _variant_display="${_variant_display:-cpu}" step "SurfSense is now installed [${_version_display}]" -_public_url=$(grep '^SURFSENSE_PUBLIC_URL=' "${INSTALL_DIR}/.env" 2>/dev/null | cut -d= -f2- | tr -d '"' | head -1 || true) -_public_url="${_public_url:-http://localhost:3929}" - -info " SurfSense: ${_public_url}" -info " Backend: ${_public_url}/api/v1" -info " Zero sync: ${_public_url}/zero" +info " Frontend: http://localhost:3929" +info " Backend: http://localhost:8929" +info " API Docs: http://localhost:8929/docs" info "" info " Config: ${INSTALL_DIR}/.env" info " Variant: ${_variant_display}" diff --git a/surfsense_backend/.env.example b/surfsense_backend/.env.example index a6b2b30a3..6e49a7132 100644 --- a/surfsense_backend/.env.example +++ b/surfsense_backend/.env.example @@ -1,20 +1,5 @@ DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense -# --- Database startup / safety knobs (optional) --- -# Run extension/table/index DDL on app startup. Set FALSE when schema is owned -# exclusively by Alembic migrations. -# DB_BOOTSTRAP_ON_STARTUP=TRUE -# lock_timeout (ms) for boot-time DDL so a contended CREATE INDEX/TABLE fails -# fast instead of hanging the FastAPI lifespan behind another transaction. -# DB_DDL_LOCK_TIMEOUT_MS=5000 -# idle_in_transaction_session_timeout (ms) so an abandoned "idle in transaction" -# session can't wedge the DB indefinitely. 0 disables. (asyncpg only) -# DB_IDLE_IN_TX_TIMEOUT_MS=900000 -# Same, for the Celery worker engine (long ingestion/podcast/video tasks). If a -# task hasn't touched the DB in this window it's treated as orphaned and dropped. -# 0 disables. (asyncpg only) -# DB_CELERY_IDLE_IN_TX_TIMEOUT_MS=3600000 - # Deployment environment: dev or production SURFSENSE_ENV=dev @@ -30,9 +15,12 @@ CELERY_TASK_DEFAULT_QUEUE=surfsense # Optional: TTL in seconds for connector indexing lock key # CONNECTOR_INDEXING_LOCK_TTL_SECONDS=28800 -# Messaging Gateway: disabled by default; set TRUE to enable chat integrations. -# Supported messaging gateways: WhatsApp, Telegram, Discord, Slack -# GATEWAY_ENABLED=TRUE +# Messaging Gateway (global) +# GATEWAY_ENABLED: master switch for ALL messaging gateway channels (Telegram, WhatsApp, +# Slack, Discord). When FALSE, no gateway background workers/supervisors start and all +# gateway HTTP routes (webhooks, OAuth callbacks, pairing) return 404. Set per-channel +# flags below to control individual platforms once the gateway is enabled. +GATEWAY_ENABLED=TRUE # Telegram Gateway # TELEGRAM_WEBHOOK_SECRET must be 1-256 chars and contain only A-Z, a-z, 0-9, _ or - @@ -87,16 +75,23 @@ SECRET_KEY=SECRET NEXT_FRONTEND_URL=http://localhost:3000 -# Stripe Checkout for the unified credit wallet. -# Each pack grants STRIPE_CREDIT_MICROS_PER_UNIT micro-USD of credit -# (default 1_000_000 = $1.00). Both ETL page processing and premium model -# turns are billed against this single balance at actual provider cost. +# Stripe Checkout for pay-as-you-go page packs +# Configure STRIPE_PRICE_ID to point at your 1,000-page price in Stripe. +# Pages granted per purchase = quantity * STRIPE_PAGES_PER_UNIT. STRIPE_SECRET_KEY=sk_test_... STRIPE_WEBHOOK_SECRET=whsec_... -STRIPE_CREDIT_PRICE_ID=price_... -STRIPE_CREDIT_MICROS_PER_UNIT=1000000 +STRIPE_PRICE_ID=price_... +STRIPE_PAGES_PER_UNIT=1000 # Set FALSE to disable new checkout session creation temporarily -STRIPE_CREDIT_BUYING_ENABLED=FALSE +STRIPE_PAGE_BUYING_ENABLED=TRUE + +# Premium credit purchases via Stripe (for premium-tier model usage). +# Each pack grants STRIPE_CREDIT_MICROS_PER_UNIT micro-USD of credit +# (default 1_000_000 = $1.00). Premium turns are billed at the actual +# per-call provider cost reported by LiteLLM. +STRIPE_TOKEN_BUYING_ENABLED=FALSE +STRIPE_PREMIUM_TOKEN_PRICE_ID=price_... +STRIPE_CREDIT_MICROS_PER_UNIT=1000000 # Periodic Stripe safety net for purchases left in PENDING (minutes old) STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10 @@ -226,25 +221,15 @@ VIDEO_PRESENTATION_FPS=30 VIDEO_PRESENTATION_DEFAULT_DURATION_IN_FRAMES=300 -# Unified credit wallet starting balance for new users, in micro-USD -# (default: 5,000,000 == $5.00). The same balance funds ETL page processing -# and premium model calls, debited at actual provider cost. -DEFAULT_CREDIT_MICROS_BALANCE=5000000 +# (Optional) Maximum pages limit per user for ETL services (default: `999999999` for unlimited in OSS version) +PAGES_LIMIT=500 -# Debit the credit wallet for ETL page processing. Default FALSE keeps ETL -# effectively free for self-hosted/OSS installs; hosted deployments set TRUE. -# 1 page == MICROS_PER_PAGE micro-USD ($0.001); premium ETL mode is 10x. -ETL_CREDIT_BILLING_ENABLED=FALSE -MICROS_PER_PAGE=1000 - -# Low-balance warning threshold (micro-USD), surfaced to the UI. Default $0.50. -CREDIT_LOW_BALANCE_WARNING_MICROS=500000 - -# Auto-reload: automatically top up via a saved Stripe card when the balance -# drops below the user-chosen threshold. Off by default. -AUTO_RELOAD_ENABLED=FALSE -AUTO_RELOAD_MIN_AMOUNT_MICROS=1000000 -AUTO_RELOAD_COOLDOWN_MINUTES=10 +# Premium credit quota per registered user, in micro-USD +# (default: 5,000,000 == $5.00 of credit). Premium turns are debited at the +# actual per-call provider cost reported by LiteLLM, so cheap and expensive +# models bill proportionally. Applies only to models with +# billing_tier=premium in global_llm_config.yaml. +PREMIUM_CREDIT_MICROS_LIMIT=5000000 # Safety ceiling on per-call premium reservation, in micro-USD. # stream_new_chat estimates an upper-bound cost from the model's @@ -323,42 +308,6 @@ FILE_STORAGE_BACKEND=local # AZURE_STORAGE_CONNECTION_STRING=DefaultEndpointsProtocol=https;AccountName=...;AccountKey=...;EndpointSuffix=core.windows.net # AZURE_STORAGE_CONTAINER=surfsense-documents -# ETL Parse Cache -# Reuse parser output for identical file bytes across workspaces (skips paid -# re-parsing on LlamaCloud / Azure DI / Unstructured). Off by default. -ETL_CACHE_ENABLED=false -# Bump to invalidate all cached entries after a parser/behaviour change. -# ETL_CACHE_PARSER_VERSION=1 -# Prune entries unused for this many days. -# ETL_CACHE_TTL_DAYS=90 -# Soft cap on total cached markdown; coldest entries are evicted past it. -# ETL_CACHE_MAX_TOTAL_MB=5120 -# Rows deleted per eviction pass. -# ETL_CACHE_EVICTION_BATCH=500 -# Optional dedicated blob storage; unset reuses the main file storage backend. -# ETL_CACHE_STORAGE_BACKEND=azure -# ETL_CACHE_STORAGE_CONTAINER=surfsense-etl-cache -# ETL_CACHE_STORAGE_LOCAL_PATH=/var/lib/surfsense/etl-cache - -# Embedding Cache -# Reuse chunk+embedding output for identical markdown across workspaces (skips -# re-chunking and re-embedding). Blobs share the ETL_CACHE_STORAGE_* backend. -# Off by default. -EMBEDDING_CACHE_ENABLED=false -# Bump to invalidate all cached embedding sets after a chunker change. -# EMBEDDING_CACHE_CHUNKER_VERSION=1 -# Prune entries unused for this many days. -# EMBEDDING_CACHE_TTL_DAYS=90 -# Soft cap on total cached embeddings; coldest entries are evicted past it. -# EMBEDDING_CACHE_MAX_TOTAL_MB=5120 -# Rows deleted per eviction pass. -# EMBEDDING_CACHE_EVICTION_BATCH=500 - -# Incremental re-indexing: on document edits, keep chunks whose text is -# unchanged (reusing their embeddings) and embed only new/changed ones. -# Set to false to fall back to delete-all + full re-embed (kill switch). -# CHUNK_RECONCILE_ENABLED=true - # Daytona Sandbox (isolated code execution) # DAYTONA_SANDBOX_ENABLED=FALSE # DAYTONA_API_KEY=your-daytona-api-key @@ -398,9 +347,7 @@ LANGSMITH_PROJECT=surfsense # SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false # adds a per-turn LLM call # Observability - OTel -# Disabled by default. Uncomment to enable OpenTelemetry. -# SURFSENSE_ENABLE_OTEL=true - +# SURFSENSE_ENABLE_OTEL=false # OpenTelemetry - endpoint enables export; absent = no-op. # Production should point at an OTel Collector. For local docker-compose.dev.yml, # use http://otel-lgtm:4317 instead. diff --git a/surfsense_backend/.gitignore b/surfsense_backend/.gitignore index bda5961fe..efc6c90d7 100644 --- a/surfsense_backend/.gitignore +++ b/surfsense_backend/.gitignore @@ -1,12 +1,12 @@ .env .venv venv/ -/data/ +data/ .local_object_store/ __pycache__/ .flashrank_cache surf_new_backend.egg-info/ -/podcasts/ +podcasts/ video_presentation_audio/ sandbox_files/ temp_audio/ diff --git a/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py b/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py index 8c74b637b..fba621a0c 100644 --- a/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py +++ b/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py @@ -4,7 +4,7 @@ Revision ID: 138 Revises: 137 Create Date: 2026-04-30 -Add a single thread-level column to persist the Auto model pin: +Add a single thread-level column to persist the Auto (Fastest) model pin: - pinned_llm_config_id: concrete resolved global LLM config id used for this thread. NULL means "no pin; Auto will resolve on next turn". diff --git a/surfsense_backend/alembic/versions/156_unify_credits_wallet.py b/surfsense_backend/alembic/versions/156_unify_credits_wallet.py deleted file mode 100644 index 1ecf1a255..000000000 --- a/surfsense_backend/alembic/versions/156_unify_credits_wallet.py +++ /dev/null @@ -1,235 +0,0 @@ -"""unify page limits and premium credits into a single credit_micros_balance wallet - -Collapses the two separate economies (ETL ``pages_limit``/``pages_used`` and -premium ``premium_credit_micros_limit``/``premium_credit_micros_used``) into one -USD-micro wallet column ``user.credit_micros_balance`` that decreases on use and -increases on purchase / grant. ``premium_credit_micros_reserved`` is kept (renamed -to ``credit_micros_reserved``) for in-flight reservation holds. - -Backfill (per existing user row): - - balance = GREATEST(0, premium_credit_micros_limit - premium_credit_micros_used) - + (CASE WHEN pages_limit < 100000000 - THEN GREATEST(0, pages_limit - pages_used) * 1000 - ELSE 0 END) - -The ``pages_limit < 100000000`` guard skips the OSS "unlimited" default -(``PAGES_LIMIT=999999999``) so self-hosters don't get a ~$1M credit grant. -1 page == 1000 micros == $0.001 (matches the prior $1 / 1000 pages price). - -Table / type renames: - - premium_token_purchases -> credit_purchases - premiumtokenpurchasestatus (enum)-> creditpurchasestatus - user_incentive_tasks.pages_awarded -> credit_micros_awarded (backfilled * 1000) - -The "user" table is in zero_publication's column list, so this migration updates -the publication via ``apply_publication`` (canonical reconcile, per migration 155) -BEFORE dropping the old columns it referenced. - -IMPORTANT - before AND after running this migration (same as migration 140): - 1. Stop zero-cache (it holds replication locks that will deadlock DDL) - 2. Run: alembic upgrade head - 3. Delete / reset the zero-cache data volume - 4. Restart zero-cache (it will do a fresh initial sync) - -Revision ID: 156 -Revises: 155 -""" - -from collections.abc import Sequence - -import sqlalchemy as sa - -from alembic import op -from app.zero_publication import apply_publication - -revision: str = "156" -down_revision: str | None = "155" -branch_labels: str | Sequence[str] | None = None -depends_on: str | Sequence[str] | None = None - - -def _column_exists(conn, table: str, column: str) -> bool: - return ( - conn.execute( - sa.text( - "SELECT 1 FROM information_schema.columns " - "WHERE table_name = :tbl AND column_name = :col " - "AND table_schema = current_schema()" - ), - {"tbl": table, "col": column}, - ).fetchone() - is not None - ) - - -def _table_exists(conn, table: str) -> bool: - return ( - conn.execute( - sa.text( - "SELECT 1 FROM information_schema.tables " - "WHERE table_name = :tbl AND table_schema = current_schema()" - ), - {"tbl": table}, - ).fetchone() - is not None - ) - - -def _terminate_blocked_pids(conn, table: str) -> None: - """Kill backends whose locks on *table* would block our AccessExclusiveLock.""" - conn.execute( - sa.text( - "SELECT pg_terminate_backend(l.pid) " - "FROM pg_locks l " - "JOIN pg_class c ON c.oid = l.relation " - "WHERE c.relname = :tbl " - " AND l.pid != pg_backend_pid()" - ), - {"tbl": table}, - ) - - -def upgrade() -> None: - conn = op.get_bind() - - # ------------------------------------------------------------------ - # 1. Add credit_micros_balance + backfill from both legacy economies. - # ------------------------------------------------------------------ - if not _column_exists(conn, "user", "credit_micros_balance"): - op.add_column( - "user", - sa.Column( - "credit_micros_balance", - sa.BigInteger(), - nullable=False, - server_default="5000000", - ), - ) - - # Backfill only when ALL legacy source columns are present (fresh DBs - # created from current models won't have them). - if all( - _column_exists(conn, "user", col) - for col in ( - "premium_credit_micros_limit", - "premium_credit_micros_used", - "pages_limit", - "pages_used", - ) - ): - conn.execute( - sa.text( - 'UPDATE "user" SET credit_micros_balance = ' - "GREATEST(0, premium_credit_micros_limit - premium_credit_micros_used) " - "+ (CASE WHEN pages_limit < 100000000 " - " THEN GREATEST(0, pages_limit - pages_used) * 1000 " - " ELSE 0 END)" - ) - ) - - # ------------------------------------------------------------------ - # 2. Rename premium_credit_micros_reserved -> credit_micros_reserved. - # ------------------------------------------------------------------ - if _column_exists( - conn, "user", "premium_credit_micros_reserved" - ) and not _column_exists(conn, "user", "credit_micros_reserved"): - op.alter_column( - "user", - "premium_credit_micros_reserved", - new_column_name="credit_micros_reserved", - ) - - # ------------------------------------------------------------------ - # 3. Reconcile the Zero publication to the new column list - # (id, credit_micros_balance) BEFORE dropping the columns it used - # to reference, otherwise Postgres rejects the column drops with - # "cannot drop column ... referenced by publication". - # ------------------------------------------------------------------ - conn.execute(sa.text("SET lock_timeout = '10s'")) - _terminate_blocked_pids(conn, "user") - apply_publication(conn) - - # ------------------------------------------------------------------ - # 4. Drop the legacy quota columns now that nothing references them. - # ------------------------------------------------------------------ - for col in ( - "premium_credit_micros_limit", - "premium_credit_micros_used", - "pages_limit", - "pages_used", - ): - if _column_exists(conn, "user", col): - op.drop_column("user", col) - - # ------------------------------------------------------------------ - # 5. Rename premium_token_purchases -> credit_purchases and its enum. - # ------------------------------------------------------------------ - op.execute( - """ - DO $$ - BEGIN - IF EXISTS ( - SELECT 1 FROM pg_type t - JOIN pg_namespace n ON n.oid = t.typnamespace - WHERE t.typname = 'premiumtokenpurchasestatus' - AND n.nspname = current_schema() - ) - AND NOT EXISTS ( - SELECT 1 FROM pg_type t - JOIN pg_namespace n ON n.oid = t.typnamespace - WHERE t.typname = 'creditpurchasestatus' - AND n.nspname = current_schema() - ) - THEN - ALTER TYPE premiumtokenpurchasestatus RENAME TO creditpurchasestatus; - END IF; - END - $$; - """ - ) - - if _table_exists(conn, "premium_token_purchases") and not _table_exists( - conn, "credit_purchases" - ): - op.rename_table("premium_token_purchases", "credit_purchases") - - # ``source`` distinguishes user checkout from auto-reload top-ups. - if _table_exists(conn, "credit_purchases") and not _column_exists( - conn, "credit_purchases", "source" - ): - op.add_column( - "credit_purchases", - sa.Column( - "source", - sa.String(length=20), - nullable=False, - server_default="checkout", - ), - ) - - # ------------------------------------------------------------------ - # 6. Rename user_incentive_tasks.pages_awarded -> credit_micros_awarded - # and convert page counts to micros (1 page == 1000 micros). - # ------------------------------------------------------------------ - if _column_exists( - conn, "user_incentive_tasks", "pages_awarded" - ) and not _column_exists(conn, "user_incentive_tasks", "credit_micros_awarded"): - op.alter_column( - "user_incentive_tasks", - "pages_awarded", - new_column_name="credit_micros_awarded", - type_=sa.BigInteger(), - ) - conn.execute( - sa.text( - "UPDATE user_incentive_tasks " - "SET credit_micros_awarded = credit_micros_awarded * 1000" - ) - ) - - -def downgrade() -> None: - """No-op. This is a one-way data-model unification; the legacy split - columns cannot be faithfully reconstructed from a single balance.""" diff --git a/surfsense_backend/alembic/versions/157_add_auto_reload_columns.py b/surfsense_backend/alembic/versions/157_add_auto_reload_columns.py deleted file mode 100644 index ef021b6d2..000000000 --- a/surfsense_backend/alembic/versions/157_add_auto_reload_columns.py +++ /dev/null @@ -1,92 +0,0 @@ -"""add auto-reload (off-session Stripe top-up) columns to user - -Adds the saved-card + threshold plumbing that powers feature-flagged credit -auto-reload (``AUTO_RELOAD_ENABLED``): - - user.stripe_customer_id (text, nullable) - user.auto_reload_enabled (bool, default false) - user.auto_reload_threshold_micros (bigint, nullable) - user.auto_reload_amount_micros (bigint, nullable) - user.auto_reload_payment_method_id (text, nullable) - user.auto_reload_failed_at (timestamptz, nullable) - -None of these columns are part of the Zero publication (``USER_COLS`` is -``["id", "credit_micros_balance"]``), so this migration does NOT touch the -publication and is safe to run without the zero-cache stop/reset dance that -migration 156 required. - -The ``credit_purchases.source`` column (``checkout`` | ``auto_reload``) was -already added in migration 156, so it is not repeated here. - -Revision ID: 157 -Revises: 156 -""" - -from collections.abc import Sequence - -import sqlalchemy as sa - -from alembic import op - -revision: str = "157" -down_revision: str | None = "156" -branch_labels: str | Sequence[str] | None = None -depends_on: str | Sequence[str] | None = None - - -def _column_exists(conn, table: str, column: str) -> bool: - return ( - conn.execute( - sa.text( - "SELECT 1 FROM information_schema.columns " - "WHERE table_name = :tbl AND column_name = :col " - "AND table_schema = current_schema()" - ), - {"tbl": table, "col": column}, - ).fetchone() - is not None - ) - - -_COLUMNS: list[tuple[str, sa.Column]] = [ - ("stripe_customer_id", sa.Column("stripe_customer_id", sa.String(), nullable=True)), - ( - "auto_reload_enabled", - sa.Column( - "auto_reload_enabled", - sa.Boolean(), - nullable=False, - server_default=sa.text("false"), - ), - ), - ( - "auto_reload_threshold_micros", - sa.Column("auto_reload_threshold_micros", sa.BigInteger(), nullable=True), - ), - ( - "auto_reload_amount_micros", - sa.Column("auto_reload_amount_micros", sa.BigInteger(), nullable=True), - ), - ( - "auto_reload_payment_method_id", - sa.Column("auto_reload_payment_method_id", sa.String(), nullable=True), - ), - ( - "auto_reload_failed_at", - sa.Column("auto_reload_failed_at", sa.TIMESTAMP(timezone=True), nullable=True), - ), -] - - -def upgrade() -> None: - conn = op.get_bind() - for name, column in _COLUMNS: - if not _column_exists(conn, "user", name): - op.add_column("user", column) - - -def downgrade() -> None: - conn = op.get_bind() - for name, _ in reversed(_COLUMNS): - if _column_exists(conn, "user", name): - op.drop_column("user", name) diff --git a/surfsense_backend/alembic/versions/158_evolve_podcasts_lifecycle.py b/surfsense_backend/alembic/versions/158_evolve_podcasts_lifecycle.py deleted file mode 100644 index f1d231f9e..000000000 --- a/surfsense_backend/alembic/versions/158_evolve_podcasts_lifecycle.py +++ /dev/null @@ -1,215 +0,0 @@ -"""evolve podcasts: expand status lifecycle and add brief/transcript/storage columns - -Revision ID: 158 -Revises: 157 -""" - -from collections.abc import Sequence - -import sqlalchemy as sa - -from alembic import op - -revision: str = "158" -down_revision: str | None = "157" -branch_labels: str | Sequence[str] | None = None -depends_on: str | Sequence[str] | None = None - -PUBLICATION_NAME = "zero_publication" -TARGET_STATUS_LABELS = ( - "pending", - "awaiting_brief", - "drafting", - "awaiting_review", - "rendering", - "ready", - "failed", - "cancelled", -) -LEGACY_STATUS_LABELS = ("pending", "generating", "ready", "failed") - - -def _drop_podcasts_from_publication() -> None: - """Detach podcasts from zero_publication so status can be retyped. - - Postgres refuses ``ALTER COLUMN ... TYPE`` on a column a publication - depends on. Some databases reach this migration with podcasts already - published (an interim apply_publication ran during 156); drop it here and - let migration 159 reconcile the publication to the canonical shape. - """ - conn = op.get_bind() - published = conn.execute( - sa.text( - "SELECT 1 FROM pg_publication_tables " - "WHERE pubname = :publication " - "AND schemaname = current_schema() AND tablename = 'podcasts'" - ), - {"publication": PUBLICATION_NAME}, - ).fetchone() - if published: - op.execute(f'ALTER PUBLICATION "{PUBLICATION_NAME}" DROP TABLE "podcasts";') - - -def _enum_labels(type_name: str) -> list[str] | None: - rows = ( - op.get_bind() - .execute( - sa.text( - "SELECT e.enumlabel " - "FROM pg_type t " - "JOIN pg_namespace n ON n.oid = t.typnamespace " - "JOIN pg_enum e ON e.enumtypid = t.oid " - "WHERE n.nspname = current_schema() AND t.typname = :type_name " - "ORDER BY e.enumsortorder" - ), - {"type_name": type_name}, - ) - .fetchall() - ) - if not rows: - return None - return [str(row[0]) for row in rows] - - -def _column_type_name(table: str, column: str) -> str | None: - row = ( - op.get_bind() - .execute( - sa.text( - "SELECT udt_name " - "FROM information_schema.columns " - "WHERE table_schema = current_schema() " - "AND table_name = :table AND column_name = :column" - ), - {"table": table, "column": column}, - ) - .fetchone() - ) - return str(row[0]) if row else None - - -def _ensure_status_enum( - *, - desired_labels: tuple[str, ...], - temporary_type: str, - create_sql: str, - alter_sql: str, - default_value: str, -) -> None: - current_labels = _enum_labels("podcast_status") - desired = list(desired_labels) - - if current_labels != desired: - if current_labels is None: - if _enum_labels(temporary_type) is None: - raise RuntimeError("podcast_status enum is missing") - elif _enum_labels(temporary_type) is None: - op.execute(f"ALTER TYPE podcast_status RENAME TO {temporary_type};") - else: - raise RuntimeError( - "podcast_status and its temporary replacement both exist" - ) - - if _enum_labels("podcast_status") is None: - op.execute(create_sql) - - if _enum_labels("podcast_status") != desired: - raise RuntimeError("podcast_status enum is not in the expected shape") - - op.execute("ALTER TABLE podcasts ALTER COLUMN status DROP DEFAULT;") - if _column_type_name("podcasts", "status") != "podcast_status": - op.execute(alter_sql) - op.execute( - f"ALTER TABLE podcasts ALTER COLUMN status SET DEFAULT '{default_value}';" - ) - - if _enum_labels(temporary_type) is not None: - op.execute(f"DROP TYPE {temporary_type};") - - -def _upgrade_status_enum() -> None: - _ensure_status_enum( - desired_labels=TARGET_STATUS_LABELS, - temporary_type="podcast_status_old", - create_sql=""" - CREATE TYPE podcast_status AS ENUM ( - 'pending', 'awaiting_brief', 'drafting', 'awaiting_review', - 'rendering', 'ready', 'failed', 'cancelled' - ); - """, - alter_sql=""" - ALTER TABLE podcasts - ALTER COLUMN status TYPE podcast_status - USING ( - CASE status::text - WHEN 'generating' THEN 'rendering' - ELSE status::text - END - )::podcast_status; - """, - default_value="pending", - ) - - -def _downgrade_status_enum() -> None: - _ensure_status_enum( - desired_labels=LEGACY_STATUS_LABELS, - temporary_type="podcast_status_new", - create_sql=( - "CREATE TYPE podcast_status AS ENUM " - "('pending', 'generating', 'ready', 'failed');" - ), - alter_sql=""" - ALTER TABLE podcasts - ALTER COLUMN status TYPE podcast_status - USING ( - CASE status::text - WHEN 'awaiting_brief' THEN 'pending' - WHEN 'drafting' THEN 'generating' - WHEN 'awaiting_review' THEN 'generating' - WHEN 'rendering' THEN 'generating' - WHEN 'cancelled' THEN 'failed' - ELSE status::text - END - )::podcast_status; - """, - default_value="ready", - ) - - -def upgrade() -> None: - _drop_podcasts_from_publication() - - # Retype the status enum by swapping in a fresh type and casting existing - # rows. The legacy transient value 'generating' maps onto 'rendering'. - _upgrade_status_enum() - - op.execute("ALTER TABLE podcasts ADD COLUMN IF NOT EXISTS source_content TEXT;") - op.execute("ALTER TABLE podcasts ADD COLUMN IF NOT EXISTS spec JSONB;") - op.execute( - "ALTER TABLE podcasts ADD COLUMN IF NOT EXISTS spec_version " - "INTEGER NOT NULL DEFAULT 1;" - ) - op.execute( - "ALTER TABLE podcasts ADD COLUMN IF NOT EXISTS storage_backend VARCHAR(32);" - ) - op.execute("ALTER TABLE podcasts ADD COLUMN IF NOT EXISTS storage_key TEXT;") - op.execute( - "ALTER TABLE podcasts ADD COLUMN IF NOT EXISTS duration_seconds INTEGER;" - ) - op.execute("ALTER TABLE podcasts ADD COLUMN IF NOT EXISTS error TEXT;") - - -def downgrade() -> None: - _drop_podcasts_from_publication() - - op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS error;") - op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS duration_seconds;") - op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS storage_key;") - op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS storage_backend;") - op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS spec_version;") - op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS spec;") - op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS source_content;") - - # Collapse the expanded lifecycle back onto the original four values. - _downgrade_status_enum() diff --git a/surfsense_backend/alembic/versions/159_publish_podcasts_to_zero.py b/surfsense_backend/alembic/versions/159_publish_podcasts_to_zero.py deleted file mode 100644 index 1667ca96b..000000000 --- a/surfsense_backend/alembic/versions/159_publish_podcasts_to_zero.py +++ /dev/null @@ -1,26 +0,0 @@ -"""publish podcasts to zero_publication - -Reconciles ``zero_publication`` after migration 158 added the lifecycle columns, -so the frontend observes podcast status and the reviewable brief by push. - -Revision ID: 159 -Revises: 158 -""" - -from collections.abc import Sequence - -from alembic import op -from app.zero_publication import apply_publication - -revision: str = "159" -down_revision: str | None = "158" -branch_labels: str | Sequence[str] | None = None -depends_on: str | Sequence[str] | None = None - - -def upgrade() -> None: - apply_publication(op.get_bind()) - - -def downgrade() -> None: - """No-op. Historical publication shapes are immutable.""" diff --git a/surfsense_backend/alembic/versions/160_add_model_connections.py b/surfsense_backend/alembic/versions/160_add_model_connections.py deleted file mode 100644 index fea45aca7..000000000 --- a/surfsense_backend/alembic/versions/160_add_model_connections.py +++ /dev/null @@ -1,299 +0,0 @@ -"""add model connections - -Revision ID: 160 -Revises: 159 -""" - -from collections.abc import Sequence - -import sqlalchemy as sa -from sqlalchemy.dialects import postgresql - -from alembic import op - -revision: str = "160" -down_revision: str | None = "159" -branch_labels: str | Sequence[str] | None = None -depends_on: str | Sequence[str] | None = None - - -connection_scope = postgresql.ENUM( - "GLOBAL", - "SEARCH_SPACE", - "USER", - name="connectionscope", - create_type=False, -) -model_source = postgresql.ENUM( - "DISCOVERED", - "MANUAL", - name="modelsource", - create_type=False, -) - - -def _table_exists(table_name: str) -> bool: - return table_name in sa.inspect(op.get_bind()).get_table_names() - - -def _column_exists(table_name: str, column_name: str) -> bool: - if not _table_exists(table_name): - return False - return column_name in { - column["name"] for column in sa.inspect(op.get_bind()).get_columns(table_name) - } - - -def _index_exists(table_name: str, index_name: str) -> bool: - if not _table_exists(table_name): - return False - return index_name in { - index["name"] for index in sa.inspect(op.get_bind()).get_indexes(table_name) - } - - -def _create_index_if_missing( - index_name: str, - table_name: str, - columns: list[str], -) -> None: - if not _index_exists(table_name, index_name): - op.create_index(index_name, table_name, columns, unique=False) - - -def _add_searchspace_column_if_missing( - column_name: str, - *, - server_default: object | None = None, -) -> None: - if not _column_exists("searchspaces", column_name): - op.add_column( - "searchspaces", - sa.Column( - column_name, - sa.Integer(), - nullable=True, - server_default=server_default, - ), - ) - - -def _drop_column_if_exists(table_name: str, column_name: str) -> None: - if _column_exists(table_name, column_name): - op.drop_column(table_name, column_name) - - -def _drop_index_if_exists(table_name: str, index_name: str) -> None: - if _index_exists(table_name, index_name): - op.drop_index(index_name, table_name=table_name) - - -def upgrade() -> None: - bind = op.get_bind() - connection_scope.create(bind, checkfirst=True) - model_source.create(bind, checkfirst=True) - - if _table_exists("connections"): - if _column_exists("connections", "litellm_provider") and not _column_exists( - "connections", "provider" - ): - op.alter_column( - "connections", - "litellm_provider", - new_column_name="provider", - existing_type=sa.String(length=100), - existing_nullable=True, - ) - op.alter_column( - "connections", - "provider", - existing_type=sa.String(length=100), - nullable=False, - ) - elif _column_exists("connections", "native_provider") and not _column_exists( - "connections", "provider" - ): - op.alter_column( - "connections", - "native_provider", - new_column_name="provider", - existing_type=sa.String(length=100), - existing_nullable=True, - ) - op.alter_column( - "connections", - "provider", - existing_type=sa.String(length=100), - nullable=False, - ) - elif not _column_exists("connections", "provider"): - op.add_column( - "connections", - sa.Column("provider", sa.String(length=100), nullable=False), - ) - _drop_index_if_exists("connections", "ix_connections_protocol") - _drop_column_if_exists("connections", "protocol") - else: - op.create_table( - "connections", - sa.Column("id", sa.Integer(), nullable=False), - sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), - sa.Column("provider", sa.String(length=100), nullable=False), - sa.Column("base_url", sa.String(length=500), nullable=True), - sa.Column("api_key", sa.String(), nullable=True), - sa.Column( - "extra", - postgresql.JSONB(astext_type=sa.Text()), - server_default=sa.text("'{}'::jsonb"), - nullable=False, - ), - sa.Column("scope", connection_scope, nullable=False), - sa.Column( - "enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False - ), - sa.Column("search_space_id", sa.Integer(), nullable=True), - sa.Column("user_id", sa.UUID(), nullable=True), - sa.CheckConstraint( - "(scope = 'GLOBAL' AND search_space_id IS NULL AND user_id IS NULL) OR " - "(scope = 'SEARCH_SPACE' AND search_space_id IS NOT NULL AND user_id IS NOT NULL) OR " - "(scope = 'USER' AND user_id IS NOT NULL)", - name="ck_connections_scope_owner", - ), - sa.ForeignKeyConstraint( - ["search_space_id"], ["searchspaces.id"], ondelete="CASCADE" - ), - sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), - sa.PrimaryKeyConstraint("id"), - ) - if _index_exists( - "connections", "ix_connections_native_provider" - ) and not _index_exists("connections", "ix_connections_provider"): - op.execute( - "ALTER INDEX ix_connections_native_provider " - "RENAME TO ix_connections_provider" - ) - if _index_exists( - "connections", "ix_connections_litellm_provider" - ) and not _index_exists("connections", "ix_connections_provider"): - op.execute( - "ALTER INDEX ix_connections_litellm_provider " - "RENAME TO ix_connections_provider" - ) - _create_index_if_missing("ix_connections_provider", "connections", ["provider"]) - _create_index_if_missing("ix_connections_scope", "connections", ["scope"]) - - if not _table_exists("models"): - op.create_table( - "models", - sa.Column("id", sa.Integer(), nullable=False), - sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), - sa.Column("connection_id", sa.Integer(), nullable=False), - sa.Column("model_id", sa.String(length=255), nullable=False), - sa.Column("display_name", sa.String(length=255), nullable=True), - sa.Column( - "source", - model_source, - server_default="DISCOVERED", - nullable=False, - ), - sa.Column("supports_chat", sa.Boolean(), nullable=True), - sa.Column("max_input_tokens", sa.Integer(), nullable=True), - sa.Column("supports_image_input", sa.Boolean(), nullable=True), - sa.Column("supports_tools", sa.Boolean(), nullable=True), - sa.Column("supports_image_generation", sa.Boolean(), nullable=True), - sa.Column( - "capabilities_override", - postgresql.JSONB(astext_type=sa.Text()), - server_default=sa.text("'{}'::jsonb"), - nullable=False, - ), - sa.Column( - "enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False - ), - sa.Column("billing_tier", sa.String(length=50), nullable=True), - sa.Column( - "catalog", - postgresql.JSONB(astext_type=sa.Text()), - server_default=sa.text("'{}'::jsonb"), - nullable=False, - ), - sa.ForeignKeyConstraint( - ["connection_id"], ["connections.id"], ondelete="CASCADE" - ), - sa.PrimaryKeyConstraint("id"), - sa.UniqueConstraint( - "connection_id", "model_id", name="uq_models_connection_model_id" - ), - ) - else: - if not _column_exists("models", "supports_chat"): - op.add_column( - "models", sa.Column("supports_chat", sa.Boolean(), nullable=True) - ) - if not _column_exists("models", "max_input_tokens"): - op.add_column( - "models", sa.Column("max_input_tokens", sa.Integer(), nullable=True) - ) - if not _column_exists("models", "supports_image_input"): - op.add_column( - "models", sa.Column("supports_image_input", sa.Boolean(), nullable=True) - ) - if not _column_exists("models", "supports_tools"): - op.add_column( - "models", sa.Column("supports_tools", sa.Boolean(), nullable=True) - ) - if not _column_exists("models", "supports_image_generation"): - op.add_column( - "models", - sa.Column("supports_image_generation", sa.Boolean(), nullable=True), - ) - _drop_column_if_exists("models", "capabilities") - _drop_column_if_exists("models", "capabilities_declared") - _drop_column_if_exists("models", "capabilities_verified") - _create_index_if_missing("ix_models_connection_id", "models", ["connection_id"]) - _create_index_if_missing("ix_models_model_id", "models", ["model_id"]) - _create_index_if_missing("ix_models_billing_tier", "models", ["billing_tier"]) - - _add_searchspace_column_if_missing("chat_model_id", server_default=sa.text("0")) - _add_searchspace_column_if_missing( - "image_gen_model_id", server_default=sa.text("0") - ) - _add_searchspace_column_if_missing("vision_model_id", server_default=sa.text("0")) - for column_name in ("chat_model_id", "image_gen_model_id", "vision_model_id"): - op.alter_column( - "searchspaces", - column_name, - existing_type=sa.Integer(), - existing_nullable=True, - server_default=sa.text("0"), - ) - op.execute( - """ - UPDATE searchspaces - SET - chat_model_id = COALESCE(chat_model_id, 0), - image_gen_model_id = COALESCE(image_gen_model_id, 0), - vision_model_id = COALESCE(vision_model_id, 0) - """ - ) - - op.execute("DROP TYPE IF EXISTS connectionprotocol") - - -def downgrade() -> None: - op.drop_column("searchspaces", "vision_model_id") - op.drop_column("searchspaces", "image_gen_model_id") - op.drop_column("searchspaces", "chat_model_id") - - op.drop_index(op.f("ix_models_billing_tier"), table_name="models") - op.drop_index("ix_models_model_id", table_name="models") - op.drop_index(op.f("ix_models_connection_id"), table_name="models") - op.drop_table("models") - - op.drop_index(op.f("ix_connections_scope"), table_name="connections") - op.drop_index(op.f("ix_connections_provider"), table_name="connections") - op.drop_table("connections") - - bind = op.get_bind() - model_source.drop(bind, checkfirst=True) - connection_scope.drop(bind, checkfirst=True) diff --git a/surfsense_backend/alembic/versions/161_remove_legacy_model_configs.py b/surfsense_backend/alembic/versions/161_remove_legacy_model_configs.py deleted file mode 100644 index 2108d763c..000000000 --- a/surfsense_backend/alembic/versions/161_remove_legacy_model_configs.py +++ /dev/null @@ -1,270 +0,0 @@ -"""remove legacy model config tables - -Revision ID: 161 -Revises: 160 -""" - -from collections.abc import Sequence - -import sqlalchemy as sa -from sqlalchemy.dialects import postgresql -from sqlalchemy.types import TypeEngine - -from alembic import op - -revision: str = "161" -down_revision: str | None = "160" -branch_labels: str | Sequence[str] | None = None -depends_on: str | Sequence[str] | None = None - - -litellm_provider = postgresql.ENUM( - "OPENAI", - "ANTHROPIC", - "GOOGLE", - "AZURE_OPENAI", - "BEDROCK", - "VERTEX_AI", - "GROQ", - "COHERE", - "MISTRAL", - "DEEPSEEK", - "XAI", - "OPENROUTER", - "TOGETHER_AI", - "FIREWORKS_AI", - "REPLICATE", - "PERPLEXITY", - "OLLAMA", - "ALIBABA_QWEN", - "MOONSHOT", - "ZHIPU", - "ANYSCALE", - "DEEPINFRA", - "CEREBRAS", - "SAMBANOVA", - "AI21", - "CLOUDFLARE", - "DATABRICKS", - "COMETAPI", - "HUGGINGFACE", - "GITHUB_MODELS", - "MINIMAX", - "CUSTOM", - name="litellmprovider", - create_type=False, -) -image_gen_provider = postgresql.ENUM( - "OPENAI", - "AZURE_OPENAI", - "GOOGLE", - "VERTEX_AI", - "BEDROCK", - "RECRAFT", - "OPENROUTER", - "XINFERENCE", - "NSCALE", - name="imagegenprovider", - create_type=False, -) -vision_provider = postgresql.ENUM( - "OPENAI", - "ANTHROPIC", - "GOOGLE", - "AZURE_OPENAI", - "VERTEX_AI", - "BEDROCK", - "XAI", - "OPENROUTER", - "OLLAMA", - "GROQ", - "TOGETHER_AI", - "FIREWORKS_AI", - "DEEPSEEK", - "MISTRAL", - "CUSTOM", - name="visionprovider", - create_type=False, -) - - -def _table_exists(table_name: str) -> bool: - return table_name in sa.inspect(op.get_bind()).get_table_names() - - -def _column_exists(table_name: str, column_name: str) -> bool: - if not _table_exists(table_name): - return False - return column_name in { - column["name"] for column in sa.inspect(op.get_bind()).get_columns(table_name) - } - - -def _drop_column_if_exists(table_name: str, column_name: str) -> None: - if _column_exists(table_name, column_name): - op.drop_column(table_name, column_name) - - -def _rename_column_if_exists( - table_name: str, - old_column_name: str, - new_column_name: str, - *, - existing_type: TypeEngine, - existing_nullable: bool = True, -) -> None: - if _column_exists(table_name, old_column_name) and not _column_exists( - table_name, new_column_name - ): - op.alter_column( - table_name, - old_column_name, - new_column_name=new_column_name, - existing_type=existing_type, - existing_nullable=existing_nullable, - ) - - -def upgrade() -> None: - for table_name in ( - "new_llm_configs", - "vision_llm_configs", - "image_generation_configs", - ): - if _table_exists(table_name): - op.drop_table(table_name) - - _drop_column_if_exists("searchspaces", "agent_llm_id") - _drop_column_if_exists("searchspaces", "image_generation_config_id") - _drop_column_if_exists("searchspaces", "vision_llm_config_id") - - _rename_column_if_exists( - "image_generations", - "image_generation_config_id", - "image_gen_model_id", - existing_type=sa.Integer(), - ) - - op.execute("DROP TYPE IF EXISTS litellmprovider") - op.execute("DROP TYPE IF EXISTS imagegenprovider") - op.execute("DROP TYPE IF EXISTS visionprovider") - - -def downgrade() -> None: - bind = op.get_bind() - litellm_provider.create(bind, checkfirst=True) - image_gen_provider.create(bind, checkfirst=True) - vision_provider.create(bind, checkfirst=True) - - _rename_column_if_exists( - "image_generations", - "image_gen_model_id", - "image_generation_config_id", - existing_type=sa.Integer(), - ) - - if _table_exists("searchspaces"): - if not _column_exists("searchspaces", "agent_llm_id"): - op.add_column( - "searchspaces", - sa.Column("agent_llm_id", sa.Integer(), nullable=True), - ) - if not _column_exists("searchspaces", "image_generation_config_id"): - op.add_column( - "searchspaces", - sa.Column("image_generation_config_id", sa.Integer(), nullable=True), - ) - if not _column_exists("searchspaces", "vision_llm_config_id"): - op.add_column( - "searchspaces", - sa.Column("vision_llm_config_id", sa.Integer(), nullable=True), - ) - - if not _table_exists("image_generation_configs"): - op.create_table( - "image_generation_configs", - sa.Column("id", sa.Integer(), nullable=False), - sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), - sa.Column("name", sa.String(length=100), nullable=False), - sa.Column("description", sa.String(length=500), nullable=True), - sa.Column("provider", image_gen_provider, nullable=False), - sa.Column("custom_provider", sa.String(length=100), nullable=True), - sa.Column("model_name", sa.String(length=100), nullable=False), - sa.Column("api_key", sa.String(), nullable=False), - sa.Column("api_base", sa.String(length=500), nullable=True), - sa.Column("api_version", sa.String(length=50), nullable=True), - sa.Column("litellm_params", sa.JSON(), nullable=True), - sa.Column("search_space_id", sa.Integer(), nullable=False), - sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False), - sa.ForeignKeyConstraint( - ["search_space_id"], ["searchspaces.id"], ondelete="CASCADE" - ), - sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), - sa.PrimaryKeyConstraint("id"), - ) - op.create_index( - op.f("ix_image_generation_configs_name"), - "image_generation_configs", - ["name"], - unique=False, - ) - - if not _table_exists("vision_llm_configs"): - op.create_table( - "vision_llm_configs", - sa.Column("id", sa.Integer(), nullable=False), - sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), - sa.Column("name", sa.String(length=100), nullable=False), - sa.Column("description", sa.String(length=500), nullable=True), - sa.Column("provider", vision_provider, nullable=False), - sa.Column("custom_provider", sa.String(length=100), nullable=True), - sa.Column("model_name", sa.String(length=100), nullable=False), - sa.Column("api_key", sa.String(), nullable=False), - sa.Column("api_base", sa.String(length=500), nullable=True), - sa.Column("api_version", sa.String(length=50), nullable=True), - sa.Column("litellm_params", sa.JSON(), nullable=True), - sa.Column("search_space_id", sa.Integer(), nullable=False), - sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False), - sa.ForeignKeyConstraint( - ["search_space_id"], ["searchspaces.id"], ondelete="CASCADE" - ), - sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), - sa.PrimaryKeyConstraint("id"), - ) - op.create_index( - op.f("ix_vision_llm_configs_name"), - "vision_llm_configs", - ["name"], - unique=False, - ) - - if not _table_exists("new_llm_configs"): - op.create_table( - "new_llm_configs", - sa.Column("id", sa.Integer(), nullable=False), - sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), - sa.Column("name", sa.String(length=100), nullable=False), - sa.Column("description", sa.String(length=500), nullable=True), - sa.Column("provider", litellm_provider, nullable=False), - sa.Column("custom_provider", sa.String(length=100), nullable=True), - sa.Column("model_name", sa.String(length=100), nullable=False), - sa.Column("api_key", sa.String(), nullable=False), - sa.Column("api_base", sa.String(length=500), nullable=True), - sa.Column("litellm_params", sa.JSON(), nullable=True), - sa.Column("system_instructions", sa.Text(), nullable=False), - sa.Column("use_default_system_instructions", sa.Boolean(), nullable=False), - sa.Column("citations_enabled", sa.Boolean(), nullable=False), - sa.Column("search_space_id", sa.Integer(), nullable=False), - sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False), - sa.ForeignKeyConstraint( - ["search_space_id"], ["searchspaces.id"], ondelete="CASCADE" - ), - sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), - sa.PrimaryKeyConstraint("id"), - ) - op.create_index( - op.f("ix_new_llm_configs_name"), - "new_llm_configs", - ["name"], - unique=False, - ) diff --git a/surfsense_backend/alembic/versions/162_add_etl_cache_parses.py b/surfsense_backend/alembic/versions/162_add_etl_cache_parses.py deleted file mode 100644 index 87e1c5813..000000000 --- a/surfsense_backend/alembic/versions/162_add_etl_cache_parses.py +++ /dev/null @@ -1,53 +0,0 @@ -"""add etl_cache_parses table for content-addressed parse reuse - -Revision ID: 162 -Revises: 161 -""" - -from collections.abc import Sequence - -from alembic import op - -revision: str = "162" -down_revision: str | None = "161" -branch_labels: str | Sequence[str] | None = None -depends_on: str | Sequence[str] | None = None - - -def upgrade() -> None: - op.execute( - """ - CREATE TABLE IF NOT EXISTS etl_cache_parses ( - id SERIAL PRIMARY KEY, - source_sha256 VARCHAR(64) NOT NULL, - etl_service VARCHAR(32) NOT NULL, - mode VARCHAR(16) NOT NULL, - parser_version INTEGER NOT NULL, - storage_backend VARCHAR(32) NOT NULL, - storage_key TEXT NOT NULL, - size_bytes BIGINT NOT NULL, - content_type VARCHAR(32) NOT NULL, - actual_pages INTEGER NOT NULL DEFAULT 0, - times_reused BIGINT NOT NULL DEFAULT 0, - last_used_at TIMESTAMP WITH TIME ZONE NOT NULL, - created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), - CONSTRAINT uq_etl_cache_parses_key - UNIQUE (source_sha256, etl_service, mode, parser_version) - ); - """ - ) - - op.execute( - "CREATE INDEX IF NOT EXISTS ix_etl_cache_parses_last_used_at " - "ON etl_cache_parses(last_used_at);" - ) - op.execute( - "CREATE INDEX IF NOT EXISTS ix_etl_cache_parses_created_at " - "ON etl_cache_parses(created_at);" - ) - - -def downgrade() -> None: - op.execute("DROP INDEX IF EXISTS ix_etl_cache_parses_created_at;") - op.execute("DROP INDEX IF EXISTS ix_etl_cache_parses_last_used_at;") - op.execute("DROP TABLE IF EXISTS etl_cache_parses;") diff --git a/surfsense_backend/alembic/versions/163_add_embedding_cache_sets.py b/surfsense_backend/alembic/versions/163_add_embedding_cache_sets.py deleted file mode 100644 index f15616332..000000000 --- a/surfsense_backend/alembic/versions/163_add_embedding_cache_sets.py +++ /dev/null @@ -1,53 +0,0 @@ -"""add embedding_cache_sets table for content-addressed embedding reuse - -Revision ID: 163 -Revises: 162 -""" - -from collections.abc import Sequence - -from alembic import op - -revision: str = "163" -down_revision: str | None = "162" -branch_labels: str | Sequence[str] | None = None -depends_on: str | Sequence[str] | None = None - - -def upgrade() -> None: - op.execute( - """ - CREATE TABLE IF NOT EXISTS embedding_cache_sets ( - id SERIAL PRIMARY KEY, - markdown_sha256 VARCHAR(64) NOT NULL, - embedding_model VARCHAR(255) NOT NULL, - embedding_dim INTEGER NOT NULL, - chunker_kind VARCHAR(8) NOT NULL, - chunker_version INTEGER NOT NULL, - storage_backend VARCHAR(32) NOT NULL, - storage_key TEXT NOT NULL, - size_bytes BIGINT NOT NULL, - chunk_count INTEGER NOT NULL DEFAULT 0, - times_reused BIGINT NOT NULL DEFAULT 0, - last_used_at TIMESTAMP WITH TIME ZONE NOT NULL, - created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), - CONSTRAINT uq_embedding_cache_sets_key - UNIQUE (markdown_sha256, embedding_model, chunker_kind, chunker_version) - ); - """ - ) - - op.execute( - "CREATE INDEX IF NOT EXISTS ix_embedding_cache_sets_last_used_at " - "ON embedding_cache_sets(last_used_at);" - ) - op.execute( - "CREATE INDEX IF NOT EXISTS ix_embedding_cache_sets_created_at " - "ON embedding_cache_sets(created_at);" - ) - - -def downgrade() -> None: - op.execute("DROP INDEX IF EXISTS ix_embedding_cache_sets_created_at;") - op.execute("DROP INDEX IF EXISTS ix_embedding_cache_sets_last_used_at;") - op.execute("DROP TABLE IF EXISTS embedding_cache_sets;") diff --git a/surfsense_backend/alembic/versions/164_remove_inactive_users.py b/surfsense_backend/alembic/versions/164_remove_inactive_users.py deleted file mode 100644 index 3ce23e204..000000000 --- a/surfsense_backend/alembic/versions/164_remove_inactive_users.py +++ /dev/null @@ -1,219 +0,0 @@ -"""remove users that never logged back in (last_login IS NULL) - -Migration 103 added ``user.last_login``. Any user whose ``last_login`` is still -NULL has never authenticated since that column existed, i.e. they never logged -back in. This migration purges those users together with everything that hangs -off them: the search spaces they own, and (via ON DELETE CASCADE) -``searchspaces -> documents -> chunks`` plus all other user/space-scoped rows. - -This runs BEFORE the chunks.position backfill (revision 165) on purpose: it -removes a large amount of dead chunk data first, so the expensive backfill has -far fewer rows to rewrite. - -Work is done in committed batches (not one giant cascading DELETE) so that on a -large table it streams progress to the alembic console, keeps each transaction -small, bounds WAL/bloat growth, and is resumable if interrupted. - -Revision ID: 164 -Revises: 163 -""" - -import logging -import time -from collections.abc import Sequence - -import sqlalchemy as sa - -from alembic import op - -revision: str = "164" -down_revision: str | None = "163" -branch_labels: str | Sequence[str] | None = None -depends_on: str | Sequence[str] | None = None - -# Documents removed per committed batch. Each document delete cascades to its -# chunks (via ix_chunks_document_id), so keep this modest to bound batch size. -DOC_BATCH = 1_000 -# Users removed per committed batch. Each cascades to owned search spaces and -# the remaining space-/user-scoped rows. -USER_BATCH = 500 -# Minimum seconds between progress log lines (keeps the console readable). -LOG_EVERY_SECONDS = 5.0 - -USER_SCRATCH = "_inactive_user_ids" -DOC_SCRATCH = "_inactive_doc_ids" - -logger = logging.getLogger("alembic.runtime.migration") - - -def _fmt_duration(seconds: float) -> str: - seconds = int(seconds) - h, rem = divmod(seconds, 3600) - m, s = divmod(rem, 60) - if h: - return f"{h}h{m:02d}m{s:02d}s" - if m: - return f"{m}m{s:02d}s" - return f"{s}s" - - -def upgrade() -> None: - bind = op.get_bind() - - # Run the heavy work outside the migration's single transaction so each - # batch can commit on its own. - with op.get_context().autocommit_block(): - # Materialize the target user ids once. Rebuilt from scratch on every - # run, so a re-run after an interruption simply picks up whoever still - # has NULL last_login -> the migration is idempotent and resumable. - op.execute(f"DROP TABLE IF EXISTS {USER_SCRATCH};") - op.execute( - f"CREATE UNLOGGED TABLE {USER_SCRATCH} AS " - 'SELECT id FROM "user" WHERE last_login IS NULL;' - ) - op.execute(f"ALTER TABLE {USER_SCRATCH} ADD PRIMARY KEY (id);") - - total_users = ( - bind.execute(sa.text(f"SELECT count(*) FROM {USER_SCRATCH}")).scalar() or 0 - ) - if total_users == 0: - logger.info("no users with NULL last_login; nothing to remove") - op.execute(f"DROP TABLE IF EXISTS {USER_SCRATCH};") - return - - logger.info( - "found %s users with NULL last_login (never logged back in); " - "removing them and all data in search spaces they own", - f"{total_users:,}", - ) - - # Documents living in search spaces owned by those users. Deleting these - # explicitly (in batches) is what bounds the otherwise-unbounded - # chunks cascade. - op.execute(f"DROP TABLE IF EXISTS {DOC_SCRATCH};") - op.execute( - f""" - CREATE UNLOGGED TABLE {DOC_SCRATCH} AS - SELECT d.id - FROM documents d - JOIN searchspaces s ON s.id = d.search_space_id - WHERE s.user_id IN (SELECT id FROM {USER_SCRATCH}); - """ - ) - op.execute(f"ALTER TABLE {DOC_SCRATCH} ADD PRIMARY KEY (id);") - total_docs = ( - bind.execute(sa.text(f"SELECT count(*) FROM {DOC_SCRATCH}")).scalar() or 0 - ) - - # Phase 1: delete documents (cascades chunks, document_versions, - # document_files) in committed batches. - logger.info( - "phase 1/2: deleting %s documents (cascades their chunks) " - "in batches of %s...", - f"{total_docs:,}", - f"{DOC_BATCH:,}", - ) - _batched_delete( - bind, - scratch=DOC_SCRATCH, - target_table="documents", - target_col="id", - batch_size=DOC_BATCH, - total=total_docs, - label="documents", - ) - op.execute(f"DROP TABLE IF EXISTS {DOC_SCRATCH};") - - # Phase 2: delete the users themselves. This cascades the now-empty - # search spaces plus all remaining user-/space-scoped rows. - logger.info( - "phase 2/2: deleting %s users (cascades search spaces and " - "remaining data) in batches of %s...", - f"{total_users:,}", - f"{USER_BATCH:,}", - ) - _batched_delete( - bind, - scratch=USER_SCRATCH, - target_table='"user"', - target_col="id", - batch_size=USER_BATCH, - total=total_users, - label="users", - ) - op.execute(f"DROP TABLE IF EXISTS {USER_SCRATCH};") - - logger.info("migration 164 finished") - - -def _batched_delete( - bind: sa.engine.Connection, - *, - scratch: str, - target_table: str, - target_col: str, - batch_size: int, - total: int, - label: str, -) -> None: - """Pop ids from ``scratch`` and delete the matching rows, one committed - batch at a time, logging progress. Atomic per batch: the row delete and the - scratch pop happen in a single statement, so an interrupted run leaves the - scratch table in sync with what has actually been deleted.""" - started = time.monotonic() - last_log = 0.0 - done = 0 - - stmt = sa.text( - f""" - WITH batch AS ( - SELECT id FROM {scratch} LIMIT :n - ), deleted AS ( - DELETE FROM {target_table} - WHERE {target_col} IN (SELECT id FROM batch) - ), popped AS ( - DELETE FROM {scratch} - WHERE id IN (SELECT id FROM batch) - RETURNING id - ) - SELECT count(*) FROM popped - """ - ) - - while True: - popped = bind.execute(stmt, {"n": batch_size}).scalar() or 0 - if popped == 0: - break - done += popped - - now = time.monotonic() - if now - last_log >= LOG_EVERY_SECONDS or done >= total: - elapsed = now - started - pct = (100.0 * done / total) if total else 100.0 - eta = (elapsed / pct * (100.0 - pct)) if pct > 0 else 0.0 - logger.info( - "%s deleted: %.1f%% (%s/%s) elapsed %s eta %s", - label, - pct, - f"{done:,}", - f"{total:,}", - _fmt_duration(elapsed), - _fmt_duration(eta), - ) - last_log = now - - logger.info( - "deleted %s %s in %s", - f"{done:,}", - label, - _fmt_duration(time.monotonic() - started), - ) - - -def downgrade() -> None: - # Irreversible: deleted users and their cascaded data cannot be restored. - # No-op so the downgrade chain can still pass through this revision. - logger.warning( - "migration 164 (remove_inactive_users) is irreversible; " - "downgrade is a no-op (deleted users/data are not restored)" - ) diff --git a/surfsense_backend/alembic/versions/165_add_chunk_position.py b/surfsense_backend/alembic/versions/165_add_chunk_position.py deleted file mode 100644 index b214f3d89..000000000 --- a/surfsense_backend/alembic/versions/165_add_chunk_position.py +++ /dev/null @@ -1,49 +0,0 @@ -"""add chunks.position for explicit document order - -Incremental re-indexing keeps unchanged chunk rows, so auto-increment ids no -longer reflect document order. The ``position`` column makes that order -explicit and is written by the indexing pipeline for every new or re-indexed -document. - -This migration intentionally does NOT backfill historical rows. On a large, -heavily-indexed table (notably a multi-hundred-GB HNSW embedding index) a bulk -UPDATE of every chunk becomes a non-HOT update that rewrites every secondary -index per row -- turning a one-off migration into a multi-day operation. -Instead, existing rows keep ``position = 0`` and therefore order by the -``Chunk.id`` tiebreaker (identical to the pre-feature behavior); new and -re-indexed documents get correct positions from application code, and any -document needing exact order can simply be re-indexed on demand. - -Revision ID: 165 -Revises: 164 -""" - -from collections.abc import Sequence - -from alembic import op - -revision: str = "165" -down_revision: str | None = "164" -branch_labels: str | Sequence[str] | None = None -depends_on: str | Sequence[str] | None = None - -# Leftover UNLOGGED scratch table from earlier backfill attempts; dropped here -# so re-running this migration converges the schema regardless of past state. -SCRATCH_TABLE = "_chunk_position_backfill" - - -def upgrade() -> None: - # Adding a NOT NULL column with a constant default is metadata-only on - # PostgreSQL 11+, so this is fast even on very large tables. IF NOT EXISTS - # makes it a no-op where the column already exists. - op.execute( - "ALTER TABLE chunks ADD COLUMN IF NOT EXISTS position INTEGER NOT NULL DEFAULT 0;" - ) - - # Clean up the scratch table left behind by the abandoned backfill approach. - op.execute(f"DROP TABLE IF EXISTS {SCRATCH_TABLE};") - - -def downgrade() -> None: - op.execute(f"DROP TABLE IF EXISTS {SCRATCH_TABLE};") - op.execute("ALTER TABLE chunks DROP COLUMN IF EXISTS position;") diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/kb_persistence/middleware.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/kb_persistence/middleware.py index a6c83a7d4..ef86eaddd 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/kb_persistence/middleware.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/kb_persistence/middleware.py @@ -241,15 +241,8 @@ async def _create_document( chunk_embeddings = await asyncio.to_thread(embed_texts, chunks) session.add_all( [ - Chunk( - document_id=doc.id, - content=text, - embedding=embedding, - position=i, - ) - for i, (text, embedding) in enumerate( - zip(chunks, chunk_embeddings, strict=True) - ) + Chunk(document_id=doc.id, content=text, embedding=embedding) + for text, embedding in zip(chunks, chunk_embeddings, strict=True) ] ) return doc @@ -296,15 +289,8 @@ async def _update_document( chunk_embeddings = await asyncio.to_thread(embed_texts, chunks) session.add_all( [ - Chunk( - document_id=document.id, - content=text, - embedding=embedding, - position=i, - ) - for i, (text, embedding) in enumerate( - zip(chunks, chunk_embeddings, strict=True) - ) + Chunk(document_id=document.id, content=text, embedding=embedding) + for text, embedding in zip(chunks, chunk_embeddings, strict=True) ] ) return document @@ -489,9 +475,7 @@ async def _load_chunks_for_snapshot( session: AsyncSession, *, doc_id: int ) -> list[dict[str, str]]: rows = await session.execute( - select(Chunk.content) - .where(Chunk.document_id == doc_id) - .order_by(Chunk.position, Chunk.id) + select(Chunk.content).where(Chunk.document_id == doc_id).order_by(Chunk.id) ) return [{"content": row.content} for row in rows.all() if row.content is not None] diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/agent_cache.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/agent_cache.py index 2d3599de0..6ac22e575 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/agent_cache.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/agent_cache.py @@ -57,7 +57,7 @@ async def build_agent_with_cache( mcp_tools_by_agent: dict[str, list[BaseTool]], disabled_tools: list[str] | None, config_id: str | None, - image_gen_model_id_override: int | None = None, + image_generation_config_id_override: int | None = None, ) -> Any: """Compile the multi-agent graph, serving from cache when key components are stable.""" @@ -121,7 +121,7 @@ async def build_agent_with_cache( # Bound into the generate_image subagent tool at construction time, so it # must key the compiled-agent cache to avoid leaking one automation's # image model into another with the same config_id/search_space. - image_gen_model_id_override, + image_generation_config_id_override, ) return await get_cache().get_or_build(cache_key, builder=_build) diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/factory.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/factory.py index 10a734192..adb1bc1ed 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/factory.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/factory.py @@ -72,11 +72,11 @@ async def create_multi_agent_chat_deep_agent( mentioned_document_ids: list[int] | None = None, anon_session_id: str | None = None, filesystem_selection: FilesystemSelection | None = None, - image_gen_model_id: int | None = None, + image_generation_config_id: int | None = None, ): """Deep agent with SurfSense tools/middleware; registry route subagents behind ``task`` when enabled. - ``image_gen_model_id`` overrides the search space's image model for + ``image_generation_config_id`` overrides the search space's image model for this invocation (used by automations to run on their captured model). When ``None``, the ``generate_image`` tool resolves the live search-space pref. """ @@ -147,7 +147,7 @@ async def create_multi_agent_chat_deep_agent( "llm": llm, # Per-invocation image model override (automations run on their captured # model). Reaches the generate_image subagent tool via subagent_dependencies. - "image_gen_model_id_override": image_gen_model_id, + "image_generation_config_id_override": image_generation_config_id, } _t0 = time.perf_counter() @@ -303,7 +303,7 @@ async def create_multi_agent_chat_deep_agent( mcp_tools_by_agent=mcp_tools_by_agent, disabled_tools=disabled_tools, config_id=config_id, - image_gen_model_id_override=image_gen_model_id, + image_generation_config_id_override=image_generation_config_id, ) _perf_log.info( "[create_agent] Middleware stack + graph compiled in %.3fs", diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/routing.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/routing.md index aa6217041..28cf0ac63 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/routing.md +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/routing.md @@ -126,25 +126,23 @@ user: "Create issues in Linear for each of these five bugs: " user: "Make a 30-second podcast of this conversation." -→ Podcast deliverable. The `deliverables` subagent sets the podcast up and - returns **immediately** — generation does not happen during the call. A - live card in the chat takes over from there: the user reviews the brief - (language, voices, length) on the card, and the episode drafts and - renders automatically after they approve. +→ Celery-backed deliverable. The `deliverables` subagent dispatches the + Celery job and then **waits for it to finish** before returning. The + call may take 10-60 seconds (or longer for video presentations) — + that is intentional, not a hang. You always get back one of two + Receipt shapes: task(deliverables, "Generate a podcast titled '' from the - following content. Aim for a 30-second style brief. Return the - podcast id and title.\n\n<source content>") + following content. Use a 30-second style brief. Return the podcast + id and title.\n\n<source content>") Outcomes: - - **`status="success"`**: the podcast is set up. Do NOT describe its - current status or promise it is ready — the card tracks progress - live and will outlive whatever you say. Just point the user at the - card in the chat. + - **`status="success"`**: the audio is saved. Tell the user the + podcast is **ready** and quote the `external_id` / `preview` so + they can find it in the podcast panel. - **`status="failed"`**: surface the Receipt's `error` field verbatim. Do NOT silently re-dispatch — the backend already tried and reported a real error. - Video presentations differ: that Celery-backed call **waits for the - render to finish** before returning (possibly minutes — intentional, - not a hang) and ends with a terminal status. If a + Same two-way pattern applies to video presentations (which take + longer to render, but still return a terminal status). If a `task(deliverables, ...)` invocation itself times out at the subagent layer (separate from the Receipt), that's an operator-side problem with the subagent invoke timeout, not a deliverable failure — pass diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/backends/kb_postgres.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/backends/kb_postgres.py index e13196537..7b8aaf2b0 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/backends/kb_postgres.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/backends/kb_postgres.py @@ -508,7 +508,7 @@ class KBPostgresBackend(BackendProtocol): chunk_rows = await session.execute( select(Chunk.id, Chunk.content) .where(Chunk.document_id == document.id) - .order_by(Chunk.position, Chunk.id) + .order_by(Chunk.id) ) chunks = [ {"chunk_id": row.id, "content": row.content} for row in chunk_rows.all() @@ -725,7 +725,7 @@ class KBPostgresBackend(BackendProtocol): .join(Document, Document.id == Chunk.document_id) .where(Document.search_space_id == self.search_space_id) .where(Chunk.content.ilike(f"%{pattern}%")) - .order_by(Chunk.document_id, Chunk.position, Chunk.id) + .order_by(Chunk.document_id, Chunk.id) ) chunk_rows = await session.execute(sub) per_doc: dict[int, int] = {} diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/knowledge_search.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/knowledge_search.py index 9ef601791..681e80b0e 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/knowledge_search.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/knowledge_search.py @@ -394,10 +394,7 @@ async def browse_recent_documents( Chunk.document_id, Chunk.content, func.row_number() - .over( - partition_by=Chunk.document_id, - order_by=(Chunk.position, Chunk.id), - ) + .over(partition_by=Chunk.document_id, order_by=Chunk.id) .label("rn"), ) .where(Chunk.document_id.in_(doc_ids)) @@ -407,7 +404,7 @@ async def browse_recent_documents( chunk_query = ( select(numbered.c.chunk_id, numbered.c.document_id, numbered.c.content) .where(numbered.c.rn <= _RECENCY_MAX_CHUNKS_PER_DOC) - .order_by(numbered.c.document_id, numbered.c.rn) + .order_by(numbered.c.document_id, numbered.c.chunk_id) ) chunk_result = await session.execute(chunk_query) fetched_chunks = chunk_result.all() @@ -534,7 +531,7 @@ async def fetch_mentioned_documents( chunk_result = await session.execute( select(Chunk.id, Chunk.content, Chunk.document_id) .where(Chunk.document_id.in_(list(docs.keys()))) - .order_by(Chunk.document_id, Chunk.position, Chunk.id) + .order_by(Chunk.document_id, Chunk.id) ) chunks_by_doc: dict[int, list[dict[str, Any]]] = {doc_id: [] for doc_id in docs} for row in chunk_result.all(): diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py index 736c508ff..7bb4a7c24 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py @@ -10,53 +10,70 @@ from langgraph.types import Command from litellm import aimage_generation from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import selectinload from app.agents.chat.multi_agent_chat.shared.receipts.command import with_receipt from app.agents.chat.multi_agent_chat.shared.receipts.receipt import make_receipt from app.config import config from app.db import ( ImageGeneration, - Model, + ImageGenerationConfig, SearchSpace, shielded_async_session, ) -from app.services.auto_model_pin_service import ( - auto_model_candidates, - choose_auto_model_candidate, -) from app.services.image_gen_router_service import ( IMAGE_GEN_AUTO_MODE_ID, + ImageGenRouterService, is_image_gen_auto_mode, ) -from app.services.model_capabilities import has_capability -from app.services.model_resolver import to_litellm +from app.services.provider_api_base import resolve_api_base from app.utils.signed_image_urls import generate_image_token logger = logging.getLogger(__name__) - -def _get_global_model(model_id: int) -> dict | None: - return next((m for m in config.GLOBAL_MODELS if m.get("id") == model_id), None) +# Provider mapping (same as routes) +_PROVIDER_MAP = { + "OPENAI": "openai", + "AZURE_OPENAI": "azure", + "GOOGLE": "gemini", + "VERTEX_AI": "vertex_ai", + "BEDROCK": "bedrock", + "RECRAFT": "recraft", + "OPENROUTER": "openrouter", + "XINFERENCE": "xinference", + "NSCALE": "nscale", +} -def _get_global_connection(connection_id: int) -> dict | None: - return next( - (c for c in config.GLOBAL_CONNECTIONS if c.get("id") == connection_id), - None, - ) +def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str: + if custom_provider: + return custom_provider + return _PROVIDER_MAP.get(provider.upper(), provider.lower()) + + +def _build_model_string( + provider: str, model_name: str, custom_provider: str | None +) -> str: + return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}" + + +def _get_global_image_gen_config(config_id: int) -> dict | None: + """Get a global image gen config by negative ID.""" + for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS: + if cfg.get("id") == config_id: + return cfg + return None def create_generate_image_tool( search_space_id: int, db_session: AsyncSession, - image_gen_model_id_override: int | None = None, + image_generation_config_id_override: int | None = None, ): """Create ``generate_image`` with bound search space; DB work uses a per-call session. - ``image_gen_model_id_override``: when set (automations running on a - captured model), use this model id instead of reading the search space's - live ``image_gen_model_id``. + ``image_generation_config_id_override``: when set (automations running on a + captured model), use this config id instead of reading the search space's + live ``image_generation_config_id``. """ del db_session # tool uses a fresh per-call session instead @@ -101,23 +118,26 @@ def create_generate_image_tool( # task's session is shared across every tool; without isolation, # autoflushes from a concurrent writer poison this tool too. async with shielded_async_session() as session: - result = await session.execute( - select(SearchSpace).filter(SearchSpace.id == search_space_id) - ) - search_space = result.scalars().first() - if not search_space: - return _failed( - {"error": "Search space not found"}, - error="Search space not found", - ) - - if image_gen_model_id_override is not None: + if image_generation_config_id_override is not None: # Automation run: use the captured image model, insulated from # later search-space changes. No search-space read needed. - config_id = image_gen_model_id_override or IMAGE_GEN_AUTO_MODE_ID - else: config_id = ( - search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID + image_generation_config_id_override or IMAGE_GEN_AUTO_MODE_ID + ) + else: + result = await session.execute( + select(SearchSpace).filter(SearchSpace.id == search_space_id) + ) + search_space = result.scalars().first() + if not search_space: + return _failed( + {"error": "Search space not found"}, + error="Search space not found", + ) + + config_id = ( + search_space.image_generation_config_id + or IMAGE_GEN_AUTO_MODE_ID ) # size/quality/style are intentionally omitted: valid values @@ -127,86 +147,73 @@ def create_generate_image_tool( gen_kwargs["n"] = n if is_image_gen_auto_mode(config_id): - candidates = await auto_model_candidates( - session, - search_space_id=search_space_id, - user_id=search_space.user_id, - capability="image_gen", - ) - if not candidates: + if not ImageGenRouterService.is_initialized(): err = ( - "No image generation models available. " + "No image generation models configured. " "Please add an image model in Settings > Image Models." ) return _failed({"error": err}, error=err) - config_id = int( - choose_auto_model_candidate(candidates, search_space_id)["id"] + response = await ImageGenRouterService.aimage_generation( + prompt=prompt, model="auto", **gen_kwargs ) - - provider_base_url: str | None = None - - if config_id < 0: - global_model = _get_global_model(config_id) - if not global_model or not has_capability( - global_model, "image_gen" - ): - err = f"Image generation model {config_id} not found" - return _failed({"error": err}, error=err) - global_connection = _get_global_connection( - global_model["connection_id"] - ) - if not global_connection: - err = f"Image generation connection for model {config_id} not found" + elif config_id < 0: + cfg = _get_global_image_gen_config(config_id) + if not cfg: + err = f"Image generation config {config_id} not found" return _failed({"error": err}, error=err) - model_string, resolved_kwargs = to_litellm( - global_connection, - global_model["model_id"], + provider_prefix = _resolve_provider_prefix( + cfg.get("provider", ""), cfg.get("custom_provider") ) - gen_kwargs.update(resolved_kwargs) - provider_base_url = resolved_kwargs.get("api_base") + model_string = f"{provider_prefix}/{cfg['model_name']}" + gen_kwargs["api_key"] = cfg.get("api_key") + # Defense-in-depth: an empty ``api_base`` must not fall + # through to LiteLLM's global ``api_base`` (e.g. Azure). + api_base = resolve_api_base( + provider=cfg.get("provider"), + provider_prefix=provider_prefix, + config_api_base=cfg.get("api_base"), + ) + if api_base: + gen_kwargs["api_base"] = api_base + if cfg.get("api_version"): + gen_kwargs["api_version"] = cfg["api_version"] + if cfg.get("litellm_params"): + gen_kwargs.update(cfg["litellm_params"]) response = await aimage_generation( prompt=prompt, model=model_string, **gen_kwargs ) else: - # Positive ID = Model + Connection + # Positive ID = user-created ImageGenerationConfig cfg_result = await session.execute( - select(Model) - .options(selectinload(Model.connection)) - .filter(Model.id == config_id, Model.enabled.is_(True)) + select(ImageGenerationConfig).filter( + ImageGenerationConfig.id == config_id + ) ) - db_model = cfg_result.scalars().first() - if ( - not db_model - or not db_model.connection - or not db_model.connection.enabled - ): - err = f"Image generation model {config_id} not found" - return _failed({"error": err}, error=err) - conn = db_model.connection - if ( - conn.search_space_id is not None - and conn.search_space_id != search_space_id - ): - err = f"Image generation model {config_id} not found" - return _failed({"error": err}, error=err) - if ( - conn.user_id is not None - and conn.user_id != search_space.user_id - ): - err = f"Image generation model {config_id} not found" - return _failed({"error": err}, error=err) - if not has_capability(db_model, "image_gen"): - err = f"Model {config_id} is not image-generation capable" + db_cfg = cfg_result.scalars().first() + if not db_cfg: + err = f"Image generation config {config_id} not found" return _failed({"error": err}, error=err) - model_string, resolved_kwargs = to_litellm( - db_model.connection, - db_model.model_id, + provider_prefix = _resolve_provider_prefix( + db_cfg.provider.value, db_cfg.custom_provider ) - gen_kwargs.update(resolved_kwargs) - provider_base_url = resolved_kwargs.get("api_base") + model_string = f"{provider_prefix}/{db_cfg.model_name}" + gen_kwargs["api_key"] = db_cfg.api_key + # Defense-in-depth: an empty ``api_base`` must not fall + # through to LiteLLM's global ``api_base`` (e.g. Azure). + api_base = resolve_api_base( + provider=db_cfg.provider.value, + provider_prefix=provider_prefix, + config_api_base=db_cfg.api_base, + ) + if api_base: + gen_kwargs["api_base"] = api_base + if db_cfg.api_version: + gen_kwargs["api_version"] = db_cfg.api_version + if db_cfg.litellm_params: + gen_kwargs.update(db_cfg.litellm_params) response = await aimage_generation( prompt=prompt, model=model_string, **gen_kwargs @@ -223,7 +230,7 @@ def create_generate_image_tool( prompt=prompt, model=getattr(response, "_hidden_params", {}).get("model"), n=n, - image_gen_model_id=config_id, + image_generation_config_id=config_id, response_data=response_dict, search_space_id=search_space_id, access_token=access_token, @@ -245,19 +252,8 @@ def create_generate_image_tool( # b64_json (e.g. gpt-image-1) is served via our backend endpoint so # megabytes of base64 don't bloat the LLM context. - # Some OpenAI-compatible backends (e.g. Xinference) return a relative - # URL like /files/image.png. Browsers can't resolve these, so we - # prepend the provider's base origin when the URL starts with "/". if first_image.get("url"): - raw_url: str = first_image["url"] - if raw_url.startswith("/") and provider_base_url: - from urllib.parse import urlparse - - parsed = urlparse(provider_base_url) - origin = f"{parsed.scheme}://{parsed.netloc}" - image_url = f"{origin}{raw_url}" - else: - image_url = raw_url + image_url = first_image["url"] elif first_image.get("b64_json"): backend_url = config.BACKEND_URL or "http://localhost:8000" image_url = ( diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/index.py index 8de95f2df..b968c1701 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/index.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/index.py @@ -51,6 +51,8 @@ def load_tools( create_generate_image_tool( search_space_id=d["search_space_id"], db_session=d["db_session"], - image_gen_model_id_override=d.get("image_gen_model_id_override"), + image_generation_config_id_override=d.get( + "image_generation_config_id_override" + ), ), ] diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/knowledge_base.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/knowledge_base.py index d89124990..e99e0291a 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/knowledge_base.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/knowledge_base.py @@ -122,7 +122,7 @@ async def _browse_recent_documents( chunk_query = ( select(Chunk) .where(Chunk.document_id.in_(doc_ids)) - .order_by(Chunk.document_id, Chunk.position, Chunk.id) + .order_by(Chunk.document_id, Chunk.id) ) chunk_result = await session.execute(chunk_query) raw_chunks = chunk_result.scalars().all() diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/podcast.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/podcast.py index d8d28ceb1..bfa3cc100 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/podcast.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/podcast.py @@ -1,10 +1,11 @@ """Factory for a podcast-generation tool. -Creates the podcast and proposes its brief (language, voices, length) inline, -then returns immediately with the row awaiting review. Everything after — -brief approval, drafting, rendering — happens on the live podcast card, so -this tool never blocks on generation and the chat text must not describe a -status that the card will outgrow. +Dispatches the heavy generation to Celery and then polls the podcast row +until it reaches a terminal status (READY/FAILED). The tool always +returns a real terminal ``Receipt`` — never a pending one. The wait is +bounded by the existing per-invocation safety net +(``SURFSENSE_SUBAGENT_INVOKE_TIMEOUT_SECONDS`` in multi-agent mode, +HTTP / process lifetime in single-agent mode). """ import logging @@ -17,12 +18,13 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.agents.chat.multi_agent_chat.shared.receipts.command import with_receipt from app.agents.chat.multi_agent_chat.shared.receipts.receipt import make_receipt +from app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.deliverable_wait import ( + wait_for_deliverable, +) from app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.tools.thread_resolver import ( resolve_root_thread_id, ) -from app.db import PodcastStatus, shielded_async_session -from app.podcasts.generation.brief import propose_brief -from app.podcasts.service import PodcastService +from app.db import Podcast, PodcastStatus, shielded_async_session logger = logging.getLogger(__name__) @@ -43,7 +45,7 @@ def create_generate_podcast_tool( user_prompt: str | None = None, ) -> Command: """ - Prepare a podcast from the provided content for the user to review. + Generate a podcast from the provided content. Use this tool when the user asks to create, generate, or make a podcast. Common triggers include phrases like: @@ -53,59 +55,100 @@ def create_generate_podcast_tool( - "Make a podcast about..." - "Turn this into a podcast" - This sets up the podcast and proposes its brief (language, voices, - length). The user reviews the brief on the live podcast card in the - chat; after approval the episode drafts and renders automatically. - Generation does not start here, and the card tracks all progress — do - not describe the podcast's current status in your reply. - Args: source_content: The text content to convert into a podcast. podcast_title: Title for the podcast (default: "SurfSense Podcast") - user_prompt: Optional steer for what the episode should focus on. + user_prompt: Optional instructions for podcast style, tone, or format. Returns: A dictionary containing: - - status: the podcast lifecycle status (awaiting_brief on success) - - podcast_id: the podcast ID to review in the panel - - title: the podcast title - - message: what the user should do next (or "error" when failed) + - status: PodcastStatus value (pending, generating, or failed) + - podcast_id: The podcast ID for polling (when status is pending or generating) + - title: The podcast title + - message: Status message (or "error" field if status is failed) """ try: # One DB session per tool call so parallel invocations never share an AsyncSession. async with shielded_async_session() as session: - service = PodcastService(session) - podcast = await service.create( + podcast = Podcast( title=podcast_title, + status=PodcastStatus.PENDING, search_space_id=search_space_id, thread_id=resolve_root_thread_id(runtime, thread_id), ) - podcast.source_content = source_content - spec = await propose_brief( - session, - search_space_id=search_space_id, - focus=user_prompt, - ) - await service.attach_brief(podcast, spec) + session.add(podcast) await session.commit() + await session.refresh(podcast) podcast_id = podcast.id - logger.info( - "[generate_podcast] Prepared podcast %s awaiting brief review", - podcast_id, + from app.tasks.celery_tasks.podcast_tasks import ( + generate_content_podcast_task, ) - payload: dict[str, Any] = { - "status": PodcastStatus.AWAITING_BRIEF.value, + task = generate_content_podcast_task.delay( + podcast_id=podcast_id, + source_content=source_content, + search_space_id=search_space_id, + user_prompt=user_prompt, + ) + + logger.info( + "[generate_podcast] Created podcast %s, task: %s", + podcast_id, + task.id, + ) + + # Wait until the Celery worker flips the row to a terminal + # state. The wait is bounded only by the subagent invoke + # timeout (multi-agent) or HTTP lifetime (single-agent) — + # see app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.deliverable_wait for details. + terminal_status, columns, elapsed = await wait_for_deliverable( + model=Podcast, + row_id=podcast_id, + columns=[Podcast.status, Podcast.file_location], + terminal_statuses={PodcastStatus.READY, PodcastStatus.FAILED}, + ) + + if terminal_status == PodcastStatus.READY: + file_location = columns[1] if columns else None + logger.info( + "[generate_podcast] Podcast %s READY in %.2fs (file=%s)", + podcast_id, + elapsed, + file_location, + ) + payload: dict[str, Any] = { + "status": PodcastStatus.READY.value, + "podcast_id": podcast_id, + "title": podcast_title, + "file_location": file_location, + "message": ("Podcast generated and saved to your podcast panel."), + } + return with_receipt( + payload=payload, + receipt=make_receipt( + route="deliverables", + type="podcast", + operation="generate", + status="success", + external_id=str(podcast_id), + preview=podcast_title, + ), + tool_call_id=runtime.tool_call_id, + ) + + # Only other terminal state is FAILED. + logger.warning( + "[generate_podcast] Podcast %s FAILED in %.2fs", + podcast_id, + elapsed, + ) + err = "Background worker reported FAILED status for this podcast." + payload = { + "status": PodcastStatus.FAILED.value, "podcast_id": podcast_id, "title": podcast_title, - "message": ( - "Podcast set up. The card in the chat handles the rest: " - "the user reviews the brief (language, voices, length) " - "there, and the episode drafts and renders automatically " - "after approval. The card tracks progress live, so do not " - "state the podcast's current status in your reply." - ), + "error": err, } return with_receipt( payload=payload, @@ -113,9 +156,10 @@ def create_generate_podcast_tool( route="deliverables", type="podcast", operation="generate", - status="success", + status="failed", external_id=str(podcast_id), preview=podcast_title, + error=err, ), tool_call_id=runtime.tool_call_id, ) diff --git a/surfsense_backend/app/agents/chat/runtime/llm_config.py b/surfsense_backend/app/agents/chat/runtime/llm_config.py index e7f2a0f0d..aad432edb 100644 --- a/surfsense_backend/app/agents/chat/runtime/llm_config.py +++ b/surfsense_backend/app/agents/chat/runtime/llm_config.py @@ -2,9 +2,9 @@ LLM configuration utilities for SurfSense agents. This module provides functions for loading LLM configurations from: -1. Auto mode (ID 0) - Resolved by callers to a concrete model-connection model +1. Auto mode (ID 0) - Uses LiteLLM Router for load balancing 2. YAML files (global configs with negative IDs) -3. Database model-connections table (user-created configs with positive IDs) +3. Database NewLLMConfig table (user-created configs with positive IDs) It also provides utilities for creating ChatLiteLLM instances and managing prompt configurations. @@ -24,6 +24,8 @@ from langchain_core.messages import AIMessage, BaseMessage from langchain_core.outputs import ChatGenerationChunk, ChatResult from langchain_litellm import ChatLiteLLM from litellm import get_model_info +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession from app.agents.chat.runtime.prompt_caching import ( apply_litellm_prompt_caching, @@ -31,7 +33,10 @@ from app.agents.chat.runtime.prompt_caching import ( from app.services.llm_router_service import ( AUTO_MODE_ID, ChatLiteLLMRouter, + LLMRouterService, _sanitize_content, + get_auto_mode_llm, + is_auto_mode, ) @@ -46,19 +51,16 @@ def _sanitize_messages(messages: list[BaseMessage]) -> list[BaseMessage]: reject the blank text. The OpenAI spec says ``content`` should be ``null`` when an assistant message only carries tool calls. """ - sanitized: list[BaseMessage] = [] for msg in messages: - next_msg = msg.model_copy(deep=True) - if isinstance(next_msg.content, list): - next_msg.content = _sanitize_content(next_msg.content) + if isinstance(msg.content, list): + msg.content = _sanitize_content(msg.content) if ( - isinstance(next_msg, AIMessage) - and (not next_msg.content or next_msg.content == "") - and getattr(next_msg, "tool_calls", None) + isinstance(msg, AIMessage) + and (not msg.content or msg.content == "") + and getattr(msg, "tool_calls", None) ): - next_msg.content = None # type: ignore[assignment] - sanitized.append(next_msg) - return sanitized + msg.content = None # type: ignore[assignment] + return messages class SanitizedChatLiteLLM(ChatLiteLLM): @@ -89,21 +91,13 @@ class SanitizedChatLiteLLM(ChatLiteLLM): ): yield chunk - async def _agenerate( - self, - messages: list[BaseMessage], - stop: list[str] | None = None, - run_manager: AsyncCallbackManagerForLLMRun | None = None, - stream: bool | None = None, - **kwargs: Any, - ) -> ChatResult: - return await super()._agenerate( - _sanitize_messages(messages), - stop=stop, - run_manager=run_manager, - stream=stream, - **kwargs, - ) + +# Re-exported under the historical name ``PROVIDER_MAP``. Source of truth lives +# in provider_capabilities so the YAML loader can resolve prefixes during +# app.config init without importing the agent/tools tree. +from app.services.provider_capabilities import ( # noqa: E402 + _PROVIDER_PREFIX_MAP as PROVIDER_MAP, +) def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None: @@ -127,9 +121,8 @@ class AgentConfig: """ Complete configuration for the SurfSense agent. - This combines resolved model settings with prompt configuration. - Supports Auto mode metadata (ID 0). Runtime callers must resolve Auto to - a concrete global or BYOK model before constructing ChatLiteLLM. + This combines LLM settings with prompt configuration from NewLLMConfig. + Supports Auto mode (ID 0) which uses LiteLLM Router for load balancing. """ # LLM Model Settings @@ -177,7 +170,7 @@ class AgentConfig: use_default_system_instructions=True, citations_enabled=True, config_id=AUTO_MODE_ID, - config_name="Auto", + config_name="Auto (Fastest)", is_auto_mode=True, billing_tier="free", is_premium=False, @@ -188,21 +181,64 @@ class AgentConfig: supports_image_input=True, ) + @classmethod + def from_new_llm_config(cls, config) -> "AgentConfig": + """Build an AgentConfig from a NewLLMConfig database model.""" + # Lazy import: keeps provider_capabilities (and litellm) out of init order. + from app.services.provider_capabilities import derive_supports_image_input + + provider_value = ( + config.provider.value + if hasattr(config.provider, "value") + else str(config.provider) + ) + litellm_params = config.litellm_params or {} + base_model = ( + litellm_params.get("base_model") + if isinstance(litellm_params, dict) + else None + ) + + return cls( + provider=provider_value, + model_name=config.model_name, + api_key=config.api_key, + api_base=config.api_base, + custom_provider=config.custom_provider, + litellm_params=config.litellm_params, + system_instructions=config.system_instructions, + use_default_system_instructions=config.use_default_system_instructions, + citations_enabled=config.citations_enabled, + config_id=config.id, + config_name=config.name, + is_auto_mode=False, + billing_tier="free", + is_premium=False, + anonymous_enabled=False, + quota_reserve_tokens=None, + # BYOK rows have no curated flag; ask LiteLLM (default-allow on + # unknown). The streaming safety net still blocks explicit text-only. + supports_image_input=derive_supports_image_input( + provider=provider_value, + model_name=config.model_name, + base_model=base_model, + custom_provider=config.custom_provider, + ), + ) + @classmethod def from_yaml_config(cls, yaml_config: dict) -> "AgentConfig": """Build an AgentConfig from a YAML configuration dictionary. - Supports prompt fields such as system_instructions, - use_default_system_instructions, and citations_enabled. + Supports the same prompt fields as NewLLMConfig (system_instructions, + use_default_system_instructions, citations_enabled). """ # Lazy import: keeps provider_capabilities (and litellm) out of init order. from app.services.provider_capabilities import derive_supports_image_input system_instructions = yaml_config.get("system_instructions", "") - provider = yaml_config.get("provider") or yaml_config.get( - "litellm_provider", "" - ) + provider = yaml_config.get("provider", "").upper() model_name = yaml_config.get("model_name", "") custom_provider = yaml_config.get("custom_provider") litellm_params = yaml_config.get("litellm_params") or {} @@ -288,15 +324,93 @@ def load_global_llm_config_by_id(llm_config_id: int) -> dict | None: return load_llm_config_from_yaml(llm_config_id) +async def load_new_llm_config_from_db( + session: AsyncSession, + config_id: int, +) -> "AgentConfig | None": + """Load a NewLLMConfig from the database by ID.""" + from app.db import NewLLMConfig + + try: + result = await session.execute( + select(NewLLMConfig).filter(NewLLMConfig.id == config_id) + ) + config = result.scalars().first() + + if not config: + print(f"Error: NewLLMConfig with id {config_id} not found") + return None + + return AgentConfig.from_new_llm_config(config) + except Exception as e: + print(f"Error loading NewLLMConfig from database: {e}") + return None + + +async def load_agent_llm_config_for_search_space( + session: AsyncSession, + search_space_id: int, +) -> "AgentConfig | None": + """Load the agent LLM config for a search space via its agent_llm_id. + + Positive id -> DB; negative -> YAML; None -> first global config (-1). + """ + from app.db import SearchSpace + + try: + result = await session.execute( + select(SearchSpace).filter(SearchSpace.id == search_space_id) + ) + search_space = result.scalars().first() + + if not search_space: + print(f"Error: SearchSpace with id {search_space_id} not found") + return None + + config_id = ( + search_space.agent_llm_id if search_space.agent_llm_id is not None else -1 + ) + return await load_agent_config(session, config_id, search_space_id) + except Exception as e: + print(f"Error loading agent LLM config for search space {search_space_id}: {e}") + return None + + +async def load_agent_config( + session: AsyncSession, + config_id: int, + search_space_id: int | None = None, +) -> "AgentConfig | None": + """Main config loader: id 0 -> Auto mode; negative -> YAML; positive -> DB.""" + if is_auto_mode(config_id): + if not LLMRouterService.is_initialized(): + print("Error: Auto mode requested but LLM Router not initialized") + return None + return AgentConfig.from_auto_mode() + + if config_id < 0: + # In-memory covers static YAML + dynamic OpenRouter configs. + from app.config import config as app_config + + for cfg in app_config.GLOBAL_LLM_CONFIGS: + if cfg.get("id") == config_id: + return AgentConfig.from_yaml_config(cfg) + yaml_config = load_llm_config_from_yaml(config_id) + if yaml_config: + return AgentConfig.from_yaml_config(yaml_config) + return None + else: + return await load_new_llm_config_from_db(session, config_id) + + def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None: """Create a ChatLiteLLM instance from a global LLM config dictionary.""" if llm_config.get("custom_provider"): model_string = f"{llm_config['custom_provider']}/{llm_config['model_name']}" else: - provider = llm_config.get("provider") or llm_config.get( - "litellm_provider", "openai" - ) - model_string = f"{provider}/{llm_config['model_name']}" + provider = llm_config.get("provider", "").upper() + provider_prefix = PROVIDER_MAP.get(provider, provider.lower()) + model_string = f"{provider_prefix}/{llm_config['model_name']}" litellm_kwargs = { "model": model_string, @@ -319,17 +433,29 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None: def create_chat_litellm_from_agent_config( agent_config: AgentConfig, ) -> ChatLiteLLM | ChatLiteLLMRouter | None: - """Create a ChatLiteLLM from an already resolved concrete model config.""" + """Create a ChatLiteLLM (or, for Auto mode, a load-balancing router) from config.""" if agent_config.is_auto_mode: - print( - "Error: Auto mode must be resolved to a concrete model before LLM creation" - ) - return None + if not LLMRouterService.is_initialized(): + print("Error: Auto mode requested but LLM Router not initialized") + return None + try: + router_llm = get_auto_mode_llm() + if router_llm is not None: + # Universal injection points only: auto-mode fans out across + # providers, so provider-specific kwargs have no known target. + apply_litellm_prompt_caching(router_llm, agent_config=agent_config) + return router_llm + except Exception as e: + print(f"Error creating ChatLiteLLMRouter: {e}") + return None if agent_config.custom_provider: model_string = f"{agent_config.custom_provider}/{agent_config.model_name}" else: - model_string = f"{agent_config.provider}/{agent_config.model_name}" + provider_prefix = PROVIDER_MAP.get( + agent_config.provider, agent_config.provider.lower() + ) + model_string = f"{provider_prefix}/{agent_config.model_name}" litellm_kwargs = { "model": model_string, diff --git a/surfsense_backend/app/agents/podcaster/__init__.py b/surfsense_backend/app/agents/podcaster/__init__.py new file mode 100644 index 000000000..8459b2977 --- /dev/null +++ b/surfsense_backend/app/agents/podcaster/__init__.py @@ -0,0 +1,8 @@ +"""New LangGraph Agent. + +This module defines a custom graph. +""" + +from .graph import graph + +__all__ = ["graph"] diff --git a/surfsense_backend/app/agents/podcaster/configuration.py b/surfsense_backend/app/agents/podcaster/configuration.py new file mode 100644 index 000000000..6a903f9df --- /dev/null +++ b/surfsense_backend/app/agents/podcaster/configuration.py @@ -0,0 +1,29 @@ +"""Define the configurable parameters for the agent.""" + +from __future__ import annotations + +from dataclasses import dataclass, fields + +from langchain_core.runnables import RunnableConfig + + +@dataclass(kw_only=True) +class Configuration: + """The configuration for the agent.""" + + # Changeme: Add configurable values here! + # these values can be pre-set when you + # create assistants (https://langchain-ai.github.io/langgraph/cloud/how-tos/configuration_cloud/) + # and when you invoke the graph + podcast_title: str + search_space_id: int + user_prompt: str | None = None + + @classmethod + def from_runnable_config( + cls, config: RunnableConfig | None = None + ) -> Configuration: + """Create a Configuration instance from a RunnableConfig object.""" + configurable = (config.get("configurable") or {}) if config else {} + _fields = {f.name for f in fields(cls) if f.init} + return cls(**{k: v for k, v in configurable.items() if k in _fields}) diff --git a/surfsense_backend/app/agents/podcaster/graph.py b/surfsense_backend/app/agents/podcaster/graph.py new file mode 100644 index 000000000..94045566b --- /dev/null +++ b/surfsense_backend/app/agents/podcaster/graph.py @@ -0,0 +1,29 @@ +from langgraph.graph import StateGraph + +from .configuration import Configuration +from .nodes import create_merged_podcast_audio, create_podcast_transcript +from .state import State + + +def build_graph(): + # Define a new graph + workflow = StateGraph(State, config_schema=Configuration) + + # Add the node to the graph + workflow.add_node("create_podcast_transcript", create_podcast_transcript) + workflow.add_node("create_merged_podcast_audio", create_merged_podcast_audio) + + # Set the entrypoint as `call_model` + workflow.add_edge("__start__", "create_podcast_transcript") + workflow.add_edge("create_podcast_transcript", "create_merged_podcast_audio") + workflow.add_edge("create_merged_podcast_audio", "__end__") + + # Compile the workflow into an executable graph + graph = workflow.compile() + graph.name = "Surfsense Podcaster" # This defines the custom name in LangSmith + + return graph + + +# Compile the graph once when the module is loaded +graph = build_graph() diff --git a/surfsense_backend/app/agents/podcaster/nodes.py b/surfsense_backend/app/agents/podcaster/nodes.py new file mode 100644 index 000000000..d1f140a44 --- /dev/null +++ b/surfsense_backend/app/agents/podcaster/nodes.py @@ -0,0 +1,195 @@ +import asyncio +import json +import os +import uuid +from pathlib import Path +from typing import Any + +from ffmpeg.asyncio import FFmpeg +from langchain_core.messages import HumanMessage, SystemMessage +from langchain_core.runnables import RunnableConfig +from litellm import aspeech + +from app.config import config as app_config +from app.services.kokoro_tts_service import get_kokoro_tts_service +from app.services.llm_service import get_agent_llm +from app.utils.content_utils import extract_text_content, strip_markdown_fences + +from .configuration import Configuration +from .prompts import get_podcast_generation_prompt +from .state import PodcastTranscriptEntry, PodcastTranscripts, State +from .utils import get_voice_for_provider + + +async def create_podcast_transcript( + state: State, config: RunnableConfig +) -> dict[str, Any]: + """Generate the podcast transcript from the source content.""" + configuration = Configuration.from_runnable_config(config) + search_space_id = configuration.search_space_id + user_prompt = configuration.user_prompt + + llm = await get_agent_llm(state.db_session, search_space_id) + if not llm: + error_message = f"No agent LLM configured for search space {search_space_id}" + print(error_message) + raise RuntimeError(error_message) + + prompt = get_podcast_generation_prompt(user_prompt) + messages = [ + SystemMessage(content=prompt), + HumanMessage( + content=f"<source_content>{state.source_content}</source_content>" + ), + ] + llm_response = await llm.ainvoke(messages) + + # Reasoning models may return content as blocks; normalise to a string. + content = strip_markdown_fences(extract_text_content(llm_response.content)) + + try: + podcast_transcript = PodcastTranscripts.model_validate(json.loads(content)) + except (json.JSONDecodeError, TypeError, ValueError) as e: + print(f"Direct JSON parsing failed, trying fallback approach: {e!s}") + + try: + json_start = content.find("{") + json_end = content.rfind("}") + 1 + if json_start >= 0 and json_end > json_start: + json_str = content[json_start:json_end] + parsed_data = json.loads(json_str) + podcast_transcript = PodcastTranscripts.model_validate(parsed_data) + print("Successfully parsed podcast transcript using fallback approach") + else: + error_message = f"Could not find valid JSON in LLM response. Raw response: {content}" + print(error_message) + raise ValueError(error_message) + + except (json.JSONDecodeError, TypeError, ValueError) as e2: + error_message = f"Error parsing LLM response (fallback also failed): {e2!s}" + print(f"Error parsing LLM response: {e2!s}") + print(f"Raw response: {content}") + raise + + return {"podcast_transcript": podcast_transcript.podcast_transcripts} + + +async def create_merged_podcast_audio( + state: State, config: RunnableConfig +) -> dict[str, Any]: + """Generate audio for each transcript and merge them into a single podcast file.""" + starting_transcript = PodcastTranscriptEntry( + speaker_id=1, dialog="Welcome to Surfsense Podcast." + ) + + transcript = state.podcast_transcript + + # transcript may be a PodcastTranscripts object or already a list. + if hasattr(transcript, "podcast_transcripts"): + transcript_entries = transcript.podcast_transcripts + else: + transcript_entries = transcript + + merged_transcript = [starting_transcript, *transcript_entries] + + temp_dir = Path("temp_audio") + temp_dir.mkdir(exist_ok=True) + + session_id = str(uuid.uuid4()) + output_path = f"podcasts/{session_id}_podcast.mp3" + os.makedirs("podcasts", exist_ok=True) + + audio_files = [] + + async def generate_speech_for_segment(segment, index): + if hasattr(segment, "speaker_id"): + speaker_id = segment.speaker_id + dialog = segment.dialog + else: + speaker_id = segment.get("speaker_id", 0) + dialog = segment.get("dialog", "") + + voice = get_voice_for_provider(app_config.TTS_SERVICE, speaker_id) + + if app_config.TTS_SERVICE == "local/kokoro": + filename = f"{temp_dir}/{session_id}_{index}.wav" + else: + filename = f"{temp_dir}/{session_id}_{index}.mp3" + + try: + if app_config.TTS_SERVICE == "local/kokoro": + kokoro_service = await get_kokoro_tts_service( + lang_code="a" + ) # American English + audio_path = await kokoro_service.generate_speech( + text=dialog, voice=voice, speed=1.0, output_path=filename + ) + return audio_path + else: + if app_config.TTS_SERVICE_API_BASE: + response = await aspeech( + model=app_config.TTS_SERVICE, + api_base=app_config.TTS_SERVICE_API_BASE, + api_key=app_config.TTS_SERVICE_API_KEY, + voice=voice, + input=dialog, + max_retries=2, + timeout=600, + ) + else: + response = await aspeech( + model=app_config.TTS_SERVICE, + api_key=app_config.TTS_SERVICE_API_KEY, + voice=voice, + input=dialog, + max_retries=2, + timeout=600, + ) + + with open(filename, "wb") as f: + f.write(response.content) + + return filename + except Exception as e: + print(f"Error generating speech for segment {index}: {e!s}") + raise + + tasks = [ + generate_speech_for_segment(segment, i) + for i, segment in enumerate(merged_transcript) + ] + audio_files = await asyncio.gather(*tasks) + + try: + ffmpeg = FFmpeg().option("y") + for audio_file in audio_files: + ffmpeg = ffmpeg.input(audio_file) + + filter_complex = [] + for i in range(len(audio_files)): + filter_complex.append(f"[{i}:0]") + + filter_complex_str = ( + "".join(filter_complex) + f"concat=n={len(audio_files)}:v=0:a=1[outa]" + ) + ffmpeg = ffmpeg.option("filter_complex", filter_complex_str) + ffmpeg = ffmpeg.output(output_path, map="[outa]") + await ffmpeg.execute() + + print(f"Successfully created podcast audio: {output_path}") + + except Exception as e: + print(f"Error merging audio files: {e!s}") + raise + finally: + for audio_file in audio_files: + try: + os.remove(audio_file) + except Exception as e: + print(f"Error removing audio file {audio_file}: {e!s}") + pass + + return { + "podcast_transcript": merged_transcript, + "final_podcast_file_path": output_path, + } diff --git a/surfsense_backend/app/agents/podcaster/prompts.py b/surfsense_backend/app/agents/podcaster/prompts.py new file mode 100644 index 000000000..efaa79788 --- /dev/null +++ b/surfsense_backend/app/agents/podcaster/prompts.py @@ -0,0 +1,122 @@ +import datetime + + +def get_podcast_generation_prompt(user_prompt: str | None = None): + return f""" +Today's date: {datetime.datetime.now().strftime("%Y-%m-%d")} +<podcast_generation_system> +You are a master podcast scriptwriter, adept at transforming diverse input content into a lively, engaging, and natural-sounding conversation between two distinct podcast hosts. Your primary objective is to craft authentic, flowing dialogue that captures the spontaneity and chemistry of a real podcast discussion, completely avoiding any hint of robotic scripting or stiff formality. Think dynamic interplay, not just information delivery. + +{ + f''' +You **MUST** strictly adhere to the following user instruction while generating the podcast script: +<user_instruction> +{user_prompt} +</user_instruction> +''' + if user_prompt + else "" + } + +<input> +- '<source_content>': A block of text containing the information to be discussed in the podcast. This could be research findings, an article summary, a detailed outline, user chat history related to the topic, or any other relevant raw information. The content might be unstructured but serves as the factual basis for the podcast dialogue. +</input> + +<output_format> +A JSON object containing the podcast transcript with alternating speakers: +{{ + "podcast_transcripts": [ + {{ + "speaker_id": 0, + "dialog": "Speaker 0 dialog here" + }}, + {{ + "speaker_id": 1, + "dialog": "Speaker 1 dialog here" + }}, + {{ + "speaker_id": 0, + "dialog": "Speaker 0 dialog here" + }}, + {{ + "speaker_id": 1, + "dialog": "Speaker 1 dialog here" + }} + ] +}} +</output_format> + +<guidelines> +1. **Establish Distinct & Consistent Host Personas:** + * **Speaker 0 (Lead Host):** Drives the conversation forward, introduces segments, poses key questions derived from the source content, and often summarizes takeaways. Maintain a guiding, clear, and engaging tone. + * **Speaker 1 (Co-Host/Expert):** Offers deeper insights, provides alternative viewpoints or elaborations on the source content, asks clarifying or challenging questions, and shares relevant anecdotes or examples. Adopt a complementary tone (e.g., analytical, enthusiastic, reflective, slightly skeptical). + * **Consistency is Key:** Ensure each speaker maintains their distinct voice, vocabulary choice, sentence structure, and perspective throughout the entire script. Avoid having them sound interchangeable. Their interaction should feel like a genuine partnership. + +2. **Craft Natural & Dynamic Dialogue:** + * **Emulate Real Conversation:** Use contractions (e.g., "don't", "it's"), interjections ("Oh!", "Wow!", "Hmm"), discourse markers ("you know", "right?", "well"), and occasional natural pauses or filler words. Avoid overly formal language or complex sentence structures typical of written text. + * **Foster Interaction & Chemistry:** Write dialogue where speakers genuinely react *to each other*. They should build on points ("Exactly, and that reminds me..."), ask follow-up questions ("Could you expand on that?"), express agreement/disagreement respectfully ("That's a fair point, but have you considered...?"), and show active listening. + * **Vary Rhythm & Pace:** Mix short, punchy lines with longer, more explanatory ones. Vary sentence beginnings. Use questions to break up exposition. The rhythm should feel spontaneous, not monotonous. + * **Inject Personality & Relatability:** Allow for appropriate humor, moments of surprise or curiosity, brief personal reflections ("I actually experienced something similar..."), or relatable asides that fit the hosts' personas and the topic. Lightly reference past discussions if it enhances context ("Remember last week when we touched on...?"). + +3. **Structure for Flow and Listener Engagement:** + * **Natural Beginning:** Start with dialogue that flows naturally after an introduction (which will be added manually). Avoid redundant greetings or podcast name mentions since these will be added separately. + * **Logical Progression & Signposting:** Guide the listener through the information smoothly. Use clear transitions to link different ideas or segments ("So, now that we've covered X, let's dive into Y...", "That actually brings me to another key finding..."). Ensure topics flow logically from one to the next. + * **Meaningful Conclusion:** Summarize the key takeaways or main points discussed, reinforcing the core message derived from the source content. End with a final thought, a lingering question for the audience, or a brief teaser for what's next, providing a sense of closure. Avoid abrupt endings. + +4. **Integrate Source Content Seamlessly & Accurately:** + * **Translate, Don't Recite:** Rephrase information from the `<source_content>` into conversational language suitable for each host's persona. Avoid directly copying dense sentences or technical jargon without explanation. The goal is discussion, not narration. + * **Explain & Contextualize:** Use analogies, simple examples, storytelling, or have one host ask clarifying questions (acting as a listener surrogate) to break down complex ideas from the source. + * **Weave Information Naturally:** Integrate facts, data, or key points from the source *within* the dialogue, not as standalone, undigested blocks. Attribute information conversationally where appropriate ("The research mentioned...", "Apparently, the key factor is..."). + * **Balance Depth & Accessibility:** Ensure the conversation is informative and factually accurate based on the source content, but prioritize clear communication and engaging delivery over exhaustive technical detail. Make it understandable and interesting for a general audience. + +5. **Length & Pacing:** + * **Six-Minute Duration:** Create a transcript that, when read at a natural speaking pace, would result in approximately 6 minutes of audio. Typically, this means around 1000 words total (based on average speaking rate of 150 words per minute). + * **Concise Speaking Turns:** Keep most speaking turns relatively brief and focused. Aim for a natural back-and-forth rhythm rather than extended monologues. + * **Essential Content Only:** Prioritize the most important information from the source content. Focus on quality over quantity, ensuring every line contributes meaningfully to the topic. +</guidelines> + +<examples> +Input: "Quantum computing uses quantum bits or qubits which can exist in multiple states simultaneously due to superposition." + +Output: +{{ + "podcast_transcripts": [ + {{ + "speaker_id": 0, + "dialog": "Today we're diving into the mind-bending world of quantum computing. You know, this is a topic I've been excited to cover for weeks." + }}, + {{ + "speaker_id": 1, + "dialog": "Same here! And I know our listeners have been asking for it. But I have to admit, the concept of quantum computing makes my head spin a little. Can we start with the basics?" + }}, + {{ + "speaker_id": 0, + "dialog": "Absolutely. So regular computers use bits, right? Little on-off switches that are either 1 or 0. But quantum computers use something called qubits, and this is where it gets fascinating." + }}, + {{ + "speaker_id": 1, + "dialog": "Wait, what makes qubits so special compared to regular bits?" + }}, + {{ + "speaker_id": 0, + "dialog": "The magic is in something called superposition. These qubits can exist in multiple states at the same time, not just 1 or 0." + }}, + {{ + "speaker_id": 1, + "dialog": "That sounds impossible! How would you even picture that?" + }}, + {{ + "speaker_id": 0, + "dialog": "Think of it like a coin spinning in the air. Before it lands, is it heads or tails?" + }}, + {{ + "speaker_id": 1, + "dialog": "Well, it's... neither? Or I guess both, until it lands? Oh, I think I see where you're going with this." + }} + ] +}} +</examples> + +Transform the source material into a lively and engaging podcast conversation. Craft dialogue that showcases authentic host chemistry and natural interaction (including occasional disagreement, building on points, or asking follow-up questions). Use varied speech patterns reflecting real human conversation, ensuring the final script effectively educates *and* entertains the listener while keeping within a 5-minute audio duration. +</podcast_generation_system> +""" diff --git a/surfsense_backend/app/agents/podcaster/state.py b/surfsense_backend/app/agents/podcaster/state.py new file mode 100644 index 000000000..62eb0537b --- /dev/null +++ b/surfsense_backend/app/agents/podcaster/state.py @@ -0,0 +1,43 @@ +"""Define the state structures for the agent.""" + +from __future__ import annotations + +from dataclasses import dataclass + +from pydantic import BaseModel, Field +from sqlalchemy.ext.asyncio import AsyncSession + + +class PodcastTranscriptEntry(BaseModel): + """ + Represents a single entry in a podcast transcript. + """ + + speaker_id: int = Field(..., description="The ID of the speaker (0 or 1)") + dialog: str = Field(..., description="The dialog text spoken by the speaker") + + +class PodcastTranscripts(BaseModel): + """ + Represents the full podcast transcript structure. + """ + + podcast_transcripts: list[PodcastTranscriptEntry] = Field( + ..., description="List of transcript entries with alternating speakers" + ) + + +@dataclass +class State: + """Defines the input state for the agent, representing a narrower interface to the outside world. + + This class is used to define the initial state and structure of incoming data. + See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state + for more information. + """ + + # Runtime context + db_session: AsyncSession + source_content: str + podcast_transcript: list[PodcastTranscriptEntry] | None = None + final_podcast_file_path: str | None = None diff --git a/surfsense_backend/app/agents/podcaster/utils.py b/surfsense_backend/app/agents/podcaster/utils.py new file mode 100644 index 000000000..96ea1d51e --- /dev/null +++ b/surfsense_backend/app/agents/podcaster/utils.py @@ -0,0 +1,84 @@ +def get_voice_for_provider(provider: str, speaker_id: int) -> dict | str: + """ + Get the appropriate voice configuration based on the TTS provider and speaker ID. + + Args: + provider: The TTS provider (e.g., "openai/tts-1", "vertex_ai/test") + speaker_id: The ID of the speaker (0-5) + + Returns: + Voice configuration - string for OpenAI, dict for Vertex AI + """ + if provider == "local/kokoro": + # Kokoro voice mapping - https://huggingface.co/hexgrad/Kokoro-82M/tree/main/voices + kokoro_voices = { + 0: "am_adam", # Default/intro voice + 1: "af_bella", # First speaker + } + return kokoro_voices.get(speaker_id, "af_heart") + + # Extract provider type from the model string + provider_type = ( + provider.split("/")[0].lower() if "/" in provider else provider.lower() + ) + + if provider_type == "openai": + # OpenAI voice mapping - simple string values + openai_voices = { + 0: "alloy", # Default/intro voice + 1: "echo", # First speaker + 2: "fable", # Second speaker + 3: "onyx", # Third speaker + 4: "nova", # Fourth speaker + 5: "shimmer", # Fifth speaker + } + return openai_voices.get(speaker_id, "alloy") + + elif provider_type == "vertex_ai": + # Vertex AI voice mapping - dict with languageCode and name + vertex_voices = { + 0: { + "languageCode": "en-US", + "name": "en-US-Studio-O", + }, + 1: { + "languageCode": "en-US", + "name": "en-US-Studio-M", + }, + 2: { + "languageCode": "en-UK", + "name": "en-UK-Studio-A", + }, + 3: { + "languageCode": "en-UK", + "name": "en-UK-Studio-B", + }, + 4: { + "languageCode": "en-AU", + "name": "en-AU-Studio-A", + }, + 5: { + "languageCode": "en-AU", + "name": "en-AU-Studio-B", + }, + } + return vertex_voices.get(speaker_id, vertex_voices[0]) + elif provider_type == "azure": + # OpenAI voice mapping - simple string values + azure_voices = { + 0: "alloy", # Default/intro voice + 1: "echo", # First speaker + 2: "fable", # Second speaker + 3: "onyx", # Third speaker + 4: "nova", # Fourth speaker + 5: "shimmer", # Fifth speaker + } + return azure_voices.get(speaker_id, "alloy") + + else: + # Default fallback to OpenAI format for unknown providers + default_voices = { + 0: {}, + 1: {}, + } + return default_voices.get(speaker_id, default_voices[0]) diff --git a/surfsense_backend/app/agents/video_presentation/__init__.py b/surfsense_backend/app/agents/video_presentation/__init__.py index 8a51eb0ef..caf885218 100644 --- a/surfsense_backend/app/agents/video_presentation/__init__.py +++ b/surfsense_backend/app/agents/video_presentation/__init__.py @@ -1,7 +1,8 @@ """Video Presentation LangGraph Agent. -This module defines a graph for generating slide-based video presentations -from source content, with TTS narration per slide. +This module defines a graph for generating video presentations +from source content, similar to the podcaster agent but producing +slide-based video presentations with TTS narration. """ from .graph import graph diff --git a/surfsense_backend/app/app.py b/surfsense_backend/app/app.py index 6dfe6a776..d3f5dce2a 100644 --- a/surfsense_backend/app/app.py +++ b/surfsense_backend/app/app.py @@ -33,6 +33,7 @@ from app.config import ( initialize_llm_router, initialize_openrouter_integration, initialize_pricing_registration, + initialize_vision_llm_router, ) from app.db import User, create_db_and_tables, get_async_session from app.exceptions import GENERIC_5XX_MESSAGE, ISSUES_URL, SurfSenseError @@ -621,6 +622,7 @@ async def lifespan(app: FastAPI): initialize_pricing_registration() initialize_llm_router() initialize_image_gen_router() + initialize_vision_llm_router() # Phase 1.7 — JIT warmup. Bounded so a stuck warmup never delays # worker readiness. ``shield`` so Uvicorn cancelling startup diff --git a/surfsense_backend/app/automations/actions/builtin/agent_task/dependencies.py b/surfsense_backend/app/automations/actions/builtin/agent_task/dependencies.py index c9584ae2a..4ef8c52bf 100644 --- a/surfsense_backend/app/automations/actions/builtin/agent_task/dependencies.py +++ b/surfsense_backend/app/automations/actions/builtin/agent_task/dependencies.py @@ -39,31 +39,31 @@ async def build_dependencies( *, session: AsyncSession, search_space_id: int, - chat_model_id: int | None = None, - image_gen_model_id: int | None = None, - vision_model_id: int | None = None, + agent_llm_id: int | None = None, + image_generation_config_id: int | None = None, + vision_llm_config_id: int | None = None, ) -> AgentDependencies: """Load the LLM bundle, connector service, and a per-invoke in-memory checkpointer. - Resolves the chat model from the automation's *captured* model snapshot - (``chat_model_id``) so runs are insulated from later chat/search-space model + Resolves the agent LLM from the automation's *captured* model snapshot + (``agent_llm_id``) so runs are insulated from later chat/search-space model changes. The model policy is enforced here as a runtime backstop: a captured model that is no longer billable (e.g. a premium global config was removed) fails the run clearly instead of silently consuming a free model. - When ``chat_model_id`` is ``None`` (no captured snapshot — defensive fallback), - fall back to the live search space's ``chat_model_id`` and validate that. + When ``agent_llm_id`` is ``None`` (no captured snapshot — defensive fallback), + fall back to the live search space's ``agent_llm_id`` and validate that. """ - if chat_model_id is not None: + if agent_llm_id is not None: try: assert_models_billable( - chat_model_id=chat_model_id, - image_gen_model_id=image_gen_model_id, - vision_model_id=vision_model_id, + agent_llm_id=agent_llm_id, + image_generation_config_id=image_generation_config_id, + vision_llm_config_id=vision_llm_config_id, ) except AutomationModelPolicyError as exc: raise DependencyError(str(exc)) from exc - resolved_chat_model_id = chat_model_id or 0 + resolved_agent_llm_id = agent_llm_id or 0 else: search_space = await session.get(SearchSpace, search_space_id) if search_space is None: @@ -72,15 +72,15 @@ async def build_dependencies( assert_automation_models_billable(search_space) except AutomationModelPolicyError as exc: raise DependencyError(str(exc)) from exc - resolved_chat_model_id = search_space.chat_model_id or 0 + resolved_agent_llm_id = search_space.agent_llm_id or 0 llm, agent_config, err = await load_llm_bundle( session, - config_id=resolved_chat_model_id, + config_id=resolved_agent_llm_id, search_space_id=search_space_id, ) if err is not None or llm is None: - raise DependencyError(err or "failed to load chat model config") + raise DependencyError(err or "failed to load agent LLM config") connector_service, firecrawl_api_key = await setup_connector_and_firecrawl( session, search_space_id=search_space_id diff --git a/surfsense_backend/app/automations/actions/builtin/agent_task/invoke.py b/surfsense_backend/app/automations/actions/builtin/agent_task/invoke.py index c3a35930d..aa96e4f6e 100644 --- a/surfsense_backend/app/automations/actions/builtin/agent_task/invoke.py +++ b/surfsense_backend/app/automations/actions/builtin/agent_task/invoke.py @@ -150,9 +150,9 @@ async def run_agent_task( deps = await build_dependencies( session=agent_session, search_space_id=ctx.search_space_id, - chat_model_id=ctx.chat_model_id, - image_gen_model_id=ctx.image_gen_model_id, - vision_model_id=ctx.vision_model_id, + agent_llm_id=ctx.agent_llm_id, + image_generation_config_id=ctx.image_generation_config_id, + vision_llm_config_id=ctx.vision_llm_config_id, ) agent = await create_multi_agent_chat_deep_agent( @@ -167,7 +167,7 @@ async def run_agent_task( firecrawl_api_key=deps.firecrawl_api_key, thread_visibility=ChatVisibility.PRIVATE, mentioned_document_ids=mentioned_document_ids, - image_gen_model_id=ctx.image_gen_model_id, + image_generation_config_id=ctx.image_generation_config_id, ) agent_query, runtime_context = await _resolve_mention_context( diff --git a/surfsense_backend/app/automations/actions/types.py b/surfsense_backend/app/automations/actions/types.py index 3ee427512..453721a43 100644 --- a/surfsense_backend/app/automations/actions/types.py +++ b/surfsense_backend/app/automations/actions/types.py @@ -23,9 +23,9 @@ class ActionContext: # Captured model snapshot from the automation definition (``definition.models``), # resolved per run instead of the live search space. ``None`` falls back to the # search space's current prefs (defensive; should not happen post-capture). - chat_model_id: int | None = None - image_gen_model_id: int | None = None - vision_model_id: int | None = None + agent_llm_id: int | None = None + image_generation_config_id: int | None = None + vision_llm_config_id: int | None = None ActionHandler = Callable[[dict[str, Any]], Awaitable[Any]] diff --git a/surfsense_backend/app/automations/runtime/executor.py b/surfsense_backend/app/automations/runtime/executor.py index bcdab3940..da249d8e5 100644 --- a/surfsense_backend/app/automations/runtime/executor.py +++ b/surfsense_backend/app/automations/runtime/executor.py @@ -132,7 +132,9 @@ def _build_action_ctx( step_id=step.step_id, search_space_id=automation.search_space_id, creator_user_id=automation.created_by_user_id, - chat_model_id=models.chat_model_id if models else None, - image_gen_model_id=models.image_gen_model_id if models else None, - vision_model_id=models.vision_model_id if models else None, + agent_llm_id=models.agent_llm_id if models else None, + image_generation_config_id=( + models.image_generation_config_id if models else None + ), + vision_llm_config_id=models.vision_llm_config_id if models else None, ) diff --git a/surfsense_backend/app/automations/schemas/definition/envelope.py b/surfsense_backend/app/automations/schemas/definition/envelope.py index 787534d4a..7ca55b1ce 100644 --- a/surfsense_backend/app/automations/schemas/definition/envelope.py +++ b/surfsense_backend/app/automations/schemas/definition/envelope.py @@ -14,16 +14,16 @@ from .trigger_spec import TriggerSpec class AutomationModels(BaseModel): """Captured model profile for an automation. - Snapshotted from the search space's model roles at create time so runs are - insulated from later chat/search-space model changes. Model-id conventions + Snapshotted from the search space's preferences at create time so runs are + insulated from later chat/search-space model changes. Config-id conventions match the shared scheme (``0`` Auto, ``< 0`` global, ``> 0`` BYOK). """ model_config = ConfigDict(extra="forbid") - chat_model_id: int = 0 - image_gen_model_id: int = 0 - vision_model_id: int = 0 + agent_llm_id: int = 0 + image_generation_config_id: int = 0 + vision_llm_config_id: int = 0 class AutomationDefinition(BaseModel): diff --git a/surfsense_backend/app/automations/services/automation.py b/surfsense_backend/app/automations/services/automation.py index 1d371c35d..4227161e2 100644 --- a/surfsense_backend/app/automations/services/automation.py +++ b/surfsense_backend/app/automations/services/automation.py @@ -57,9 +57,9 @@ class AutomationService: else: search_space = await self._assert_models_billable(payload.search_space_id) payload.definition.models = AutomationModels( - chat_model_id=search_space.chat_model_id or 0, - image_gen_model_id=search_space.image_gen_model_id or 0, - vision_model_id=search_space.vision_model_id or 0, + agent_llm_id=search_space.agent_llm_id or 0, + image_generation_config_id=search_space.image_generation_config_id or 0, + vision_llm_config_id=search_space.vision_llm_config_id or 0, ) automation = Automation( @@ -225,9 +225,9 @@ class AutomationService: """ try: assert_models_billable( - chat_model_id=models.chat_model_id, - image_gen_model_id=models.image_gen_model_id, - vision_model_id=models.vision_model_id, + agent_llm_id=models.agent_llm_id, + image_generation_config_id=models.image_generation_config_id, + vision_llm_config_id=models.vision_llm_config_id, ) except AutomationModelPolicyError as exc: raise HTTPException(status_code=422, detail=str(exc)) from exc diff --git a/surfsense_backend/app/automations/services/model_policy.py b/surfsense_backend/app/automations/services/model_policy.py index e18264246..7e3e46b61 100644 --- a/surfsense_backend/app/automations/services/model_policy.py +++ b/surfsense_backend/app/automations/services/model_policy.py @@ -2,11 +2,11 @@ Automations run unattended, so every run must be **billable**: it may only use either a premium global model (``billing_tier == "premium"``) or a user-provided -BYOK model (a positive model id pointing at a per-user/per-space DB row). Free +BYOK model (a positive config id pointing at a per-user/per-space DB row). Free global models and Auto mode are blocked, because Auto can dispatch to a free deployment and free models aren't metered in premium credits. -Model id conventions (shared across chat / image / vision): +Config id conventions (shared across chat / image / vision): - ``id == 0`` → Auto mode (``AUTO_MODE_ID`` / ``IMAGE_GEN_AUTO_MODE_ID`` / ``VISION_AUTO_MODE_ID``). Blocked. - ``id < 0`` → global YAML/OpenRouter config. Allowed only if premium. @@ -24,45 +24,70 @@ from typing import TYPE_CHECKING, Literal if TYPE_CHECKING: from app.db import SearchSpace -ModelKind = Literal["chat", "image", "vision"] +ModelKind = Literal["llm", "image", "vision"] _KIND_LABEL: dict[ModelKind, str] = { - "chat": "chat model", + "llm": "agent LLM", "image": "image generation model", "vision": "vision model", } -def _is_premium_global(model_id: int) -> bool: - """Return True if a negative (global) model id is a premium tier model.""" +def _is_premium_global(kind: ModelKind, config_id: int) -> bool: + """Return True if a negative (global) config id is a premium tier model.""" from app.config import config as app_config - model = next((m for m in app_config.GLOBAL_MODELS if m.get("id") == model_id), None) - if not model: + cfg: dict | None = None + if kind == "llm": + from app.agents.chat.runtime.llm_config import ( + load_global_llm_config_by_id, + ) + + cfg = load_global_llm_config_by_id(config_id) + elif kind == "image": + cfg = next( + ( + c + for c in app_config.GLOBAL_IMAGE_GEN_CONFIGS + if c.get("id") == config_id + ), + None, + ) + else: # vision + cfg = next( + ( + c + for c in app_config.GLOBAL_VISION_LLM_CONFIGS + if c.get("id") == config_id + ), + None, + ) + + if not cfg: return False - return str(model.get("billing_tier", "free")).lower() == "premium" + return str(cfg.get("billing_tier", "free")).lower() == "premium" -def _classify(kind: ModelKind, model_id: int | None) -> tuple[bool, str]: - """Classify a resolved model id as allowed or blocked. +def _classify(kind: ModelKind, config_id: int | None) -> tuple[bool, str]: + """Classify a resolved config id as allowed or blocked. Returns ``(allowed, reason)``; ``reason`` is empty when allowed. """ label = _KIND_LABEL[kind] - if model_id is None or model_id == 0: + if config_id is None or config_id == 0: return ( False, f"The {label} is set to Auto mode. Automations require an explicit " "premium model or your own (BYOK) model so every run is billable.", ) - if model_id > 0: - # Positive id -> user/search-space BYOK model. Always allowed. + if config_id > 0: + # Positive id → user-owned BYOK config. Always allowed. return True, "" - # Negative id -> global model. Allowed only if premium. - if _is_premium_global(model_id): + # Negative id → global config. Allowed only if premium. + if _is_premium_global(kind, config_id): return True, "" return ( @@ -74,27 +99,27 @@ def _classify(kind: ModelKind, model_id: int | None) -> tuple[bool, str]: def get_model_eligibility( *, - chat_model_id: int | None, - image_gen_model_id: int | None, - vision_model_id: int | None, + agent_llm_id: int | None, + image_generation_config_id: int | None, + vision_llm_config_id: int | None, ) -> dict: - """Return ``{"allowed": bool, "violations": [...]}`` for explicit model ids. + """Return ``{"allowed": bool, "violations": [...]}`` for explicit config ids. The ID-based core shared by both the search-space path (creation/eligibility) and the captured-snapshot path (runtime backstop). Each violation is - ``{"kind", "model_id", "reason"}``. + ``{"kind", "config_id", "reason"}``. """ checks: list[tuple[ModelKind, int | None]] = [ - ("chat", chat_model_id), - ("image", image_gen_model_id), - ("vision", vision_model_id), + ("llm", agent_llm_id), + ("image", image_generation_config_id), + ("vision", vision_llm_config_id), ] violations: list[dict] = [] - for kind, model_id in checks: - allowed, reason = _classify(kind, model_id) + for kind, config_id in checks: + allowed, reason = _classify(kind, config_id) if not allowed: - violations.append({"kind": kind, "model_id": model_id, "reason": reason}) + violations.append({"kind": kind, "config_id": config_id, "reason": reason}) return {"allowed": not violations, "violations": violations} @@ -106,9 +131,9 @@ def get_automation_model_eligibility(search_space: SearchSpace) -> dict: wrapper over :func:`get_model_eligibility`. """ return get_model_eligibility( - chat_model_id=search_space.chat_model_id, - image_gen_model_id=search_space.image_gen_model_id, - vision_model_id=search_space.vision_model_id, + agent_llm_id=search_space.agent_llm_id, + image_generation_config_id=search_space.image_generation_config_id, + vision_llm_config_id=search_space.vision_llm_config_id, ) @@ -125,9 +150,9 @@ class AutomationModelPolicyError(Exception): def assert_models_billable( *, - chat_model_id: int | None, - image_gen_model_id: int | None, - vision_model_id: int | None, + agent_llm_id: int | None, + image_generation_config_id: int | None, + vision_llm_config_id: int | None, ) -> None: """Raise :class:`AutomationModelPolicyError` if any explicit id is not billable. @@ -135,9 +160,9 @@ def assert_models_billable( captured model snapshot. """ result = get_model_eligibility( - chat_model_id=chat_model_id, - image_gen_model_id=image_gen_model_id, - vision_model_id=vision_model_id, + agent_llm_id=agent_llm_id, + image_generation_config_id=image_generation_config_id, + vision_llm_config_id=vision_llm_config_id, ) if not result["allowed"]: raise AutomationModelPolicyError(result["violations"]) diff --git a/surfsense_backend/app/celery_app.py b/surfsense_backend/app/celery_app.py index 704c9cf9b..0e852b801 100644 --- a/surfsense_backend/app/celery_app.py +++ b/surfsense_backend/app/celery_app.py @@ -115,12 +115,14 @@ def init_worker(**kwargs): initialize_llm_router, initialize_openrouter_integration, initialize_pricing_registration, + initialize_vision_llm_router, ) initialize_openrouter_integration() initialize_pricing_registration() initialize_llm_router() initialize_image_gen_router() + initialize_vision_llm_router() # Celery configuration, sourced from the central Config singleton @@ -179,8 +181,7 @@ celery_app = Celery( backend=CELERY_RESULT_BACKEND, include=[ "app.tasks.celery_tasks.document_tasks", - "app.podcasts.tasks.draft", - "app.podcasts.tasks.render", + "app.tasks.celery_tasks.podcast_tasks", "app.tasks.celery_tasks.video_presentation_tasks", "app.tasks.celery_tasks.connector_tasks", "app.tasks.celery_tasks.obsidian_tasks", @@ -188,10 +189,7 @@ celery_app = Celery( "app.tasks.celery_tasks.document_reindex_tasks", "app.tasks.celery_tasks.stale_notification_cleanup_task", "app.tasks.celery_tasks.stripe_reconciliation_task", - "app.tasks.celery_tasks.auto_reload_task", "app.tasks.celery_tasks.gateway_tasks", - "app.etl_pipeline.cache.eviction.task", - "app.indexing_pipeline.cache.eviction.task", "app.automations.tasks.execute_run", "app.automations.triggers.builtin.schedule.selector", "app.automations.triggers.builtin.event.selector", @@ -283,9 +281,16 @@ celery_app.conf.beat_schedule = { "expires": 60, # Task expires after 60 seconds if not picked up }, }, - # Reconcile Stripe credit purchases that were paid but remained pending - "reconcile-pending-stripe-credit-purchases": { - "task": "reconcile_pending_stripe_credit_purchases", + # Reconcile Stripe purchases that were paid but remained pending + "reconcile-pending-stripe-page-purchases": { + "task": "reconcile_pending_stripe_page_purchases", + "schedule": crontab(**stripe_reconciliation_schedule_params), + "options": { + "expires": 60, + }, + }, + "reconcile-pending-stripe-token-purchases": { + "task": "reconcile_pending_stripe_token_purchases", "schedule": crontab(**stripe_reconciliation_schedule_params), "options": { "expires": 60, @@ -306,18 +311,6 @@ celery_app.conf.beat_schedule = { "schedule": crontab(hour="3", minute="17"), "options": {"expires": 600}, }, - # Prune the ETL parse cache (TTL + size budget) once daily, off-peak. - "evict-etl-cache": { - "task": "evict_etl_cache", - "schedule": crontab(hour="4", minute="0"), - "options": {"expires": 600}, - }, - # Prune the embedding cache (chunk+embedding sets) once daily, off-peak. - "evict-embedding-cache": { - "task": "evict_embedding_cache", - "schedule": crontab(hour="4", minute="30"), - "options": {"expires": 600}, - }, # Fire due automation schedule triggers (Beat entry owned by the schedule # trigger; see app.automations.triggers.builtin.schedule.source). **SCHEDULE_BEAT_SCHEDULE, diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index 63be54654..75af17d11 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -78,7 +78,8 @@ def load_global_llm_configs(): # stamps) never leak into the cached YAML structure. configs = copy.deepcopy(data.get("global_llm_configs", [])) - # Lazy import keeps the `app.config` -> `app.services` edge one-way. + # Lazy import keeps the `app.config` -> `app.services` edge one-way + # and matches the `provider_api_base` pattern used elsewhere. from app.services.provider_capabilities import derive_supports_image_input seen_slugs: dict[str, int] = {} @@ -103,7 +104,7 @@ def load_global_llm_configs(): else None ) cfg["supports_image_input"] = derive_supports_image_input( - provider=cfg.get("provider") or cfg.get("litellm_provider"), + provider=cfg.get("provider"), model_name=cfg.get("model_name"), base_model=base_model, custom_provider=cfg.get("custom_provider"), @@ -119,10 +120,10 @@ def load_global_llm_configs(): else: seen_slugs[slug] = cfg.get("id", 0) - # Stamp Auto ranking metadata. YAML configs are always + # Stamp Auto (Fastest) ranking metadata. YAML configs are always # Tier A — operator-curated, locked first when premium-eligible. # The OpenRouter refresh tick later re-stamps health for any cfg - # whose provider == "openrouter" via _enrich_health. + # whose provider == "OPENROUTER" via _enrich_health. try: from app.services.quality_score import static_score_yaml @@ -132,7 +133,7 @@ def load_global_llm_configs(): cfg["quality_score_static"] = static_q cfg["quality_score"] = static_q cfg["quality_score_health"] = None - # YAML cfgs whose provider is openrouter are also subject + # YAML cfgs whose provider is OPENROUTER are also subject # to health gating against their own /endpoints data — a # hand-picked dead OR model is still dead. _enrich_health # re-stamps health_gated for them on the next refresh tick. @@ -210,6 +211,42 @@ def load_global_image_gen_configs(): return [] +def load_global_vision_llm_configs(): + data = _global_config_data() + if not data: + return [] + + try: + configs = copy.deepcopy(data.get("global_vision_llm_configs", []) or []) + for cfg in configs: + if isinstance(cfg, dict): + cfg.setdefault("billing_tier", "free") + return configs + except Exception as e: + print(f"Warning: Failed to load global vision LLM configs: {e}") + return [] + + +def load_vision_llm_router_settings(): + default_settings = { + "routing_strategy": "usage-based-routing", + "num_retries": 3, + "allowed_fails": 3, + "cooldown_time": 60, + } + + data = _global_config_data() + if not data: + return default_settings + + try: + settings = data.get("vision_llm_router_settings", {}) + return {**default_settings, **settings} + except Exception as e: + print(f"Warning: Failed to load vision LLM router settings: {e}") + return default_settings + + def load_image_gen_router_settings(): """ Load router settings for image generation Auto mode from YAML file. @@ -326,8 +363,8 @@ def initialize_openrouter_integration(): else: print("Info: OpenRouter integration enabled but no models fetched") - # Image generation emissions reuse the catalogue already cached by - # ``service.initialize`` + # Image generation + vision LLM emissions are opt-in (issue L). + # Both reuse the catalogue already cached by ``service.initialize`` # so we don't make additional network calls here. if settings.get("image_generation_enabled"): try: @@ -341,26 +378,21 @@ def initialize_openrouter_integration(): except Exception as e: print(f"Warning: Failed to inject OpenRouter image-gen configs: {e}") - refresh_global_model_catalog() + if settings.get("vision_enabled"): + try: + vision_configs = service.get_vision_llm_configs() + if vision_configs: + config.GLOBAL_VISION_LLM_CONFIGS.extend(vision_configs) + print( + f"Info: OpenRouter integration added {len(vision_configs)} " + f"vision LLM models" + ) + except Exception as e: + print(f"Warning: Failed to inject OpenRouter vision-LLM configs: {e}") except Exception as e: print(f"Warning: Failed to initialize OpenRouter integration: {e}") -def materialize_global_configs(): - from app.services.global_model_catalog import materialize_global_model_catalog - - return materialize_global_model_catalog( - chat_configs=getattr(config, "GLOBAL_LLM_CONFIGS", []), - image_configs=getattr(config, "GLOBAL_IMAGE_GEN_CONFIGS", []), - ) - - -def refresh_global_model_catalog(): - connections, models = materialize_global_configs() - config.GLOBAL_CONNECTIONS = connections - config.GLOBAL_MODELS = models - - def initialize_pricing_registration(): """ Teach LiteLLM the per-token cost of every deployment in @@ -398,10 +430,7 @@ def initialize_llm_router(): router_settings = config.ROUTER_SETTINGS if not all_configs: - print( - "Info: No global LLM configs found; global Auto pool is unavailable. " - "Auto can still use enabled BYOK models." - ) + print("Info: No global LLM configs found, Auto mode will not be available") return try: @@ -446,6 +475,32 @@ def initialize_image_gen_router(): print(f"Warning: Failed to initialize Image Generation Router: {e}") +def initialize_vision_llm_router(): + vision_configs = load_global_vision_llm_configs() + # Reuse the router settings already parsed at Config construction. The + # *configs* list is intentionally re-read from YAML (it must exclude the + # OpenRouter-injected dynamic models held in config.GLOBAL_VISION_LLM_CONFIGS). + router_settings = config.VISION_LLM_ROUTER_SETTINGS + + if not vision_configs: + print( + "Info: No global vision LLM configs found, " + "Vision LLM Auto mode will not be available" + ) + return + + try: + from app.services.vision_llm_router_service import VisionLLMRouterService + + VisionLLMRouterService.initialize(vision_configs, router_settings) + print( + f"Info: Vision LLM Router initialized with {len(vision_configs)} models " + f"(strategy: {router_settings.get('routing_strategy', 'usage-based-routing')})" + ) + except Exception as e: + print(f"Warning: Failed to initialize Vision LLM Router: {e}") + + class Config: # Check if ffmpeg is installed if not is_ffmpeg_installed(): @@ -486,28 +541,6 @@ class Config: # Database DATABASE_URL = os.getenv("DATABASE_URL") - # When TRUE (default) the app ensures extensions/tables/indexes exist on - # startup. Set FALSE in environments where schema is owned exclusively by - # Alembic migrations to skip all boot-time DDL. - DB_BOOTSTRAP_ON_STARTUP = ( - os.getenv("DB_BOOTSTRAP_ON_STARTUP", "TRUE").upper() == "TRUE" - ) - # Per-session lock_timeout (ms) applied to boot-time DDL so a contended - # CREATE INDEX / CREATE TABLE fails fast instead of hanging the FastAPI - # lifespan forever behind another transaction's lock. - DB_DDL_LOCK_TIMEOUT_MS = int(os.getenv("DB_DDL_LOCK_TIMEOUT_MS", "5000")) - # Global idle_in_transaction_session_timeout (ms) applied to every pooled - # connection so an abandoned "idle in transaction" session can't wedge the - # database indefinitely. 0 disables. Only applied to asyncpg connections. - DB_IDLE_IN_TX_TIMEOUT_MS = int(os.getenv("DB_IDLE_IN_TX_TIMEOUT_MS", "900000")) - # Same protection for the separate Celery worker engine, where long-running - # ingestion/podcast/video tasks live. Kept higher than the web default so a - # legitimate per-document embed window is never reaped: if a task hasn't - # touched the DB in 60 min it's treated as orphaned and dropped. 0 disables. - DB_CELERY_IDLE_IN_TX_TIMEOUT_MS = int( - os.getenv("DB_CELERY_IDLE_IN_TX_TIMEOUT_MS", "3600000") - ) - # Celery / Redis # Redis (single endpoint for Celery broker, result backend, and app cache). # Legacy CELERY_BROKER_URL / CELERY_RESULT_BACKEND / REDIS_APP_URL still @@ -557,15 +590,14 @@ class Config: # Platform web search (SearXNG) SEARXNG_DEFAULT_HOST = os.getenv("SEARXNG_DEFAULT_HOST") - SURFSENSE_PUBLIC_URL = os.getenv("SURFSENSE_PUBLIC_URL") - NEXT_FRONTEND_URL = os.getenv("NEXT_FRONTEND_URL") or SURFSENSE_PUBLIC_URL + NEXT_FRONTEND_URL = os.getenv("NEXT_FRONTEND_URL") # Backend URL to override the http to https in the OAuth redirect URI - BACKEND_URL = os.getenv("BACKEND_URL") or SURFSENSE_PUBLIC_URL + BACKEND_URL = os.getenv("BACKEND_URL") - # Messaging gateway + # Messaging gateway (Telegram v1) # Global master switch: when FALSE, no gateway supervisors/workers start and all - # gated gateway HTTP routes return 404, regardless of the per-channel flags below. - GATEWAY_ENABLED = os.getenv("GATEWAY_ENABLED", "FALSE").upper() == "TRUE" + # gateway HTTP routes return 404, regardless of the per-channel flags below. + GATEWAY_ENABLED = os.getenv("GATEWAY_ENABLED", "TRUE").upper() == "TRUE" TELEGRAM_SHARED_BOT_TOKEN = os.getenv("TELEGRAM_SHARED_BOT_TOKEN") TELEGRAM_SHARED_BOT_USERNAME = os.getenv("TELEGRAM_SHARED_BOT_USERNAME") TELEGRAM_WEBHOOK_SECRET = os.getenv("TELEGRAM_WEBHOOK_SECRET") @@ -608,9 +640,14 @@ class Config: ) GATEWAY_DISCORD_REDIRECT_URI = os.getenv("GATEWAY_DISCORD_REDIRECT_URI") - # Stripe checkout (shared secrets for the unified credit wallet) + # Stripe checkout for pay-as-you-go page packs STRIPE_SECRET_KEY = os.getenv("STRIPE_SECRET_KEY") STRIPE_WEBHOOK_SECRET = os.getenv("STRIPE_WEBHOOK_SECRET") + STRIPE_PRICE_ID = os.getenv("STRIPE_PRICE_ID") + STRIPE_PAGES_PER_UNIT = int(os.getenv("STRIPE_PAGES_PER_UNIT", "1000")) + STRIPE_PAGE_BUYING_ENABLED = ( + os.getenv("STRIPE_PAGE_BUYING_ENABLED", "TRUE").upper() == "TRUE" + ) STRIPE_RECONCILIATION_LOOKBACK_MINUTES = int( os.getenv("STRIPE_RECONCILIATION_LOOKBACK_MINUTES", "10") ) @@ -618,56 +655,27 @@ class Config: os.getenv("STRIPE_RECONCILIATION_BATCH_SIZE", "100") ) - # Unified credit wallet (micro-USD) settings. + # Premium credit (micro-USD) quota settings. # - # Storage unit is integer micro-USD (1_000_000 = $1.00). A single - # ``credit_micros_balance`` funds both ETL page processing and premium - # model calls. New users start with ``DEFAULT_CREDIT_MICROS_BALANCE`` - # ($5 by default). - # - # Legacy env names (``PREMIUM_CREDIT_MICROS_LIMIT`` / ``PREMIUM_TOKEN_LIMIT``, - # ``STRIPE_PREMIUM_TOKEN_PRICE_ID``, ``STRIPE_CREDIT_MICROS_PER_UNIT`` / - # ``STRIPE_TOKENS_PER_UNIT``, ``STRIPE_TOKEN_BUYING_ENABLED``) are still - # honoured as fall-backs for one release; deprecation warnings fire below. - DEFAULT_CREDIT_MICROS_BALANCE = int( - os.getenv("DEFAULT_CREDIT_MICROS_BALANCE") - or os.getenv("PREMIUM_CREDIT_MICROS_LIMIT") + # Storage unit is integer micro-USD (1_000_000 = $1.00). The legacy + # ``PREMIUM_TOKEN_LIMIT`` and ``STRIPE_TOKENS_PER_UNIT`` env vars are + # still honoured for one release as fall-back values — the prior + # $1-per-1M-tokens Stripe price means every existing value maps 1:1 + # to micros, so operators upgrading without changing their .env still + # get correct behaviour. A startup deprecation warning fires below if + # they're set. + PREMIUM_CREDIT_MICROS_LIMIT = int( + os.getenv("PREMIUM_CREDIT_MICROS_LIMIT") or os.getenv("PREMIUM_TOKEN_LIMIT", "5000000") ) - STRIPE_CREDIT_PRICE_ID = os.getenv("STRIPE_CREDIT_PRICE_ID") or os.getenv( - "STRIPE_PREMIUM_TOKEN_PRICE_ID" - ) + STRIPE_PREMIUM_TOKEN_PRICE_ID = os.getenv("STRIPE_PREMIUM_TOKEN_PRICE_ID") STRIPE_CREDIT_MICROS_PER_UNIT = int( os.getenv("STRIPE_CREDIT_MICROS_PER_UNIT") or os.getenv("STRIPE_TOKENS_PER_UNIT", "1000000") ) - STRIPE_CREDIT_BUYING_ENABLED = ( - os.getenv("STRIPE_CREDIT_BUYING_ENABLED") - or os.getenv("STRIPE_TOKEN_BUYING_ENABLED", "FALSE") - ).upper() == "TRUE" - - # ETL page processing debits the credit wallet only when enabled. Defaults - # to FALSE so self-hosted / OSS installs keep effectively-free ETL; hosted - # deployments set this TRUE. 1 page == ``MICROS_PER_PAGE`` micro-USD. - ETL_CREDIT_BILLING_ENABLED = ( - os.getenv("ETL_CREDIT_BILLING_ENABLED", "FALSE").upper() == "TRUE" + STRIPE_TOKEN_BUYING_ENABLED = ( + os.getenv("STRIPE_TOKEN_BUYING_ENABLED", "FALSE").upper() == "TRUE" ) - MICROS_PER_PAGE = int(os.getenv("MICROS_PER_PAGE", "1000")) - - # Low-balance WARNING threshold (micro-USD). Surfaced by the quota service - # so the UI can nudge the user to top up / enable auto-reload. $0.50. - CREDIT_LOW_BALANCE_WARNING_MICROS = int( - os.getenv("CREDIT_LOW_BALANCE_WARNING_MICROS", "500000") - ) - - # Auto-reload (off-session Stripe top-up) feature flag and guards. - AUTO_RELOAD_ENABLED = os.getenv("AUTO_RELOAD_ENABLED", "FALSE").upper() == "TRUE" - # Minimum configurable reload amount (micro-USD). $1.00 to match pack pricing. - AUTO_RELOAD_MIN_AMOUNT_MICROS = int( - os.getenv("AUTO_RELOAD_MIN_AMOUNT_MICROS", "1000000") - ) - # Cooldown so a burst of debits can't fire multiple charges (minutes). - AUTO_RELOAD_COOLDOWN_MINUTES = int(os.getenv("AUTO_RELOAD_COOLDOWN_MINUTES", "10")) # Safety ceiling on the per-call premium reservation. ``stream_new_chat`` # estimates an upper-bound cost from ``litellm.get_model_info`` x the @@ -677,13 +685,14 @@ class Config: # reserve_tokens ≈ $0.36) with headroom. QUOTA_MAX_RESERVE_MICROS = int(os.getenv("QUOTA_MAX_RESERVE_MICROS", "1000000")) - if ( - os.getenv("PREMIUM_TOKEN_LIMIT") or os.getenv("PREMIUM_CREDIT_MICROS_LIMIT") - ) and not os.getenv("DEFAULT_CREDIT_MICROS_BALANCE"): + if os.getenv("PREMIUM_TOKEN_LIMIT") and not os.getenv( + "PREMIUM_CREDIT_MICROS_LIMIT" + ): print( - "Warning: PREMIUM_TOKEN_LIMIT / PREMIUM_CREDIT_MICROS_LIMIT are " - "deprecated; rename to DEFAULT_CREDIT_MICROS_BALANCE. The old keys " - "will be removed in a future release." + "Warning: PREMIUM_TOKEN_LIMIT is deprecated; rename to " + "PREMIUM_CREDIT_MICROS_LIMIT (1:1 numerical mapping under the " + "current Stripe price). The old key will be removed in a " + "future release." ) if os.getenv("STRIPE_TOKENS_PER_UNIT") and not os.getenv( "STRIPE_CREDIT_MICROS_PER_UNIT" @@ -693,22 +702,6 @@ class Config: "STRIPE_CREDIT_MICROS_PER_UNIT (1:1 numerical mapping). " "The old key will be removed in a future release." ) - if os.getenv("STRIPE_PREMIUM_TOKEN_PRICE_ID") and not os.getenv( - "STRIPE_CREDIT_PRICE_ID" - ): - print( - "Warning: STRIPE_PREMIUM_TOKEN_PRICE_ID is deprecated; rename to " - "STRIPE_CREDIT_PRICE_ID. The old key will be removed in a future " - "release." - ) - if os.getenv("STRIPE_TOKEN_BUYING_ENABLED") and not os.getenv( - "STRIPE_CREDIT_BUYING_ENABLED" - ): - print( - "Warning: STRIPE_TOKEN_BUYING_ENABLED is deprecated; rename to " - "STRIPE_CREDIT_BUYING_ENABLED. The old key will be removed in a " - "future release." - ) # Anonymous / no-login mode settings NOLOGIN_MODE_ENABLED = os.getenv("NOLOGIN_MODE_ENABLED", "FALSE").upper() == "TRUE" @@ -730,7 +723,7 @@ class Config: os.getenv("QUOTA_DEFAULT_IMAGE_RESERVE_MICROS", "50000") ) - # Per-podcast reservation (in micro-USD). One chat model call generating + # Per-podcast reservation (in micro-USD). One agent LLM call generating # a transcript, typically 5k-20k completion tokens. $0.20 covers a long # premium-model run. Tune via env. QUOTA_DEFAULT_PODCAST_RESERVE_MICROS = int( @@ -836,13 +829,6 @@ class Config: # LLM instances are now managed per-user through the LLMConfig system # Legacy environment variables removed in favor of user-specific configurations - # True when an operator-provided global_llm_config.yaml is present. - # Used to gate the per-search-space LLM onboarding flow: when a global - # config file exists, search spaces inherit it and onboarding is skipped. - GLOBAL_LLM_CONFIG_FILE_EXISTS = ( - BASE_DIR / "app" / "config" / "global_llm_config.yaml" - ).exists() - # Global LLM Configurations (optional) # Load from global_llm_config.yaml if available # These can be used as default options for users @@ -857,17 +843,11 @@ class Config: # Router settings for Image Generation Auto mode IMAGE_GEN_ROUTER_SETTINGS = load_image_gen_router_settings() - # Virtual GLOBAL connection/model catalog. This is server-only metadata - # derived from global_llm_config.yaml; GLOBAL keys are not stored in DB. - from app.services.global_model_catalog import ( - materialize_global_model_catalog as _materialize_global_model_catalog, - ) + # Global Vision LLM Configurations (optional) + GLOBAL_VISION_LLM_CONFIGS = load_global_vision_llm_configs() - GLOBAL_CONNECTIONS, GLOBAL_MODELS = _materialize_global_model_catalog( - chat_configs=GLOBAL_LLM_CONFIGS, - image_configs=GLOBAL_IMAGE_GEN_CONFIGS, - ) - del _materialize_global_model_catalog + # Router settings for Vision LLM Auto mode + VISION_LLM_ROUTER_SETTINGS = load_vision_llm_router_settings() # OpenRouter Integration settings (optional) OPENROUTER_INTEGRATION_SETTINGS = load_openrouter_integration_settings() @@ -923,6 +903,9 @@ class Config: # ETL Service ETL_SERVICE = os.getenv("ETL_SERVICE") + # Pages limit for ETL services (default to very high number for OSS unlimited usage) + PAGES_LIMIT = int(os.getenv("PAGES_LIMIT", "999999999")) + if ETL_SERVICE == "UNSTRUCTURED": # Unstructured API Key UNSTRUCTURED_API_KEY = os.getenv("UNSTRUCTURED_API_KEY") @@ -933,47 +916,6 @@ class Config: AZURE_DI_ENDPOINT = os.getenv("AZURE_DI_ENDPOINT") AZURE_DI_KEY = os.getenv("AZURE_DI_KEY") - # ETL parse cache: reuse parser output for identical bytes across workspaces. - ETL_CACHE_ENABLED = ( - os.getenv("ETL_CACHE_ENABLED", "false").strip().lower() == "true" - ) - # Bump to invalidate every cached entry after a parser/behaviour change. - ETL_CACHE_PARSER_VERSION = int(os.getenv("ETL_CACHE_PARSER_VERSION", "1")) - ETL_CACHE_TTL_DAYS = int(os.getenv("ETL_CACHE_TTL_DAYS", "90")) - ETL_CACHE_MAX_TOTAL_MB = int(os.getenv("ETL_CACHE_MAX_TOTAL_MB", "5120")) - ETL_CACHE_EVICTION_BATCH = int(os.getenv("ETL_CACHE_EVICTION_BATCH", "500")) - # Optional dedicated blob storage; unset reuses the main file_storage backend. - ETL_CACHE_STORAGE_BACKEND = os.getenv("ETL_CACHE_STORAGE_BACKEND") - ETL_CACHE_STORAGE_CONTAINER = os.getenv("ETL_CACHE_STORAGE_CONTAINER") - ETL_CACHE_STORAGE_LOCAL_PATH = os.getenv("ETL_CACHE_STORAGE_LOCAL_PATH") - - # Embedding cache: reuse chunk+embedding output for identical markdown across - # workspaces. Blobs share the ETL_CACHE_STORAGE_* backend. - EMBEDDING_CACHE_ENABLED = ( - os.getenv("EMBEDDING_CACHE_ENABLED", "false").strip().lower() == "true" - ) - # Bump to invalidate every cached embedding set after a chunker change. - EMBEDDING_CACHE_CHUNKER_VERSION = int( - os.getenv("EMBEDDING_CACHE_CHUNKER_VERSION", "1") - ) - EMBEDDING_CACHE_TTL_DAYS = int(os.getenv("EMBEDDING_CACHE_TTL_DAYS", "90")) - EMBEDDING_CACHE_MAX_TOTAL_MB = int( - os.getenv("EMBEDDING_CACHE_MAX_TOTAL_MB", "5120") - ) - EMBEDDING_CACHE_EVICTION_BATCH = int( - os.getenv("EMBEDDING_CACHE_EVICTION_BATCH", "500") - ) - - # Incremental re-indexing: on document edits, keep chunk rows whose text is - # unchanged (reusing their embeddings) and embed only new/changed chunks. - # Kill switch -- disabling falls back to delete-all + full re-embed. - CHUNK_RECONCILE_ENABLED = ( - os.getenv("CHUNK_RECONCILE_ENABLED", "true").strip().lower() == "true" - ) - INDEXING_CHUNK_INSERT_BATCH_SIZE = int( - os.getenv("INDEXING_CHUNK_INSERT_BATCH_SIZE", "200") - ) - # Proxy provider selection. Maps to a ProxyProvider implementation registered # in app/utils/proxy/registry.py. Add new vendors there and switch via this var. PROXY_PROVIDER = os.getenv("PROXY_PROVIDER", "anonymous_proxies") diff --git a/surfsense_backend/app/config/global_llm_config.example.yaml b/surfsense_backend/app/config/global_llm_config.example.yaml index 329d96e37..1c09a91ac 100644 --- a/surfsense_backend/app/config/global_llm_config.example.yaml +++ b/surfsense_backend/app/config/global_llm_config.example.yaml @@ -1,236 +1,362 @@ # Global LLM Configuration # # SETUP INSTRUCTIONS: -# 1. Copy this file to global_llm_config.yaml. -# 2. Replace placeholder credentials, endpoints, deployment names, and pricing -# with values from your own provider accounts. +# 1. For production: Copy this file to global_llm_config.yaml and add your real API keys +# 2. For testing: The system will use this example file automatically if global_llm_config.yaml doesn't exist # -# This file is intentionally safe to commit. Do not put real API keys in this -# example file. +# NOTE: The example API keys below are placeholders and won't work. +# Replace them with your actual API keys to enable global configurations. # -# These YAML entries are materialized at startup as server-owned GLOBAL -# connections and models: +# These configurations will be available to all users as a convenient option +# Users can choose to use these global configs or add their own # -# global_llm_configs -> GLOBAL chat models -# global_image_generation_configs -> GLOBAL image generation models +# AUTO MODE (Recommended): +# - Auto mode (ID: 0) uses LiteLLM Router to automatically load balance across all global configs +# - This helps avoid rate limits by distributing requests across multiple providers +# - New users are automatically assigned Auto mode by default +# - Configure router_settings below to customize the load balancing behavior # -# Do not add global_connections or global_models sections here. They are -# runtime-derived metadata exposed through the model-connections APIs. -# -# Static config shape: -# - Connection fields: provider, api_key, api_base, api_version -# - Model fields: model_name, billing_tier, rpm/tpm, capabilities, litellm_params -# - Public no-login SEO metadata: seo_title, seo_description -# - Prompt defaults: system_instructions, use_default_system_instructions, -# citations_enabled -# -# Provider notes: -# - Use the canonical provider field. -# - For Azure, use the bare deployment name in model_name, for example -# model_name: "gpt-5.1". The resolver prefixes the LiteLLM model string from -# provider: "azure". -# -# GLOBAL ID namespace: -# - ID 0 is reserved for Auto mode. -# - Negative IDs are server-owned GLOBAL models. -# - Positive IDs are user/BYOK database models. -# - Keep static IDs unique across chat and image generation. -# - Suggested static ranges: chat -1..-999, image -2001..-2999. -# - Vision is not a separate config/table. Chat models that accept images use -# supports_image_input: true. +# Structure matches NewLLMConfig: +# - Model configuration (provider, model_name, api_key, etc.) +# - Prompt configuration (system_instructions, citations_enabled) # # COST-BASED PREMIUM CREDITS: -# Each premium model bills the user's USD-credit balance based on provider cost -# reported by LiteLLM. For custom Azure deployments or any model LiteLLM does -# not know, declare per-token costs inline: +# Each premium config bills the user's USD-credit balance based on the +# actual provider cost reported by LiteLLM. For models LiteLLM already +# knows (most OpenAI/Anthropic/etc. names) you don't need to do anything. +# For custom Azure deployment names (e.g. an in-house "gpt-5.4" deployment) +# or any model LiteLLM doesn't have in its built-in pricing table, declare +# per-token costs inline so they bill correctly: # # litellm_params: -# base_model: "my-custom-deployment" -# # USD per token; 0.00000125 == $1.25 per million input tokens. -# input_cost_per_token: 0.00000125 -# output_cost_per_token: 0.00001 +# base_model: "my-custom-azure-deploy" +# # USD per token; e.g. 0.000003 == $3.00 per million input tokens +# input_cost_per_token: 0.000003 +# output_cost_per_token: 0.000015 # -# OpenRouter dynamic chat models pull pricing automatically from OpenRouter's -# API. Models without resolvable pricing debit $0 and log a warning. +# OpenRouter dynamic models pull pricing automatically from OpenRouter's +# API — no inline declaration needed. Models without resolvable pricing +# debit $0 from the user's balance and log a WARNING. -# ============================================================================= -# Chat Auto Mode Router Settings -# ============================================================================= -# These settings control how the LiteLLM Router distributes Auto-mode requests -# across curated router-eligible GLOBAL chat deployments. +# Router Settings for Auto Mode +# These settings control how the LiteLLM Router distributes requests across models router_settings: # Routing strategy options: - # - "usage-based-routing": Routes to deployment with lowest current usage. - # - "simple-shuffle": Random distribution with optional RPM/TPM weighting. - # - "least-busy": Routes to least busy deployment. - # - "latency-based-routing": Routes based on response latency. + # - "usage-based-routing": Routes to deployment with lowest current usage (recommended for rate limits) + # - "simple-shuffle": Random distribution with optional RPM/TPM weighting + # - "least-busy": Routes to least busy deployment + # - "latency-based-routing": Routes based on response latency routing_strategy: "usage-based-routing" + + # Number of retries before failing num_retries: 3 + + # Number of failures allowed before cooling down a deployment allowed_fails: 3 + + # Cooldown time in seconds after allowed_fails is exceeded cooldown_time: 60 - # Optional fallback map: - # fallbacks: - # - {"azure/gpt-5.1": ["azure/gpt-5.4-mini"]} -# ============================================================================= -# Static GLOBAL Chat Models -# ============================================================================= + # Fallback models (optional) - when primary fails, try these + # Format: [{"primary_model": ["fallback1", "fallback2"]}] + # fallbacks: [] + global_llm_configs: - # Premium Azure chat model with image input support and explicit custom - # pricing. This is the current shape to use for hosted GPT 5.x deployments. + # Example: OpenAI GPT-4 Turbo with citations enabled - id: -1 - name: "Azure GPT 5.1" - billing_tier: "premium" - anonymous_enabled: false - seo_enabled: false - seo_slug: "azure-gpt-5-1" - quota_reserve_tokens: 4000 - provider: "azure" - model_name: "gpt-5.1" - supports_image_input: true - supports_tools: true - max_input_tokens: 400000 - api_key: "your-azure-api-key-here" - api_base: "https://your-resource.openai.azure.com" - # api_version is optional. Include it if your Azure deployment requires a - # specific API version. - # api_version: "2025-04-01-preview" - rpm: 47500 - tpm: 14750000 - litellm_params: - max_tokens: 16384 - base_model: "gpt-5.1" - input_cost_per_token: 0.00000125 - output_cost_per_token: 0.00001 - system_instructions: "" - use_default_system_instructions: true - citations_enabled: true - - # Larger premium chat model. If your provider prices long-context traffic - # differently, choose a conservative flat price or document the limitation - # next to the inline pricing. - - id: -2 - name: "Azure GPT 5.4" - billing_tier: "premium" - anonymous_enabled: false - seo_enabled: false - seo_slug: "azure-gpt-5-4" - quota_reserve_tokens: 4000 - provider: "azure" - model_name: "gpt-5.4" - supports_image_input: true - supports_tools: true - max_input_tokens: 400000 - api_key: "your-azure-api-key-here" - api_base: "https://your-resource.openai.azure.com" - rpm: 150000 - tpm: 15000000 - litellm_params: - max_tokens: 16384 - base_model: "gpt-5.4" - input_cost_per_token: 0.0000025 - output_cost_per_token: 0.000015 - system_instructions: "" - use_default_system_instructions: true - citations_enabled: true - - # Free/no-login hosted model. Free models are visible to users when - # anonymous_enabled/seo_enabled are true but do not debit premium credits. - - id: -3 - name: "Azure GPT 5.4 Mini" + name: "Global GPT-4 Turbo" + description: "OpenAI's GPT-4 Turbo with default prompts and citations" billing_tier: "free" anonymous_enabled: true seo_enabled: true - seo_slug: "gpt-5-4-mini-no-login" - seo_title: "Free GPT 5.4 Mini Chat" - seo_description: "Chat with a hosted GPT 5.4 Mini model without signing in." + seo_slug: "gpt-4-turbo" quota_reserve_tokens: 4000 - provider: "azure" - model_name: "gpt-5.4-mini" - supports_image_input: false - supports_tools: true - max_input_tokens: 128000 - api_key: "your-azure-api-key-here" - api_base: "https://your-resource.openai.azure.com" - rpm: 15000 - tpm: 15000000 + provider: "OPENAI" + model_name: "gpt-4-turbo-preview" + api_key: "sk-your-openai-api-key-here" + api_base: "" + # Rate limits for load balancing (requests/tokens per minute) + rpm: 500 # Requests per minute + tpm: 100000 # Tokens per minute litellm_params: - max_tokens: 16384 - base_model: "gpt-5.4-mini" + temperature: 0.7 + max_tokens: 4000 + # Prompt Configuration + system_instructions: "" # Empty = use default SURFSENSE_SYSTEM_INSTRUCTIONS + use_default_system_instructions: true + citations_enabled: true + + # Example: Anthropic Claude 3 Opus + - id: -2 + name: "Global Claude 3 Opus" + description: "Anthropic's most capable model with citations" + billing_tier: "free" + anonymous_enabled: true + seo_enabled: true + seo_slug: "claude-3-opus" + quota_reserve_tokens: 4000 + provider: "ANTHROPIC" + model_name: "claude-3-opus-20240229" + api_key: "sk-ant-your-anthropic-api-key-here" + api_base: "" + rpm: 1000 + tpm: 100000 + litellm_params: + temperature: 0.7 + max_tokens: 4000 system_instructions: "" use_default_system_instructions: true citations_enabled: true - # Planner LLM. This is operator-only and is not shown in the user-facing - # model selector. Only one global_llm_configs entry should set is_planner. + # Example: Fast model - GPT-3.5 Turbo (citations disabled for speed) + - id: -3 + name: "Global GPT-3.5 Turbo (Fast)" + description: "Fast responses without citations for quick queries" + billing_tier: "free" + anonymous_enabled: true + seo_enabled: true + seo_slug: "gpt-3.5-turbo-fast" + quota_reserve_tokens: 2000 + provider: "OPENAI" + model_name: "gpt-3.5-turbo" + api_key: "sk-your-openai-api-key-here" + api_base: "" + rpm: 3500 # GPT-3.5 has higher rate limits + tpm: 200000 + litellm_params: + temperature: 0.5 + max_tokens: 2000 + system_instructions: "" + use_default_system_instructions: true + citations_enabled: false # Disabled for faster responses + + # Example: Chinese LLM - DeepSeek with custom instructions + - id: -4 + name: "Global DeepSeek Chat (Chinese)" + description: "DeepSeek optimized for Chinese language responses" + billing_tier: "free" + anonymous_enabled: true + seo_enabled: true + seo_slug: "deepseek-chat-chinese" + quota_reserve_tokens: 4000 + provider: "DEEPSEEK" + model_name: "deepseek-chat" + api_key: "your-deepseek-api-key-here" + api_base: "https://api.deepseek.com/v1" + rpm: 60 + tpm: 100000 + litellm_params: + temperature: 0.7 + max_tokens: 4000 + # Custom system instructions for Chinese responses + system_instructions: | + <system_instruction> + You are SurfSense, a reasoning and acting AI agent designed to answer user questions using the user's personal knowledge base. + + Today's date (UTC): {resolved_today} + + IMPORTANT: Please respond in Chinese (简体中文) unless the user specifically requests another language. + </system_instruction> + use_default_system_instructions: false + citations_enabled: true + + # Example: Azure OpenAI GPT-4o + # IMPORTANT: For Azure deployments, always include 'base_model' in litellm_params + # to enable accurate token counting, cost tracking, and max token limits + - id: -5 + name: "Global Azure GPT-4o" + description: "Azure OpenAI GPT-4o deployment" + billing_tier: "free" + anonymous_enabled: true + seo_enabled: true + seo_slug: "azure-gpt-4o" + quota_reserve_tokens: 4000 + provider: "AZURE" + # model_name format for Azure: azure/<your-deployment-name> + model_name: "azure/gpt-4o-deployment" + api_key: "your-azure-api-key-here" + api_base: "https://your-resource.openai.azure.com" + api_version: "2024-02-15-preview" # Azure API version + rpm: 1000 + tpm: 150000 + litellm_params: + temperature: 0.7 + max_tokens: 4000 + # REQUIRED for Azure: Specify the underlying OpenAI model + # This fixes "Could not identify azure model" warnings + # Common base_model values: gpt-4, gpt-4-turbo, gpt-4o, gpt-4o-mini, gpt-3.5-turbo + base_model: "gpt-4o" + system_instructions: "" + use_default_system_instructions: true + citations_enabled: true + + # Example: Azure OpenAI GPT-4 Turbo + - id: -6 + name: "Global Azure GPT-4 Turbo" + description: "Azure OpenAI GPT-4 Turbo deployment" + billing_tier: "free" + anonymous_enabled: true + seo_enabled: true + seo_slug: "azure-gpt-4-turbo" + quota_reserve_tokens: 4000 + provider: "AZURE" + model_name: "azure/gpt-4-turbo-deployment" + api_key: "your-azure-api-key-here" + api_base: "https://your-resource.openai.azure.com" + api_version: "2024-02-15-preview" + rpm: 500 + tpm: 100000 + litellm_params: + temperature: 0.7 + max_tokens: 4000 + base_model: "gpt-4-turbo" # Maps to gpt-4-turbo-preview + system_instructions: "" + use_default_system_instructions: true + citations_enabled: true + + # Example: Groq - Fast inference + - id: -7 + name: "Global Groq Llama 3" + description: "Ultra-fast Llama 3 70B via Groq" + billing_tier: "free" + anonymous_enabled: true + seo_enabled: true + seo_slug: "groq-llama-3" + quota_reserve_tokens: 8000 + provider: "GROQ" + model_name: "llama3-70b-8192" + api_key: "your-groq-api-key-here" + api_base: "" + rpm: 30 # Groq has lower rate limits on free tier + tpm: 14400 + litellm_params: + temperature: 0.7 + max_tokens: 8000 + system_instructions: "" + use_default_system_instructions: true + citations_enabled: true + + # Example: MiniMax M3 - High-performance with 512K context window + - id: -8 + name: "Global MiniMax M3" + description: "MiniMax M3 with 512K context window and competitive pricing" + billing_tier: "free" + anonymous_enabled: true + seo_enabled: true + seo_slug: "minimax-m3" + quota_reserve_tokens: 4000 + provider: "MINIMAX" + model_name: "MiniMax-M3" + api_key: "your-minimax-api-key-here" + api_base: "https://api.minimax.io/v1" + rpm: 60 + tpm: 100000 + litellm_params: + temperature: 1.0 # MiniMax requires temperature in (0.0, 1.0], cannot be 0 + max_tokens: 4000 + system_instructions: "" + use_default_system_instructions: true + citations_enabled: true + + # Example: Planner LLM - small, fast model used for internal utility tasks + # + # The PLANNER role handles short, structured internal calls (KB query + # rewriting, date extraction, recency classification, etc.) that don't + # need frontier-tier capability. Pointing the planner at a cheap+fast + # model (gpt-4o-mini, Claude Haiku, Azure gpt-5.x-nano, Groq Llama, ...) + # typically saves 500ms-1.5s per turn vs. routing those same internal + # calls through the user's chat model. + # + # Activation: + # - Mark EXACTLY ONE global config with ``is_planner: true``. + # - If multiple are marked, the first one wins and a WARNING is logged. + # - If none is marked, every internal call falls back to the user's + # chat LLM (same behavior as before this flag existed). + # + # This config is operator-only — it is NOT exposed in the user-facing + # model selector, never billed against premium quota, and the + # billing_tier / anonymous_enabled fields below are ignored. - id: -9 - name: "Azure GPT 5.x Nano Planner" + name: "Global Planner (GPT-4o mini)" + description: "Internal-only planner LLM for query rewriting and classification" is_planner: true billing_tier: "free" anonymous_enabled: false seo_enabled: false quota_reserve_tokens: 1000 - provider: "azure" - model_name: "gpt-5.4-nano" - supports_image_input: false - supports_tools: false - router_pool_eligible: false - api_key: "your-azure-api-key-here" - api_base: "https://your-resource.openai.azure.com" - rpm: 20000 - tpm: 4000000 + provider: "OPENAI" + model_name: "gpt-4o-mini" + api_key: "sk-your-openai-api-key-here" + api_base: "" + rpm: 3500 + tpm: 200000 litellm_params: temperature: 0 max_tokens: 1000 - base_model: "gpt-5.4-nano" system_instructions: "" use_default_system_instructions: true citations_enabled: false # ============================================================================= -# OpenRouter Dynamic Model Integration +# OpenRouter Integration # ============================================================================= -# When enabled, SurfSense fetches the OpenRouter catalog at startup and injects -# supported models as GLOBAL chat and optionally image-generation models. -# Tier is derived per model from OpenRouter data: -# - model id ends with ":free" -> billing_tier=free -# - prompt and completion pricing are zero -> billing_tier=free -# - otherwise -> billing_tier=premium -# -# Do not use deprecated openrouter_integration.billing_tier or -# openrouter_integration.anonymous_enabled. Use the tier-specific anonymous -# switches below. +# When enabled, dynamically fetches ALL available models from the OpenRouter API +# and injects them as global configs. This gives premium users access to any model +# on OpenRouter (Claude, Gemini, Llama, Mistral, etc.) via their premium token quota, +# while free-tier OpenRouter models show up with a green Free badge and do NOT +# consume premium quota. +# Models are fetched at startup and refreshed periodically in the background. +# All calls go through LiteLLM with the openrouter/ prefix. openrouter_integration: enabled: false api_key: "sk-or-your-openrouter-api-key" + # Tier is derived PER MODEL from OpenRouter's own API signals: + # - id ends with ":free" -> billing_tier=free + # - pricing.prompt AND pricing.completion == "0" -> billing_tier=free + # - otherwise -> billing_tier=premium + # No global billing_tier knob is honored; any legacy value emits a startup warning. + + # Anonymous access is split by tier so operators can expose only free + # models to no-login users without leaking paid inference. anonymous_enabled_paid: false anonymous_enabled_free: false + seo_enabled: false + # quota_reserve_tokens: tokens reserved per call for quota enforcement quota_reserve_tokens: 4000 - - # Base negative ID namespace for dynamic chat models. IDs are derived - # deterministically so they survive catalog churn. Do not overlap static IDs. + # id_offset: base negative ID for dynamically generated configs. + # Model IDs are derived deterministically via BLAKE2b so they survive + # catalogue churn. Must not overlap with your static global_llm_configs IDs. id_offset: -10000 - - # Separate base negative ID namespace for dynamic image-generation models. - image_id_offset: -20000 - - # How often to refresh the OpenRouter catalog. 0 means startup only. + # refresh_interval_hours: how often to re-fetch models from OpenRouter (0 = startup only) refresh_interval_hours: 24 - # Paid OpenRouter models may join curated router pools when eligible. + # Rate limits for PAID OpenRouter models. These are used by LiteLLM Router + # for per-deployment accounting when OR premium models participate in the + # shared sub-agent "auto" pool. They do NOT cap OpenRouter itself — your + # real account limits live at https://openrouter.ai/settings/limits. rpm: 200 tpm: 1000000 - # Free OpenRouter models are available for user-facing selection/pinning but - # should be treated as a shared-account bucket, not normal router capacity. + # Rate limits for FREE OpenRouter models. Informational only: free OR + # models are intentionally kept OUT of the LiteLLM Router pool, because + # OpenRouter enforces free-tier limits globally per account (~20 RPM + + # 50-1000 daily requests across every ":free" model combined) — + # per-deployment router accounting can't represent a shared bucket + # correctly. Free OR models stay fully available in the model selector + # and for user-facing Auto thread pinning. free_rpm: 20 free_tpm: 100000 - # Image generation is opt-in to avoid injecting a large image catalog during - # upgrades. Vision-capable chat models are represented with - # supports_image_input: true. + # Image generation + vision LLM emission are OPT-IN. OpenRouter's catalogue + # contains hundreds of image- and vision-capable models; turning these on + # injects them into the global Image-Generation / Vision-LLM model + # selectors alongside any static configs. Tier (free/premium) is derived + # per model the same way it is for chat (`:free` suffix or zero pricing). + # When a user picks a premium image/vision model the call debits the + # shared $5 USD-cost-based premium credit pool — so leaving these off + # avoids surprise quota burn on existing deployments. Default: false. image_generation_enabled: false vision_enabled: false @@ -241,80 +367,191 @@ openrouter_integration: citations_enabled: true # ============================================================================= -# Image Generation Auto Mode Router Settings +# Image Generation Configuration # ============================================================================= +# These configurations power the image generation feature using litellm.aimage_generation(). +# Supported providers: OpenAI, Azure, Google AI Studio, Vertex AI, AWS Bedrock, +# Recraft, OpenRouter, Xinference, Nscale +# +# Auto mode (ID 0) uses LiteLLM Router for load balancing across all image gen configs. + +# Router Settings for Image Generation Auto Mode image_generation_router_settings: routing_strategy: "usage-based-routing" num_retries: 3 allowed_fails: 3 cooldown_time: 60 -# ============================================================================= -# Static GLOBAL Image Generation Models -# ============================================================================= global_image_generation_configs: - - id: -2001 - name: "Azure GPT Image 1.5" - billing_tier: "premium" - provider: "azure" - model_name: "gpt-image-1.5" + # Example: OpenAI DALL-E 3 + - id: -1 + name: "Global DALL-E 3" + description: "OpenAI's DALL-E 3 for high-quality image generation" + provider: "OPENAI" + model_name: "dall-e-3" + api_key: "sk-your-openai-api-key-here" + api_base: "" + rpm: 50 # Requests per minute (image gen is rate-limited by RPM, not tokens) + litellm_params: {} + + # Example: OpenAI GPT Image 1 + - id: -2 + name: "Global GPT Image 1" + description: "OpenAI's GPT Image 1 model" + provider: "OPENAI" + model_name: "gpt-image-1" + api_key: "sk-your-openai-api-key-here" + api_base: "" + rpm: 50 + litellm_params: {} + + # Example: Azure OpenAI DALL-E 3 + - id: -3 + name: "Global Azure DALL-E 3" + description: "Azure-hosted DALL-E 3 deployment" + provider: "AZURE_OPENAI" + model_name: "azure/dall-e-3-deployment" api_key: "your-azure-api-key-here" api_base: "https://your-resource.openai.azure.com" - # api_version: "2025-04-01-preview" - rpm: 60 + api_version: "2024-02-15-preview" + rpm: 50 litellm_params: - base_model: "gpt-image-1.5" + base_model: "dall-e-3" - - id: -2002 - name: "Azure GPT Image 1 Mini" - billing_tier: "free" - provider: "azure" - model_name: "gpt-image-1-mini" - api_key: "your-azure-api-key-here" - api_base: "https://your-resource.openai.azure.com" - # api_version: "2025-04-01-preview" - rpm: 120 - litellm_params: - base_model: "gpt-image-1-mini" + # Example: OpenRouter Gemini Image Generation + # - id: -4 + # name: "Global Gemini Image Gen" + # description: "Google Gemini image generation via OpenRouter" + # provider: "OPENROUTER" + # model_name: "google/gemini-2.5-flash-image" + # api_key: "your-openrouter-api-key-here" + # api_base: "" + # rpm: 30 + # litellm_params: {} # ============================================================================= -# Field Notes +# Vision LLM Configuration # ============================================================================= -# Common chat/image fields: -# - provider: Canonical provider adapter name. Example: azure, openai, -# anthropic, openrouter, groq, bedrock. -# - model_name: Provider model or deployment id. For Azure, use the bare -# deployment name. The resolver prefixes LiteLLM model strings from provider. -# - api_base: Provider endpoint/root URL. For OpenAI-compatible providers, the -# resolver adds /v1 when needed. -# - api_version: Optional provider-specific API version, stored on the -# materialized connection extra metadata. -# - litellm_params: Passed to LiteLLM when invoking the model. Also used for -# base_model and inline pricing registration. +# These configurations power the vision autocomplete feature (screenshot analysis). +# Only vision-capable models should be used here (e.g. GPT-4o, Gemini Pro, Claude 3). +# Supported providers: OpenAI, Anthropic, Google, Azure OpenAI, Vertex AI, Bedrock, +# xAI, OpenRouter, Ollama, Groq, Together AI, Fireworks AI, DeepSeek, Mistral, Custom # -# Chat model fields: -# - supports_image_input: true when the chat model can consume image inputs. -# - supports_tools: true when the model can use tools/function calling. -# - max_input_tokens: Optional UI/catalog metadata for context size. -# - router_pool_eligible: false keeps a model out of shared router pools while -# still allowing direct selection/pinning. -# - is_planner: true marks the internal-only planner model. Only one config -# should set this flag. +# Auto mode (ID 0) uses LiteLLM Router for load balancing across all vision configs. + +# Router Settings for Vision LLM Auto Mode +vision_llm_router_settings: + routing_strategy: "usage-based-routing" + num_retries: 3 + allowed_fails: 3 + cooldown_time: 60 + +global_vision_llm_configs: + # Example: OpenAI GPT-4o (recommended for vision) + - id: -1 + name: "Global GPT-4o Vision" + description: "OpenAI's GPT-4o with strong vision capabilities" + provider: "OPENAI" + model_name: "gpt-4o" + api_key: "sk-your-openai-api-key-here" + api_base: "" + rpm: 500 + tpm: 100000 + litellm_params: + temperature: 0.3 + max_tokens: 1000 + + # Example: Google Gemini 2.0 Flash + - id: -2 + name: "Global Gemini 2.0 Flash" + description: "Google's fast vision model with large context" + provider: "GOOGLE" + model_name: "gemini-2.0-flash" + api_key: "your-google-ai-api-key-here" + api_base: "" + rpm: 1000 + tpm: 200000 + litellm_params: + temperature: 0.3 + max_tokens: 1000 + + # Example: Anthropic Claude 3.5 Sonnet + - id: -3 + name: "Global Claude 3.5 Sonnet Vision" + description: "Anthropic's Claude 3.5 Sonnet with vision support" + provider: "ANTHROPIC" + model_name: "claude-3-5-sonnet-20241022" + api_key: "sk-ant-your-anthropic-api-key-here" + api_base: "" + rpm: 1000 + tpm: 100000 + litellm_params: + temperature: 0.3 + max_tokens: 1000 + + # Example: Azure OpenAI GPT-4o + # - id: -4 + # name: "Global Azure GPT-4o Vision" + # description: "Azure-hosted GPT-4o for vision analysis" + # provider: "AZURE_OPENAI" + # model_name: "azure/gpt-4o-deployment" + # api_key: "your-azure-api-key-here" + # api_base: "https://your-resource.openai.azure.com" + # api_version: "2024-02-15-preview" + # rpm: 500 + # tpm: 100000 + # litellm_params: + # temperature: 0.3 + # max_tokens: 1000 + # base_model: "gpt-4o" + +# Notes: +# - ID 0 is reserved for "Auto" mode - uses LiteLLM Router for load balancing +# - Use negative IDs to distinguish global configs from user configs (NewLLMConfig in DB) +# - IDs should be unique and sequential (e.g., -1, -2, -3, etc.) +# - The 'api_key' field will not be exposed to users via API +# - system_instructions: Custom prompt or empty string to use defaults +# - use_default_system_instructions: true = use SURFSENSE_SYSTEM_INSTRUCTIONS when system_instructions is empty +# - citations_enabled: true = include citation instructions, false = include anti-citation instructions +# - All standard LiteLLM providers are supported +# - rpm/tpm: Optional rate limits for load balancing (requests/tokens per minute) +# These help the router distribute load evenly and avoid rate limit errors # -# Catalog and access fields: -# - billing_tier: "free" or "premium". -# - anonymous_enabled: Whether the model appears in the public no-login catalog. -# - seo_enabled: Whether a /free/<seo_slug> landing page is generated. -# - seo_slug: Stable URL slug for SEO pages. Keep unique and do not change once -# public. -# - seo_title / seo_description: Optional SEO metadata overrides. -# - quota_reserve_tokens: Tokens reserved before each chat LLM call. -# - rpm / tpm: Optional rate limits for router accounting and load balancing. # -# Image generation notes: -# - Image-generation configs use the same GLOBAL ID namespace as chat models. -# - Only RPM is relevant for most image-generation APIs. -# - The runtime uses litellm.aimage_generation(). -# - Image billing currently uses billing_tier and model catalog metadata. Keep -# quota reserve tuning in code/catalog unless the materializer copies a YAML -# key for image quota reservation. +# IMAGE GENERATION NOTES: +# - Image generation configs use the same ID scheme as LLM configs (negative for global) +# - Supported models: dall-e-2, dall-e-3, gpt-image-1 (OpenAI), azure/* (Azure), +# bedrock/* (AWS), vertex_ai/* (Google), recraft/* (Recraft), openrouter/* (OpenRouter) +# - The router uses litellm.aimage_generation() for async image generation +# - Only RPM (requests per minute) is relevant for image generation rate limiting. +# TPM (tokens per minute) does not apply since image APIs are billed/rate-limited per request, not per token. +# +# VISION LLM NOTES: +# - Vision configs use the same ID scheme (negative for global, positive for user DB) +# - Only use vision-capable models (GPT-4o, Gemini, Claude 3, etc.) +# - Lower temperature (0.3) is recommended for accurate screenshot analysis +# - Lower max_tokens (1000) is sufficient since autocomplete produces short suggestions +# +# PLANNER LLM NOTES: +# - is_planner: true marks a config as the internal-only planner LLM (small, +# fast model used for KB query rewriting, date extraction, recency +# classification, etc.). Only one config may carry this flag — if +# multiple do, the first one wins and a startup WARNING is logged. +# - When no config is marked is_planner, every internal utility call falls +# back to the user's chat LLM (the historical behavior). +# - Planner configs are NOT shown in the user-facing model selector and +# are NOT billed against the user's premium quota. Their billing_tier, +# anonymous_enabled, seo_* fields are ignored. +# - Recommended models: gpt-4o-mini, claude-3-5-haiku, gemini-1.5-flash, +# azure gpt-5.x-nano, groq llama3-8b — anything <200ms p50 on a 1-2k +# prompt. Frontier models here defeat the purpose of the flag. +# +# TOKEN QUOTA & ANONYMOUS ACCESS NOTES: +# - billing_tier: "free" or "premium". Controls whether registered users need premium token quota. +# - anonymous_enabled: true/false. Whether the model appears in the public no-login catalog. +# - seo_enabled: true/false. Whether a /free/<seo_slug> landing page is generated. +# - seo_slug: Stable URL slug for SEO pages. Must be unique. Do NOT change once public. +# - seo_title: Optional HTML title tag override for the model's /free/<slug> page. +# - seo_description: Optional meta description override for the model's /free/<slug> page. +# - quota_reserve_tokens: Tokens reserved before each LLM call for quota enforcement. +# Independent of litellm_params.max_tokens. Used by the token quota service. diff --git a/surfsense_backend/app/connectors/dropbox/content_extractor.py b/surfsense_backend/app/connectors/dropbox/content_extractor.py index 300010c26..372d2fc82 100644 --- a/surfsense_backend/app/connectors/dropbox/content_extractor.py +++ b/surfsense_backend/app/connectors/dropbox/content_extractor.py @@ -90,12 +90,11 @@ async def download_and_extract_content( if error: return None, metadata, error - from app.etl_pipeline.cache import extract_with_cache from app.etl_pipeline.etl_document import EtlRequest + from app.etl_pipeline.etl_pipeline_service import EtlPipelineService - result = await extract_with_cache( - EtlRequest(file_path=temp_file_path, filename=file_name), - vision_llm=vision_llm, + result = await EtlPipelineService(vision_llm=vision_llm).extract( + EtlRequest(file_path=temp_file_path, filename=file_name) ) markdown = result.markdown_content return markdown, metadata, None diff --git a/surfsense_backend/app/connectors/google_drive/content_extractor.py b/surfsense_backend/app/connectors/google_drive/content_extractor.py index 1ea047978..59392831d 100644 --- a/surfsense_backend/app/connectors/google_drive/content_extractor.py +++ b/surfsense_backend/app/connectors/google_drive/content_extractor.py @@ -122,13 +122,12 @@ async def download_and_extract_content( async def _parse_file_to_markdown( file_path: str, filename: str, *, vision_llm=None ) -> str: - """Parse a local file to markdown via the cache-aware ETL pipeline.""" - from app.etl_pipeline.cache import extract_with_cache + """Parse a local file to markdown using the unified ETL pipeline.""" from app.etl_pipeline.etl_document import EtlRequest + from app.etl_pipeline.etl_pipeline_service import EtlPipelineService - result = await extract_with_cache( - EtlRequest(file_path=file_path, filename=filename), - vision_llm=vision_llm, + result = await EtlPipelineService(vision_llm=vision_llm).extract( + EtlRequest(file_path=file_path, filename=filename) ) return result.markdown_content diff --git a/surfsense_backend/app/connectors/onedrive/content_extractor.py b/surfsense_backend/app/connectors/onedrive/content_extractor.py index fb1d31fbc..3154f2eca 100644 --- a/surfsense_backend/app/connectors/onedrive/content_extractor.py +++ b/surfsense_backend/app/connectors/onedrive/content_extractor.py @@ -84,12 +84,11 @@ async def download_and_extract_content( async def _parse_file_to_markdown( file_path: str, filename: str, *, vision_llm=None ) -> str: - """Parse a local file to markdown via the cache-aware ETL pipeline.""" - from app.etl_pipeline.cache import extract_with_cache + """Parse a local file to markdown using the unified ETL pipeline.""" from app.etl_pipeline.etl_document import EtlRequest + from app.etl_pipeline.etl_pipeline_service import EtlPipelineService - result = await extract_with_cache( - EtlRequest(file_path=file_path, filename=filename), - vision_llm=vision_llm, + result = await EtlPipelineService(vision_llm=vision_llm).extract( + EtlRequest(file_path=file_path, filename=filename) ) return result.markdown_content diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 3f098d5d2..6117caecb 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -1,4 +1,3 @@ -import logging import uuid from collections.abc import AsyncGenerator from contextlib import asynccontextmanager @@ -35,8 +34,6 @@ from app.config import config if config.AUTH_TYPE == "GOOGLE": from fastapi_users.db import SQLAlchemyBaseOAuthAccountTableUUID -logger = logging.getLogger(__name__) - DATABASE_URL = config.DATABASE_URL @@ -117,6 +114,13 @@ class SearchSourceConnectorType(StrEnum): COMPOSIO_GOOGLE_CALENDAR_CONNECTOR = "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR" +class PodcastStatus(StrEnum): + PENDING = "pending" + GENERATING = "generating" + READY = "ready" + FAILED = "failed" + + class VideoPresentationStatus(StrEnum): PENDING = "pending" GENERATING = "generating" @@ -201,15 +205,79 @@ class DocumentStatus: return None -class ConnectionScope(StrEnum): - GLOBAL = "GLOBAL" - SEARCH_SPACE = "SEARCH_SPACE" - USER = "USER" +class LiteLLMProvider(StrEnum): + """ + Enum for LLM providers supported by LiteLLM. + """ + + OPENAI = "OPENAI" + ANTHROPIC = "ANTHROPIC" + GOOGLE = "GOOGLE" + AZURE_OPENAI = "AZURE_OPENAI" + BEDROCK = "BEDROCK" + VERTEX_AI = "VERTEX_AI" + GROQ = "GROQ" + COHERE = "COHERE" + MISTRAL = "MISTRAL" + DEEPSEEK = "DEEPSEEK" + XAI = "XAI" + OPENROUTER = "OPENROUTER" + TOGETHER_AI = "TOGETHER_AI" + FIREWORKS_AI = "FIREWORKS_AI" + REPLICATE = "REPLICATE" + PERPLEXITY = "PERPLEXITY" + OLLAMA = "OLLAMA" + ALIBABA_QWEN = "ALIBABA_QWEN" + MOONSHOT = "MOONSHOT" + ZHIPU = "ZHIPU" + ANYSCALE = "ANYSCALE" + DEEPINFRA = "DEEPINFRA" + CEREBRAS = "CEREBRAS" + SAMBANOVA = "SAMBANOVA" + AI21 = "AI21" + CLOUDFLARE = "CLOUDFLARE" + DATABRICKS = "DATABRICKS" + COMETAPI = "COMETAPI" + HUGGINGFACE = "HUGGINGFACE" + GITHUB_MODELS = "GITHUB_MODELS" + MINIMAX = "MINIMAX" + CUSTOM = "CUSTOM" -class ModelSource(StrEnum): - DISCOVERED = "DISCOVERED" - MANUAL = "MANUAL" +class ImageGenProvider(StrEnum): + """ + Enum for image generation providers supported by LiteLLM. + This is a subset of LLM providers — only those that support image generation. + See: https://docs.litellm.ai/docs/image_generation#supported-providers + """ + + OPENAI = "OPENAI" + AZURE_OPENAI = "AZURE_OPENAI" + GOOGLE = "GOOGLE" # Google AI Studio + VERTEX_AI = "VERTEX_AI" + BEDROCK = "BEDROCK" # AWS Bedrock + RECRAFT = "RECRAFT" + OPENROUTER = "OPENROUTER" + XINFERENCE = "XINFERENCE" + NSCALE = "NSCALE" + + +class VisionProvider(StrEnum): + OPENAI = "OPENAI" + ANTHROPIC = "ANTHROPIC" + GOOGLE = "GOOGLE" + AZURE_OPENAI = "AZURE_OPENAI" + VERTEX_AI = "VERTEX_AI" + BEDROCK = "BEDROCK" + XAI = "XAI" + OPENROUTER = "OPENROUTER" + OLLAMA = "OLLAMA" + GROQ = "GROQ" + TOGETHER_AI = "TOGETHER_AI" + FIREWORKS_AI = "FIREWORKS_AI" + DEEPSEEK = "DEEPSEEK" + MISTRAL = "MISTRAL" + CUSTOM = "CUSTOM" class LogLevel(StrEnum): @@ -252,7 +320,7 @@ class PagePurchaseStatus(StrEnum): FAILED = "failed" -class CreditPurchaseStatus(StrEnum): +class PremiumTokenPurchaseStatus(StrEnum): PENDING = "pending" COMPLETED = "completed" FAILED = "failed" @@ -264,27 +332,26 @@ INCENTIVE_TASKS_CONFIG = { IncentiveTaskType.GITHUB_STAR: { "title": "Star our GitHub repository", "description": "Show your support by starring SurfSense on GitHub", - # Credit reward in USD micro-units (1_000_000 == $1.00). $0.03. - "credit_micros_reward": 30000, + "pages_reward": 30, "action_url": "https://github.com/MODSetter/SurfSense", }, IncentiveTaskType.REDDIT_FOLLOW: { "title": "Join our Subreddit", "description": "Join the SurfSense community on Reddit", - "credit_micros_reward": 30000, + "pages_reward": 30, "action_url": "https://www.reddit.com/r/SurfSense/", }, IncentiveTaskType.DISCORD_JOIN: { "title": "Join our Discord", "description": "Join the SurfSense community on Discord", - "credit_micros_reward": 40000, + "pages_reward": 40, "action_url": "https://discord.gg/ejRNvftDp9", }, # Future tasks can be configured here: # IncentiveTaskType.GITHUB_ISSUE: { # "title": "Create an issue", # "description": "Help improve SurfSense by reporting bugs or suggesting features", - # "credit_micros_reward": 50000, + # "pages_reward": 50, # "action_url": "https://github.com/MODSetter/SurfSense/issues/new/choose", # }, } @@ -638,11 +705,11 @@ class NewChatThread(BaseModel, TimestampMixin): default=False, server_default="false", ) - # Auto model pin for this thread: concrete resolved global LLM + # Auto (Fastest) model pin for this thread: concrete resolved global LLM # config id. NULL means no pin; Auto will resolve on the next turn. # Single-writer invariant: only app.services.auto_model_pin_service sets # or clears this column (plus bulk clears when a search space's - # chat_model_id changes). Unindexed: all reads are by primary key. + # agent_llm_id changes). Unindexed: all reads are by primary key. pinned_llm_config_id = Column(Integer, nullable=True) # Surface metadata for first-party SurfSense and external chat threads. @@ -1423,10 +1490,7 @@ class Document(BaseModel, TimestampMixin): created_by = relationship("User", back_populates="documents") connector = relationship("SearchSourceConnector", back_populates="documents") chunks = relationship( - "Chunk", - back_populates="document", - cascade="all, delete-orphan", - order_by="Chunk.position", + "Chunk", back_populates="document", cascade="all, delete-orphan" ) # Original upload + future derived artifacts (redacted, filled-form). # Model lives in app.file_storage.persistence to keep that feature cohesive. @@ -1462,11 +1526,6 @@ class Chunk(BaseModel, TimestampMixin): content = Column(Text, nullable=False) embedding = Column(Vector(config.embedding_model_instance.dimension)) - # Explicit document order; ids don't follow it since incremental - # re-indexing keeps unchanged rows across edits. Deliberately not indexed: - # ordering reads are document-scoped (covered by ix_chunks_document_id) and - # building a position index on the large chunks table is not worth it. - position = Column(Integer, nullable=False, server_default="0") document_id = Column( Integer, @@ -1477,6 +1536,41 @@ class Chunk(BaseModel, TimestampMixin): document = relationship("Document", back_populates="chunks") +class Podcast(BaseModel, TimestampMixin): + """Podcast model for storing generated podcasts.""" + + __tablename__ = "podcasts" + + title = Column(String(500), nullable=False) + podcast_transcript = Column(JSONB, nullable=True) + file_location = Column(Text, nullable=True) + status = Column( + SQLAlchemyEnum( + PodcastStatus, + name="podcast_status", + create_type=False, + values_callable=lambda x: [e.value for e in x], + ), + nullable=False, + default=PodcastStatus.READY, + server_default="ready", + index=True, + ) + + search_space_id = Column( + Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False + ) + search_space = relationship("SearchSpace", back_populates="podcasts") + + thread_id = Column( + Integer, + ForeignKey("new_chat_threads.id", ondelete="SET NULL"), + nullable=True, + index=True, + ) + thread = relationship("NewChatThread") + + class VideoPresentation(BaseModel, TimestampMixin): """Video presentation model for storing AI-generated video presentations. @@ -1548,80 +1642,73 @@ class Report(BaseModel, TimestampMixin): thread = relationship("NewChatThread") -class Connection(BaseModel, TimestampMixin): - __tablename__ = "connections" +class ImageGenerationConfig(BaseModel, TimestampMixin): + """ + Dedicated configuration table for image generation models. - provider = Column(String(100), nullable=False, index=True) - base_url = Column(String(500), nullable=True) - api_key = Column(String, nullable=True) - extra = Column(JSONB, nullable=False, default=dict, server_default="{}") - scope = Column(SQLAlchemyEnum(ConnectionScope), nullable=False, index=True) - enabled = Column(Boolean, nullable=False, default=True, server_default="true") + Separate from NewLLMConfig because image generation models don't need + system_instructions, citations_enabled, or use_default_system_instructions. + They only need provider credentials and model parameters. + """ + + __tablename__ = "image_generation_configs" + + name = Column(String(100), nullable=False, index=True) + description = Column(String(500), nullable=True) + + # Provider & model (uses ImageGenProvider, NOT LiteLLMProvider) + provider = Column(SQLAlchemyEnum(ImageGenProvider), nullable=False) + custom_provider = Column(String(100), nullable=True) + model_name = Column(String(100), nullable=False) + + # Credentials + api_key = Column(String, nullable=False) + api_base = Column(String(500), nullable=True) + api_version = Column(String(50), nullable=True) # Azure-specific + + # Additional litellm parameters + litellm_params = Column(JSON, nullable=True, default={}) + + # Relationships + search_space_id = Column( + Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False + ) + search_space = relationship( + "SearchSpace", back_populates="image_generation_configs" + ) + + # User who created this config + user_id = Column( + UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False + ) + user = relationship("User", back_populates="image_generation_configs") + + +class VisionLLMConfig(BaseModel, TimestampMixin): + __tablename__ = "vision_llm_configs" + + name = Column(String(100), nullable=False, index=True) + description = Column(String(500), nullable=True) + + provider = Column(SQLAlchemyEnum(VisionProvider), nullable=False) + custom_provider = Column(String(100), nullable=True) + model_name = Column(String(100), nullable=False) + + api_key = Column(String, nullable=False) + api_base = Column(String(500), nullable=True) + api_version = Column(String(50), nullable=True) + + litellm_params = Column(JSON, nullable=True, default={}) search_space_id = Column( - Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=True + Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False ) + search_space = relationship("SearchSpace", back_populates="vision_llm_configs") + user_id = Column( - UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=True - ) - - search_space = relationship("SearchSpace", back_populates="connections") - user = relationship("User", back_populates="connections") - models = relationship( - "Model", - back_populates="connection", - order_by="Model.id", - cascade="all, delete-orphan", - passive_deletes=True, - ) - - __table_args__ = ( - CheckConstraint( - "(scope = 'GLOBAL' AND search_space_id IS NULL AND user_id IS NULL) OR " - "(scope = 'SEARCH_SPACE' AND search_space_id IS NOT NULL AND user_id IS NOT NULL) OR " - "(scope = 'USER' AND user_id IS NOT NULL)", - name="ck_connections_scope_owner", - ), - ) - - -class Model(BaseModel, TimestampMixin): - __tablename__ = "models" - - connection_id = Column( - Integer, - ForeignKey("connections.id", ondelete="CASCADE"), - nullable=False, - index=True, - ) - model_id = Column(String(255), nullable=False) - display_name = Column(String(255), nullable=True) - source = Column( - SQLAlchemyEnum(ModelSource), - nullable=False, - default=ModelSource.DISCOVERED, - server_default=ModelSource.DISCOVERED.value, - ) - supports_chat = Column(Boolean, nullable=True) - max_input_tokens = Column(Integer, nullable=True) - supports_image_input = Column(Boolean, nullable=True) - supports_tools = Column(Boolean, nullable=True) - supports_image_generation = Column(Boolean, nullable=True) - capabilities_override = Column( - JSONB, nullable=False, default=dict, server_default="{}" - ) - enabled = Column(Boolean, nullable=False, default=True, server_default="true") - billing_tier = Column(String(50), nullable=True, index=True) - catalog = Column(JSONB, nullable=False, default=dict, server_default="{}") - - connection = relationship("Connection", back_populates="models") - - __table_args__ = ( - UniqueConstraint( - "connection_id", "model_id", name="uq_models_connection_model_id" - ), - Index("ix_models_model_id", "model_id"), + UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False ) + user = relationship("User", back_populates="vision_llm_configs") class ImageGeneration(BaseModel, TimestampMixin): @@ -1655,9 +1742,10 @@ class ImageGeneration(BaseModel, TimestampMixin): style = Column(String(50), nullable=True) # Model-specific style parameter response_format = Column(String(50), nullable=True) # "url" or "b64_json" - # Image generation model provenance. - # 0 = Auto mode, negative IDs = GLOBAL models, positive IDs = Model records. - image_gen_model_id = Column(Integer, nullable=True) + # Image generation config reference + # 0 = Auto mode (router), negative IDs = global configs from YAML, + # positive IDs = ImageGenerationConfig records in DB + image_generation_config_id = Column(Integer, nullable=True) # Response data (full litellm response as JSONB) — present on success response_data = Column(JSONB, nullable=True) @@ -1699,19 +1787,19 @@ class SearchSpace(BaseModel, TimestampMixin): shared_memory_md = Column(Text, nullable=True, server_default="") - # Connection/model role bindings. - # Note: ID values preserve the existing convention: - # - 0: Auto mode - # - Negative IDs: Global virtual models from global_llm_config.yaml - # - Positive IDs: User/search-space models from the models table - chat_model_id = Column( - Integer, nullable=True, default=0, server_default="0" + # Search space-level LLM preferences (shared by all members) + # Note: ID values: + # - 0: Auto mode (uses LiteLLM Router for load balancing) - default for new search spaces + # - Negative IDs: Global configs from YAML + # - Positive IDs: Custom configs from DB (NewLLMConfig table) + agent_llm_id = Column( + Integer, nullable=True, default=0 ) # For agent/chat operations, defaults to Auto mode - image_gen_model_id = Column( - Integer, nullable=True, default=0, server_default="0" - ) # For image generation, defaults to Auto mode when eligible - vision_model_id = Column( - Integer, nullable=True, default=0, server_default="0" + image_generation_config_id = Column( + Integer, nullable=True, default=0 + ) # For image generation, defaults to Auto mode + vision_llm_config_id = Column( + Integer, nullable=True, default=0 ) # For vision/screenshot analysis, defaults to Auto mode ai_file_sort_enabled = Column( @@ -1783,12 +1871,23 @@ class SearchSpace(BaseModel, TimestampMixin): order_by="SearchSourceConnector.id", cascade="all, delete-orphan", ) - connections = relationship( - "Connection", + new_llm_configs = relationship( + "NewLLMConfig", back_populates="search_space", - order_by="Connection.id", + order_by="NewLLMConfig.id", + cascade="all, delete-orphan", + ) + image_generation_configs = relationship( + "ImageGenerationConfig", + back_populates="search_space", + order_by="ImageGenerationConfig.id", + cascade="all, delete-orphan", + ) + vision_llm_configs = relationship( + "VisionLLMConfig", + back_populates="search_space", + order_by="VisionLLMConfig.id", cascade="all, delete-orphan", - passive_deletes=True, ) automations = relationship( @@ -1891,6 +1990,64 @@ class SearchSourceConnector(BaseModel, TimestampMixin): documents = relationship("Document", back_populates="connector") +class NewLLMConfig(BaseModel, TimestampMixin): + """ + New LLM configuration table that combines model settings with prompt configuration. + + This table provides: + - LLM model configuration (provider, model_name, api_key, etc.) + - Configurable system instructions (defaults to SURFSENSE_SYSTEM_INSTRUCTIONS) + - Citation toggle (enable/disable citation instructions) + + Note: Tools instructions are built by get_tools_instructions(thread_visibility) (personal vs shared memory). + """ + + __tablename__ = "new_llm_configs" + + name = Column(String(100), nullable=False, index=True) + description = Column(String(500), nullable=True) + + # === LLM Model Configuration (from original LLMConfig, excluding 'language') === + # Provider from the enum + provider = Column(SQLAlchemyEnum(LiteLLMProvider), nullable=False) + # Custom provider name when provider is CUSTOM + custom_provider = Column(String(100), nullable=True) + # Just the model name without provider prefix + model_name = Column(String(100), nullable=False) + # API Key should be encrypted before storing + api_key = Column(String, nullable=False) + api_base = Column(String(500), nullable=True) + # For any other parameters that litellm supports + litellm_params = Column(JSON, nullable=True, default={}) + + # === Prompt Configuration === + # Configurable system instructions (defaults to SURFSENSE_SYSTEM_INSTRUCTIONS) + # Users can customize this from the UI + system_instructions = Column( + Text, + nullable=False, + default="", # Empty string means use default SURFSENSE_SYSTEM_INSTRUCTIONS + ) + # Whether to use the default system instructions when system_instructions is empty + use_default_system_instructions = Column(Boolean, nullable=False, default=True) + + # Citation toggle - when enabled, SURFSENSE_CITATION_INSTRUCTIONS is injected + # When disabled, an anti-citation prompt is injected instead + citations_enabled = Column(Boolean, nullable=False, default=True) + + # === Relationships === + search_space_id = Column( + Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False + ) + search_space = relationship("SearchSpace", back_populates="new_llm_configs") + + # User who created this config + user_id = Column( + UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False + ) + user = relationship("User", back_populates="new_llm_configs") + + class Log(BaseModel, TimestampMixin): __tablename__ = "logs" @@ -1912,7 +2069,7 @@ class UserIncentiveTask(BaseModel, TimestampMixin): """ Tracks completed incentive tasks for users. Each user can only complete each task type once. - When a task is completed, the user's credit_micros_balance is increased. + When a task is completed, the user's pages_limit is increased. """ __tablename__ = "user_incentive_tasks" @@ -1931,8 +2088,7 @@ class UserIncentiveTask(BaseModel, TimestampMixin): index=True, ) task_type = Column(SQLAlchemyEnum(IncentiveTaskType), nullable=False, index=True) - # Credit reward granted in USD micro-units (1_000_000 == $1.00). - credit_micros_awarded = Column(BigInteger, nullable=False) + pages_awarded = Column(Integer, nullable=False) completed_at = Column( TIMESTAMP(timezone=True), nullable=False, @@ -1975,18 +2131,18 @@ class PagePurchase(Base, TimestampMixin): user = relationship("User", back_populates="page_purchases") -class CreditPurchase(Base, TimestampMixin): - """Tracks Stripe checkout sessions used to grant credit (USD micro-units). +class PremiumTokenPurchase(Base, TimestampMixin): + """Tracks Stripe checkout sessions used to grant additional premium credit (USD micro-units). - Renamed from ``premium_token_purchases`` in migration 156 as part of the - unified-credits wallet. ``credit_micros_granted`` stores the USD-micro - amount added to ``user.credit_micros_balance`` on fulfillment. - - ``source`` distinguishes a user-initiated checkout from an automatic - off-session top-up (auto-reload), added in the auto-reload migration. + Note: the table name is preserved (``premium_token_purchases``) for + operational continuity even though the unit is now USD micro-credits + instead of raw tokens. The ``credit_micros_granted`` column replaced + the legacy ``tokens_granted`` in migration 140; the stored values + were not transformed because the prior $1 = 1M tokens Stripe price + makes the unit conversion 1:1 numerically. """ - __tablename__ = "credit_purchases" + __tablename__ = "premium_token_purchases" __allow_unmapped__ = True id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) @@ -2004,18 +2160,15 @@ class CreditPurchase(Base, TimestampMixin): credit_micros_granted = Column(BigInteger, nullable=False) amount_total = Column(Integer, nullable=True) currency = Column(String(10), nullable=True) - source = Column( - String(20), nullable=False, default="checkout", server_default="checkout" - ) status = Column( - SQLAlchemyEnum(CreditPurchaseStatus), + SQLAlchemyEnum(PremiumTokenPurchaseStatus), nullable=False, - default=CreditPurchaseStatus.PENDING, + default=PremiumTokenPurchaseStatus.PENDING, index=True, ) completed_at = Column(TIMESTAMP(timezone=True), nullable=True) - user = relationship("User", back_populates="credit_purchases") + user = relationship("User", back_populates="premium_token_purchases") class SearchSpaceRole(BaseModel, TimestampMixin): @@ -2257,8 +2410,22 @@ if config.AUTH_TYPE == "GOOGLE": passive_deletes=True, ) - connections = relationship( - "Connection", + # LLM configs created by this user + new_llm_configs = relationship( + "NewLLMConfig", + back_populates="user", + passive_deletes=True, + ) + + # Image generation configs created by this user + image_generation_configs = relationship( + "ImageGenerationConfig", + back_populates="user", + passive_deletes=True, + ) + + vision_llm_configs = relationship( + "VisionLLMConfig", back_populates="user", passive_deletes=True, ) @@ -2281,40 +2448,33 @@ if config.AUTH_TYPE == "GOOGLE": back_populates="user", cascade="all, delete-orphan", ) - credit_purchases = relationship( - "CreditPurchase", + premium_token_purchases = relationship( + "PremiumTokenPurchase", back_populates="user", cascade="all, delete-orphan", ) - # Unified credit wallet (USD micro-units, 1_000_000 == $1.00). - # Decreases on use (ETL pages + premium model calls), increases on - # purchase / incentive grant / auto-reload. May dip slightly negative - # when an actual cost exceeds its pre-charge estimate; UI clamps at $0. - credit_micros_balance = Column( + # Page usage tracking for ETL services + pages_limit = Column( + Integer, + nullable=False, + default=config.PAGES_LIMIT, + server_default=str(config.PAGES_LIMIT), + ) + pages_used = Column(Integer, nullable=False, default=0, server_default="0") + + premium_credit_micros_limit = Column( BigInteger, nullable=False, - default=config.DEFAULT_CREDIT_MICROS_BALANCE, - server_default=str(config.DEFAULT_CREDIT_MICROS_BALANCE), + default=config.PREMIUM_CREDIT_MICROS_LIMIT, + server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT), ) - # In-flight reservation holds (released/settled at finalize). - credit_micros_reserved = Column( + premium_credit_micros_used = Column( BigInteger, nullable=False, default=0, server_default="0" ) - - # Auto-reload (off-session Stripe top-up), behind AUTO_RELOAD_ENABLED. - # ``stripe_customer_id`` + ``auto_reload_payment_method_id`` are the - # saved-card plumbing; thresholds are micro-USD. ``auto_reload_failed_at`` - # is set (and ``auto_reload_enabled`` flipped off) when an off-session - # charge is declined so the UI can prompt the user to fix their card. - stripe_customer_id = Column(String, nullable=True) - auto_reload_enabled = Column( - Boolean, nullable=False, default=False, server_default="false" + premium_credit_micros_reserved = Column( + BigInteger, nullable=False, default=0, server_default="0" ) - auto_reload_threshold_micros = Column(BigInteger, nullable=True) - auto_reload_amount_micros = Column(BigInteger, nullable=True) - auto_reload_payment_method_id = Column(String, nullable=True) - auto_reload_failed_at = Column(TIMESTAMP(timezone=True), nullable=True) # User profile from OAuth display_name = Column(String, nullable=True) @@ -2389,8 +2549,22 @@ else: passive_deletes=True, ) - connections = relationship( - "Connection", + # LLM configs created by this user + new_llm_configs = relationship( + "NewLLMConfig", + back_populates="user", + passive_deletes=True, + ) + + # Image generation configs created by this user + image_generation_configs = relationship( + "ImageGenerationConfig", + back_populates="user", + passive_deletes=True, + ) + + vision_llm_configs = relationship( + "VisionLLMConfig", back_populates="user", passive_deletes=True, ) @@ -2413,40 +2587,33 @@ else: back_populates="user", cascade="all, delete-orphan", ) - credit_purchases = relationship( - "CreditPurchase", + premium_token_purchases = relationship( + "PremiumTokenPurchase", back_populates="user", cascade="all, delete-orphan", ) - # Unified credit wallet (USD micro-units, 1_000_000 == $1.00). - # Decreases on use (ETL pages + premium model calls), increases on - # purchase / incentive grant / auto-reload. May dip slightly negative - # when an actual cost exceeds its pre-charge estimate; UI clamps at $0. - credit_micros_balance = Column( + # Page usage tracking for ETL services + pages_limit = Column( + Integer, + nullable=False, + default=config.PAGES_LIMIT, + server_default=str(config.PAGES_LIMIT), + ) + pages_used = Column(Integer, nullable=False, default=0, server_default="0") + + premium_credit_micros_limit = Column( BigInteger, nullable=False, - default=config.DEFAULT_CREDIT_MICROS_BALANCE, - server_default=str(config.DEFAULT_CREDIT_MICROS_BALANCE), + default=config.PREMIUM_CREDIT_MICROS_LIMIT, + server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT), ) - # In-flight reservation holds (released/settled at finalize). - credit_micros_reserved = Column( + premium_credit_micros_used = Column( BigInteger, nullable=False, default=0, server_default="0" ) - - # Auto-reload (off-session Stripe top-up), behind AUTO_RELOAD_ENABLED. - # ``stripe_customer_id`` + ``auto_reload_payment_method_id`` are the - # saved-card plumbing; thresholds are micro-USD. ``auto_reload_failed_at`` - # is set (and ``auto_reload_enabled`` flipped off) when an off-session - # charge is declined so the UI can prompt the user to fix their card. - stripe_customer_id = Column(String, nullable=True) - auto_reload_enabled = Column( - Boolean, nullable=False, default=False, server_default="false" + premium_credit_micros_reserved = Column( + BigInteger, nullable=False, default=0, server_default="0" ) - auto_reload_threshold_micros = Column(BigInteger, nullable=True) - auto_reload_amount_micros = Column(BigInteger, nullable=True) - auto_reload_payment_method_id = Column(String, nullable=True) - auto_reload_failed_at = Column(TIMESTAMP(timezone=True), nullable=True) # User profile (can be set manually for non-OAuth users) display_name = Column(String, nullable=True) @@ -2720,38 +2887,8 @@ from app.automations.persistence import ( # noqa: E402, F401 AutomationRun, AutomationTrigger, ) -from app.etl_pipeline.cache.persistence.models import CachedParse # noqa: E402, F401 from app.file_storage.persistence import DocumentFile # noqa: E402, F401 -from app.indexing_pipeline.cache.persistence.models import ( # noqa: E402, F401 - CachedEmbeddingSet, -) from app.notifications.persistence import Notification # noqa: E402, F401 -from app.podcasts.persistence import ( # noqa: E402, F401 - Podcast, - PodcastStatus, -) - - -def _build_connect_args() -> dict: - """Build driver connect_args, including a protective idle-in-transaction - timeout for asyncpg connections. - - A single abandoned ``idle in transaction`` session can hold table/row locks - indefinitely and wedge writes plus boot-time DDL (the classic "FastAPI - stuck at application startup" failure). Setting - ``idle_in_transaction_session_timeout`` server-side makes Postgres reap such - sessions automatically. It never affects sessions that are actively running - statements — only ones that opened a transaction and went idle. - """ - connect_args: dict = {} - idle_ms = config.DB_IDLE_IN_TX_TIMEOUT_MS - # ``server_settings`` is asyncpg-specific; only apply it for that driver. - if idle_ms and idle_ms > 0 and DATABASE_URL and "asyncpg" in DATABASE_URL: - connect_args["server_settings"] = { - "idle_in_transaction_session_timeout": str(idle_ms) - } - return connect_args - engine = create_async_engine( DATABASE_URL, @@ -2760,7 +2897,6 @@ engine = create_async_engine( pool_recycle=1800, pool_pre_ping=True, pool_timeout=30, - connect_args=_build_connect_args(), ) async_session_maker = async_sessionmaker(engine, expire_on_commit=False) @@ -2785,117 +2921,54 @@ async def shielded_async_session(): await session.close() -# (index_name, table, CREATE statement). Built with CONCURRENTLY so an index -# build only takes a non-blocking ShareUpdateExclusiveLock — ingestion -# INSERT/UPDATE on documents/chunks keep flowing while the index builds, and a -# slow build can never freeze the FastAPI lifespan or block writers. -_INDEX_DEFINITIONS: list[tuple[str, str, str]] = [ - ( - "document_vector_index", - "documents", - "CREATE INDEX CONCURRENTLY IF NOT EXISTS document_vector_index ON documents USING hnsw (embedding public.vector_cosine_ops)", - ), - ( - "document_search_index", - "documents", - "CREATE INDEX CONCURRENTLY IF NOT EXISTS document_search_index ON documents USING gin (to_tsvector('english', content))", - ), - ( - "chucks_vector_index", - "chunks", - "CREATE INDEX CONCURRENTLY IF NOT EXISTS chucks_vector_index ON chunks USING hnsw (embedding public.vector_cosine_ops)", - ), - ( - "chucks_search_index", - "chunks", - "CREATE INDEX CONCURRENTLY IF NOT EXISTS chucks_search_index ON chunks USING gin (to_tsvector('english', content))", - ), - # pg_trgm index for efficient ILIKE '%term%' searches on titles — critical - # for the document mention picker (@mentions) to scale. - ( - "idx_documents_title_trgm", - "documents", - "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_documents_title_trgm ON documents USING gin (title gin_trgm_ops)", - ), - ( - "idx_documents_search_space_id", - "documents", - "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_documents_search_space_id ON documents (search_space_id)", - ), - # Covering index for "recent documents" query — enables index-only scan. - ( - "idx_documents_search_space_updated", - "documents", - "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_documents_search_space_updated ON documents (search_space_id, updated_at DESC NULLS LAST) INCLUDE (id, title, document_type)", - ), -] - - -async def _drop_invalid_index(conn, name: str) -> None: - """Drop a leftover *invalid* index so it can be rebuilt. - - A ``CREATE INDEX CONCURRENTLY`` that is interrupted (timeout, crash, - cancellation) leaves behind an ``indisvalid = false`` index. Because the - name now exists, a later ``CREATE INDEX CONCURRENTLY IF NOT EXISTS`` would - skip it and the broken index would persist forever. Detect and drop it - first. - """ - result = await conn.execute( - text("SELECT indisvalid FROM pg_index WHERE indexrelid = to_regclass(:n)"), - {"n": name}, - ) - row = result.first() - if row is not None and row[0] is False: - logger.warning( - "[startup] dropping invalid leftover index %s before rebuild", name +async def setup_indexes(): + async with engine.begin() as conn: + # Create indexes + # Document embedding indexes + await conn.execute( + text( + "CREATE INDEX IF NOT EXISTS document_vector_index ON documents USING hnsw (embedding public.vector_cosine_ops)" + ) + ) + await conn.execute( + text( + "CREATE INDEX IF NOT EXISTS document_search_index ON documents USING gin (to_tsvector('english', content))" + ) + ) + # Document Chuck Indexes + await conn.execute( + text( + "CREATE INDEX IF NOT EXISTS chucks_vector_index ON chunks USING hnsw (embedding public.vector_cosine_ops)" + ) + ) + await conn.execute( + text( + "CREATE INDEX IF NOT EXISTS chucks_search_index ON chunks USING gin (to_tsvector('english', content))" + ) + ) + # pg_trgm indexes for efficient ILIKE '%term%' searches on titles + # Critical for document mention picker (@mentions) to scale + await conn.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_documents_title_trgm ON documents USING gin (title gin_trgm_ops)" + ) + ) + # B-tree index on search_space_id for fast filtering + await conn.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_documents_search_space_id ON documents (search_space_id)" + ) + ) + # Covering index for "recent documents" query - enables index-only scan + await conn.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_documents_search_space_updated ON documents (search_space_id, updated_at DESC NULLS LAST) INCLUDE (id, title, document_type)" + ) ) - await conn.execute(text(f'DROP INDEX CONCURRENTLY IF EXISTS "{name}"')) - - -async def setup_indexes() -> None: - """Ensure search/vector indexes exist without ever blocking startup. - - Each index is created with ``CONCURRENTLY`` (so it never takes a blocking - SHARE lock on documents/chunks) under a short per-session ``lock_timeout`` - (so a contended boot fails fast instead of hanging the lifespan forever). - Failures are logged and swallowed per-index — a missing index just gets - retried on the next boot rather than crash-looping the API. - """ - lock_timeout_ms = int(config.DB_DDL_LOCK_TIMEOUT_MS) - # AUTOCOMMIT is mandatory: CREATE INDEX CONCURRENTLY cannot run inside a - # transaction block. - async with engine.connect() as base_conn: - conn = await base_conn.execution_options(isolation_level="AUTOCOMMIT") - await conn.execute(text(f"SET lock_timeout = {lock_timeout_ms}")) - for name, table, ddl in _INDEX_DEFINITIONS: - try: - await _drop_invalid_index(conn, name) - await conn.execute(text(ddl)) - except Exception as exc: - # Non-fatal by design: a missing index is retried next boot. - logger.warning( - "[startup] index %s on %s not ready (%s: %s); " - "will retry on next boot", - name, - table, - exc.__class__.__name__, - exc, - ) async def create_db_and_tables(): - if not config.DB_BOOTSTRAP_ON_STARTUP: - logger.info( - "[startup] DB bootstrap skipped (DB_BOOTSTRAP_ON_STARTUP=FALSE); " - "schema/indexes are expected to be managed by migrations" - ) - return - - lock_timeout_ms = int(config.DB_DDL_LOCK_TIMEOUT_MS) async with engine.begin() as conn: - # Fail fast instead of hanging forever if another session holds a - # conflicting lock on a table we need to touch. - await conn.execute(text(f"SET LOCAL lock_timeout = {lock_timeout_ms}")) await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) await conn.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm")) await conn.run_sync(Base.metadata.create_all) diff --git a/surfsense_backend/app/etl_pipeline/cache/__init__.py b/surfsense_backend/app/etl_pipeline/cache/__init__.py deleted file mode 100644 index 3f4585778..000000000 --- a/surfsense_backend/app/etl_pipeline/cache/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Content-addressed reuse of expensive ETL parser output across workspaces.""" - -from __future__ import annotations - -from app.etl_pipeline.cache.cached_extraction import extract_with_cache -from app.etl_pipeline.cache.service import EtlCacheService - -__all__ = [ - "EtlCacheService", - "extract_with_cache", -] diff --git a/surfsense_backend/app/etl_pipeline/cache/cached_extraction.py b/surfsense_backend/app/etl_pipeline/cache/cached_extraction.py deleted file mode 100644 index de4186b69..000000000 --- a/surfsense_backend/app/etl_pipeline/cache/cached_extraction.py +++ /dev/null @@ -1,86 +0,0 @@ -"""Entry point: serve ETL parses from cache, parsing only on a miss.""" - -from __future__ import annotations - -import asyncio -import hashlib -import logging - -from app.config import config -from app.etl_pipeline.cache.eligibility import is_parse_cacheable -from app.etl_pipeline.cache.schemas import ParseKey -from app.etl_pipeline.cache.service import EtlCacheService -from app.etl_pipeline.cache.settings import load_etl_cache_settings -from app.etl_pipeline.etl_document import EtlRequest, EtlResult -from app.etl_pipeline.etl_pipeline_service import EtlPipelineService -from app.observability import metrics - -logger = logging.getLogger(__name__) - -_HASH_CHUNK = 1024 * 1024 - - -async def extract_with_cache(request: EtlRequest, *, vision_llm=None) -> EtlResult: - """Drop-in for ``EtlPipelineService.extract`` that reuses prior parser output.""" - settings = load_etl_cache_settings() - - cacheable = is_parse_cacheable( - filename=request.filename, - etl_service=config.ETL_SERVICE, - cache_enabled=settings.enabled, - has_vision_llm=vision_llm is not None, - ) - if not cacheable: - return await EtlPipelineService(vision_llm=vision_llm).extract(request) - - key = ParseKey.for_document( - await asyncio.to_thread(_hash_file, request.file_path), - etl_service=config.ETL_SERVICE, - mode=request.processing_mode.value, - version=settings.parser_version, - ) - - cached_result = await _recall(key) - if cached_result is not None: - metrics.record_etl_cache_lookup( - etl_service=key.etl_service, mode=key.mode, outcome="hit" - ) - logger.debug("ETL cache hit for %s", key.source_sha256) - return cached_result - - metrics.record_etl_cache_lookup( - etl_service=key.etl_service, mode=key.mode, outcome="miss" - ) - result = await EtlPipelineService(vision_llm=vision_llm).extract(request) - await _remember(key, result) - return result - - -async def _recall(key: ParseKey) -> EtlResult | None: - # Caching is best-effort: any failure falls through to a normal parse. - try: - from app.tasks.celery_tasks import get_celery_session_maker - - async with get_celery_session_maker()() as session: - return await EtlCacheService(session).recall(key) - except Exception: - logger.warning("ETL cache recall failed; parsing fresh", exc_info=True) - return None - - -async def _remember(key: ParseKey, result: EtlResult) -> None: - try: - from app.tasks.celery_tasks import get_celery_session_maker - - async with get_celery_session_maker()() as session: - await EtlCacheService(session).remember(key, result) - except Exception: - logger.warning("ETL cache write failed; result not cached", exc_info=True) - - -def _hash_file(path: str) -> str: - digest = hashlib.sha256() - with open(path, "rb") as handle: - for chunk in iter(lambda: handle.read(_HASH_CHUNK), b""): - digest.update(chunk) - return digest.hexdigest() diff --git a/surfsense_backend/app/etl_pipeline/cache/eligibility.py b/surfsense_backend/app/etl_pipeline/cache/eligibility.py deleted file mode 100644 index 18f096218..000000000 --- a/surfsense_backend/app/etl_pipeline/cache/eligibility.py +++ /dev/null @@ -1,28 +0,0 @@ -"""Gating rule: may this upload be served from / written to the parse cache?""" - -from __future__ import annotations - -from app.etl_pipeline.file_classifier import FileCategory, classify_file - - -def is_parse_cacheable( - *, - filename: str, - etl_service: str | None, - cache_enabled: bool, - has_vision_llm: bool, -) -> bool: - """Only deterministic document parses are shareable across workspaces. - - Vision-LLM runs append model-generated content not captured by the cache key, - and a missing ETL service means there is no document parser to key against -- - both bypass the cache. Non-document categories (plaintext, audio, images, - direct-convert) are cheap or parser-agnostic and are handled outside it. - """ - if not cache_enabled: - return False - if has_vision_llm: - return False - if not etl_service: - return False - return classify_file(filename) == FileCategory.DOCUMENT diff --git a/surfsense_backend/app/etl_pipeline/cache/eviction/__init__.py b/surfsense_backend/app/etl_pipeline/cache/eviction/__init__.py deleted file mode 100644 index f47b9c4e0..000000000 --- a/surfsense_backend/app/etl_pipeline/cache/eviction/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Background pruning of the parse cache by age and size budget.""" - -from __future__ import annotations - -from .task import evict_etl_cache_task - -__all__ = [ - "evict_etl_cache_task", -] diff --git a/surfsense_backend/app/etl_pipeline/cache/eviction/policy.py b/surfsense_backend/app/etl_pipeline/cache/eviction/policy.py deleted file mode 100644 index 5a80752d6..000000000 --- a/surfsense_backend/app/etl_pipeline/cache/eviction/policy.py +++ /dev/null @@ -1,28 +0,0 @@ -"""Pure selection rules for which cached entries to drop.""" - -from __future__ import annotations - -from collections.abc import Iterable - -from app.etl_pipeline.cache.schemas import EvictionCandidate - - -def select_over_budget( - coldest_first: Iterable[EvictionCandidate], - *, - current_total_bytes: int, - max_total_bytes: int, -) -> list[EvictionCandidate]: - """Pick coldest entries until the footprint drops under the budget.""" - bytes_to_free = current_total_bytes - max_total_bytes - if bytes_to_free <= 0: - return [] - - chosen: list[EvictionCandidate] = [] - bytes_freed = 0 - for candidate in coldest_first: - if bytes_freed >= bytes_to_free: - break - chosen.append(candidate) - bytes_freed += candidate.size_bytes - return chosen diff --git a/surfsense_backend/app/etl_pipeline/cache/eviction/task.py b/surfsense_backend/app/etl_pipeline/cache/eviction/task.py deleted file mode 100644 index 61433f8a7..000000000 --- a/surfsense_backend/app/etl_pipeline/cache/eviction/task.py +++ /dev/null @@ -1,68 +0,0 @@ -"""Celery task that prunes the parse cache by TTL, then by size budget.""" - -from __future__ import annotations - -import contextlib -import logging -from datetime import UTC, datetime, timedelta - -from app.celery_app import celery_app -from app.etl_pipeline.cache.eviction.policy import select_over_budget -from app.etl_pipeline.cache.persistence import CachedParseRepository -from app.etl_pipeline.cache.schemas import EvictionCandidate -from app.etl_pipeline.cache.settings import load_etl_cache_settings -from app.etl_pipeline.cache.storage import MarkdownCacheStore -from app.observability import metrics -from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task - -logger = logging.getLogger(__name__) - - -@celery_app.task(name="evict_etl_cache") -def evict_etl_cache_task(): - return run_async_celery_task(_evict) - - -async def _evict() -> None: - """Expire stale entries, then shed the coldest overflow only if still over budget.""" - settings = load_etl_cache_settings() - if not settings.enabled: - return - - store = MarkdownCacheStore() - async with get_celery_session_maker()() as session: - index = CachedParseRepository(session) - - cutoff = datetime.now(UTC) - timedelta(days=settings.ttl_days) - expired = await index.select_expired( - cutoff=cutoff, limit=settings.eviction_batch - ) - await _drop(index, store, expired, phase="ttl") - - total = await index.total_size_bytes() - if total > settings.max_total_bytes: - coldest = await index.select_coldest(limit=settings.eviction_batch) - over_budget = select_over_budget( - coldest, - current_total_bytes=total, - max_total_bytes=settings.max_total_bytes, - ) - await _drop(index, store, over_budget, phase="size") - - -async def _drop( - index: CachedParseRepository, - store: MarkdownCacheStore, - candidates: list[EvictionCandidate], - *, - phase: str, -) -> None: - if not candidates: - return - for candidate in candidates: - # Drop the index row even if the blob delete fails (orphan blob is harmless). - with contextlib.suppress(Exception): - await store.delete(candidate.storage_key) - await index.delete_by_ids([candidate.id for candidate in candidates]) - metrics.record_etl_cache_eviction(len(candidates), phase=phase) - logger.info("Evicted %d cached parses (%s)", len(candidates), phase) diff --git a/surfsense_backend/app/etl_pipeline/cache/persistence/__init__.py b/surfsense_backend/app/etl_pipeline/cache/persistence/__init__.py deleted file mode 100644 index 666e4cfa8..000000000 --- a/surfsense_backend/app/etl_pipeline/cache/persistence/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Database access for cached parse rows.""" - -from __future__ import annotations - -from .models import CachedParse -from .repository import CachedParseRepository - -__all__ = [ - "CachedParse", - "CachedParseRepository", -] diff --git a/surfsense_backend/app/etl_pipeline/cache/persistence/models.py b/surfsense_backend/app/etl_pipeline/cache/persistence/models.py deleted file mode 100644 index bd20bdd12..000000000 --- a/surfsense_backend/app/etl_pipeline/cache/persistence/models.py +++ /dev/null @@ -1,49 +0,0 @@ -"""``etl_cache_parses``: one reusable parser result per (bytes + recipe).""" - -from __future__ import annotations - -from sqlalchemy import ( - BigInteger, - Column, - DateTime, - Index, - Integer, - String, - UniqueConstraint, -) - -from app.db import BaseModel, TimestampMixin - - -class CachedParse(BaseModel, TimestampMixin): - __tablename__ = "etl_cache_parses" - - # Key: raw bytes + the recipe that produced the markdown. - source_sha256 = Column(String(64), nullable=False) - etl_service = Column(String(32), nullable=False) - mode = Column(String(16), nullable=False) - parser_version = Column(Integer, nullable=False) - - # Where the markdown blob lives (kept out of the row to stay small). - storage_backend = Column(String(32), nullable=False) - storage_key = Column(String, nullable=False) - size_bytes = Column(BigInteger, nullable=False) - - # Payload needed to rebuild the EtlResult on a hit. - content_type = Column(String(32), nullable=False) - actual_pages = Column(Integer, nullable=False, default=0, server_default="0") - - # Drives eviction (popularity + recency). - times_reused = Column(BigInteger, nullable=False, default=0, server_default="0") - last_used_at = Column(DateTime(timezone=True), nullable=False) - - __table_args__ = ( - UniqueConstraint( - "source_sha256", - "etl_service", - "mode", - "parser_version", - name="uq_etl_cache_parses_key", - ), - Index("ix_etl_cache_parses_last_used_at", "last_used_at"), - ) diff --git a/surfsense_backend/app/etl_pipeline/cache/persistence/repository.py b/surfsense_backend/app/etl_pipeline/cache/persistence/repository.py deleted file mode 100644 index 05f40eae5..000000000 --- a/surfsense_backend/app/etl_pipeline/cache/persistence/repository.py +++ /dev/null @@ -1,121 +0,0 @@ -"""CRUD and eviction selectors for ``etl_cache_parses`` (no business rules).""" - -from __future__ import annotations - -from datetime import UTC, datetime - -from sqlalchemy import delete, func, select, update -from sqlalchemy.dialects.postgresql import insert as pg_insert -from sqlalchemy.ext.asyncio import AsyncSession - -from app.etl_pipeline.cache.schemas import EvictionCandidate, ParseKey - -from .models import CachedParse - -_EVICTION_COLUMNS = ( - CachedParse.id, - CachedParse.storage_key, - CachedParse.size_bytes, - CachedParse.last_used_at, - CachedParse.times_reused, -) - - -def _as_eviction_candidate(row) -> EvictionCandidate: - return EvictionCandidate( - id=row.id, - storage_key=row.storage_key, - size_bytes=row.size_bytes, - last_used_at=row.last_used_at, - times_reused=row.times_reused, - ) - - -class CachedParseRepository: - def __init__(self, session: AsyncSession) -> None: - self._session = session - - async def get(self, key: ParseKey) -> CachedParse | None: - result = await self._session.execute( - select(CachedParse).where( - CachedParse.source_sha256 == key.source_sha256, - CachedParse.etl_service == key.etl_service, - CachedParse.mode == key.mode, - CachedParse.parser_version == key.version, - ) - ) - return result.scalars().first() - - async def insert( - self, - *, - key: ParseKey, - content_type: str, - actual_pages: int, - storage_backend: str, - storage_key: str, - size_bytes: int, - ) -> None: - # Concurrent writers parse identical bytes, so a lost race is harmless. - now = datetime.now(UTC) - await self._session.execute( - pg_insert(CachedParse) - .values( - source_sha256=key.source_sha256, - etl_service=key.etl_service, - mode=key.mode, - parser_version=key.version, - content_type=content_type, - actual_pages=actual_pages, - storage_backend=storage_backend, - storage_key=storage_key, - size_bytes=size_bytes, - times_reused=0, - last_used_at=now, - created_at=now, - ) - .on_conflict_do_nothing(constraint="uq_etl_cache_parses_key") - ) - await self._session.commit() - - async def mark_used(self, row_id: int) -> None: - await self._session.execute( - update(CachedParse) - .where(CachedParse.id == row_id) - .values( - times_reused=CachedParse.times_reused + 1, - last_used_at=datetime.now(UTC), - ) - ) - await self._session.commit() - - async def total_size_bytes(self) -> int: - result = await self._session.execute( - select(func.coalesce(func.sum(CachedParse.size_bytes), 0)) - ) - return int(result.scalar() or 0) - - async def select_expired( - self, *, cutoff: datetime, limit: int - ) -> list[EvictionCandidate]: - result = await self._session.execute( - select(*_EVICTION_COLUMNS) - .where(CachedParse.last_used_at < cutoff) - .order_by(CachedParse.last_used_at.asc()) - .limit(limit) - ) - return [_as_eviction_candidate(row) for row in result] - - async def select_coldest(self, *, limit: int) -> list[EvictionCandidate]: - result = await self._session.execute( - select(*_EVICTION_COLUMNS) - .order_by(CachedParse.times_reused.asc(), CachedParse.last_used_at.asc()) - .limit(limit) - ) - return [_as_eviction_candidate(row) for row in result] - - async def delete_by_ids(self, ids: list[int]) -> None: - if not ids: - return - await self._session.execute(delete(CachedParse).where(CachedParse.id.in_(ids))) - await self._session.commit() diff --git a/surfsense_backend/app/etl_pipeline/cache/schemas/__init__.py b/surfsense_backend/app/etl_pipeline/cache/schemas/__init__.py deleted file mode 100644 index c88ac0c72..000000000 --- a/surfsense_backend/app/etl_pipeline/cache/schemas/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Pure value objects for the parse cache.""" - -from __future__ import annotations - -from .eviction_candidate import EvictionCandidate -from .parse_key import ParseKey - -__all__ = [ - "EvictionCandidate", - "ParseKey", -] diff --git a/surfsense_backend/app/etl_pipeline/cache/schemas/eviction_candidate.py b/surfsense_backend/app/etl_pipeline/cache/schemas/eviction_candidate.py deleted file mode 100644 index 13a903e7d..000000000 --- a/surfsense_backend/app/etl_pipeline/cache/schemas/eviction_candidate.py +++ /dev/null @@ -1,15 +0,0 @@ -"""Row projection handed to the eviction policy.""" - -from __future__ import annotations - -from dataclasses import dataclass -from datetime import datetime - - -@dataclass(frozen=True, slots=True) -class EvictionCandidate: - id: int - storage_key: str - size_bytes: int - last_used_at: datetime - times_reused: int diff --git a/surfsense_backend/app/etl_pipeline/cache/schemas/parse_key.py b/surfsense_backend/app/etl_pipeline/cache/schemas/parse_key.py deleted file mode 100644 index 88133a418..000000000 --- a/surfsense_backend/app/etl_pipeline/cache/schemas/parse_key.py +++ /dev/null @@ -1,28 +0,0 @@ -"""Identity of a cacheable parse: equal keys yield identical markdown.""" - -from __future__ import annotations - -from dataclasses import dataclass - - -@dataclass(frozen=True, slots=True) -class ParseKey: - source_sha256: str - etl_service: str - mode: str - version: int - - @classmethod - def for_document( - cls, source_sha256: str, *, etl_service: str, mode: str, version: int - ) -> ParseKey: - return cls( - source_sha256=source_sha256, - etl_service=etl_service, - mode=mode, - version=version, - ) - - @property - def object_suffix(self) -> str: - return f"{self.etl_service}.{self.mode}.v{self.version}.md" diff --git a/surfsense_backend/app/etl_pipeline/cache/service.py b/surfsense_backend/app/etl_pipeline/cache/service.py deleted file mode 100644 index 49398faf8..000000000 --- a/surfsense_backend/app/etl_pipeline/cache/service.py +++ /dev/null @@ -1,53 +0,0 @@ -"""Recall and remember parser output, coordinating the index and blob store.""" - -from __future__ import annotations - -import logging - -from sqlalchemy.ext.asyncio import AsyncSession - -from app.etl_pipeline.cache.persistence import CachedParseRepository -from app.etl_pipeline.cache.schemas import ParseKey -from app.etl_pipeline.cache.storage import MarkdownCacheStore -from app.etl_pipeline.etl_document import EtlResult - -logger = logging.getLogger(__name__) - - -class EtlCacheService: - def __init__(self, session: AsyncSession) -> None: - self._index = CachedParseRepository(session) - self._store = MarkdownCacheStore() - - async def recall(self, key: ParseKey) -> EtlResult | None: - """Return the cached result, or None on a miss.""" - row = await self._index.get(key) - if row is None: - return None - - try: - markdown = await self._store.load(row.storage_key) - except Exception: - # Index points at a blob that is gone; treat as a miss and re-parse. - logger.warning("Cache blob missing: %s", row.storage_key, exc_info=True) - return None - - await self._index.mark_used(row.id) - return EtlResult( - markdown_content=markdown, - etl_service=row.etl_service, - actual_pages=row.actual_pages, - content_type=row.content_type, - ) - - async def remember(self, key: ParseKey, result: EtlResult) -> None: - """Store a freshly parsed result for future reuse.""" - storage_key = await self._store.save(key, result.markdown_content) - await self._index.insert( - key=key, - content_type=result.content_type, - actual_pages=result.actual_pages, - storage_backend=self._store.backend_name, - storage_key=storage_key, - size_bytes=len(result.markdown_content.encode("utf-8")), - ) diff --git a/surfsense_backend/app/etl_pipeline/cache/settings.py b/surfsense_backend/app/etl_pipeline/cache/settings.py deleted file mode 100644 index 5911ea222..000000000 --- a/surfsense_backend/app/etl_pipeline/cache/settings.py +++ /dev/null @@ -1,33 +0,0 @@ -"""Cache configuration resolved from the central ``Config``.""" - -from __future__ import annotations - -from dataclasses import dataclass - - -@dataclass(frozen=True) -class EtlCacheSettings: - enabled: bool - parser_version: int - ttl_days: int - max_total_bytes: int - eviction_batch: int - # None for any storage_* field means: reuse the main file_storage backend. - storage_backend: str | None - storage_container: str | None - storage_local_root: str | None - - -def load_etl_cache_settings() -> EtlCacheSettings: - from app.config import config - - return EtlCacheSettings( - enabled=config.ETL_CACHE_ENABLED, - parser_version=config.ETL_CACHE_PARSER_VERSION, - ttl_days=config.ETL_CACHE_TTL_DAYS, - max_total_bytes=config.ETL_CACHE_MAX_TOTAL_MB * 1024 * 1024, - eviction_batch=config.ETL_CACHE_EVICTION_BATCH, - storage_backend=config.ETL_CACHE_STORAGE_BACKEND or None, - storage_container=config.ETL_CACHE_STORAGE_CONTAINER or None, - storage_local_root=config.ETL_CACHE_STORAGE_LOCAL_PATH or None, - ) diff --git a/surfsense_backend/app/etl_pipeline/cache/storage/__init__.py b/surfsense_backend/app/etl_pipeline/cache/storage/__init__.py deleted file mode 100644 index bed39c510..000000000 --- a/surfsense_backend/app/etl_pipeline/cache/storage/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Blob storage for cached parse markdown.""" - -from __future__ import annotations - -from .markdown_store import MarkdownCacheStore - -__all__ = [ - "MarkdownCacheStore", -] diff --git a/surfsense_backend/app/etl_pipeline/cache/storage/backend.py b/surfsense_backend/app/etl_pipeline/cache/storage/backend.py deleted file mode 100644 index 4f68ac0d3..000000000 --- a/surfsense_backend/app/etl_pipeline/cache/storage/backend.py +++ /dev/null @@ -1,48 +0,0 @@ -"""Resolve the storage backend for cache blobs: shared main store or a dedicated one.""" - -from __future__ import annotations - -from functools import lru_cache - -from app.file_storage.backends.base import StorageBackend - - -@lru_cache(maxsize=1) -def resolve_cache_backend() -> StorageBackend: - from app.etl_pipeline.cache.settings import load_etl_cache_settings - - settings = load_etl_cache_settings() - - if not settings.storage_backend: - from app.file_storage.factory import get_storage_backend - - return get_storage_backend() - - backend = settings.storage_backend.strip().lower() - - if backend == "azure": - from app.config import config - - if not settings.storage_container: - raise ValueError("ETL_CACHE_STORAGE_CONTAINER is required for azure cache.") - if not config.AZURE_STORAGE_CONNECTION_STRING: - raise ValueError( - "AZURE_STORAGE_CONNECTION_STRING is required for azure cache." - ) - from app.file_storage.backends.azure import AzureBlobBackend - - return AzureBlobBackend( - connection_string=config.AZURE_STORAGE_CONNECTION_STRING, - container=settings.storage_container, - ) - - if backend == "local": - if not settings.storage_local_root: - raise ValueError( - "ETL_CACHE_STORAGE_LOCAL_PATH is required for local cache." - ) - from app.file_storage.backends.local import LocalFileBackend - - return LocalFileBackend(settings.storage_local_root) - - raise ValueError(f"Unknown ETL_CACHE_STORAGE_BACKEND: {settings.storage_backend!r}") diff --git a/surfsense_backend/app/etl_pipeline/cache/storage/markdown_store.py b/surfsense_backend/app/etl_pipeline/cache/storage/markdown_store.py deleted file mode 100644 index 189f3508b..000000000 --- a/surfsense_backend/app/etl_pipeline/cache/storage/markdown_store.py +++ /dev/null @@ -1,35 +0,0 @@ -"""Read and write cached markdown blobs through the resolved backend.""" - -from __future__ import annotations - -from app.etl_pipeline.cache.schemas import ParseKey -from app.etl_pipeline.cache.storage.backend import resolve_cache_backend -from app.etl_pipeline.cache.storage.object_keys import build_parse_object_key - -_MARKDOWN_CONTENT_TYPE = "text/markdown; charset=utf-8" - - -class MarkdownCacheStore: - def __init__(self) -> None: - self._backend = resolve_cache_backend() - - @property - def backend_name(self) -> str: - return self._backend.backend_name - - async def save(self, key: ParseKey, markdown: str) -> str: - """Persist the markdown and return its storage key for the index row.""" - storage_key = build_parse_object_key(key) - await self._backend.put( - storage_key, - markdown.encode("utf-8"), - content_type=_MARKDOWN_CONTENT_TYPE, - ) - return storage_key - - async def load(self, storage_key: str) -> str: - chunks = [chunk async for chunk in self._backend.open_stream(storage_key)] - return b"".join(chunks).decode("utf-8") - - async def delete(self, storage_key: str) -> None: - await self._backend.delete(storage_key) diff --git a/surfsense_backend/app/etl_pipeline/cache/storage/object_keys.py b/surfsense_backend/app/etl_pipeline/cache/storage/object_keys.py deleted file mode 100644 index 7b89c3f92..000000000 --- a/surfsense_backend/app/etl_pipeline/cache/storage/object_keys.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Object keys for cached markdown, namespaced under a dedicated prefix.""" - -from __future__ import annotations - -from app.etl_pipeline.cache.schemas import ParseKey - -CACHE_PREFIX = "etl_cache" - - -def build_parse_object_key(key: ParseKey) -> str: - # Content-addressed: identical bytes + recipe always map to the same key. - return f"{CACHE_PREFIX}/{key.source_sha256}/{key.object_suffix}" diff --git a/surfsense_backend/app/gateway/__init__.py b/surfsense_backend/app/gateway/__init__.py index 89b931bc3..8b79b3160 100644 --- a/surfsense_backend/app/gateway/__init__.py +++ b/surfsense_backend/app/gateway/__init__.py @@ -8,7 +8,7 @@ from app.config import config def require_gateway_enabled() -> None: - """FastAPI dependency that gates gateway operational routes on the global flag. + """FastAPI dependency that gates all gateway HTTP routes on the global flag. Returns 404 (rather than 503) when ``GATEWAY_ENABLED`` is FALSE so that disabling the gateway makes its webhook/OAuth/pairing surface indistinguishable diff --git a/surfsense_backend/app/indexing_pipeline/cache/__init__.py b/surfsense_backend/app/indexing_pipeline/cache/__init__.py deleted file mode 100644 index d3b9e5f0d..000000000 --- a/surfsense_backend/app/indexing_pipeline/cache/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Content-addressed reuse of chunk+embedding output across workspaces.""" - -from __future__ import annotations - -from app.indexing_pipeline.cache.cached_indexing import build_chunk_embeddings -from app.indexing_pipeline.cache.service import EmbeddingCacheService - -__all__ = [ - "EmbeddingCacheService", - "build_chunk_embeddings", -] diff --git a/surfsense_backend/app/indexing_pipeline/cache/cached_indexing.py b/surfsense_backend/app/indexing_pipeline/cache/cached_indexing.py deleted file mode 100644 index 95321a229..000000000 --- a/surfsense_backend/app/indexing_pipeline/cache/cached_indexing.py +++ /dev/null @@ -1,129 +0,0 @@ -"""Entry point: serve chunk embeddings from cache, embedding only on a miss. - -Embeddings are a pure function of the markdown, the embedding model, and the -chunker -- so identical markdown is chunked and embedded once and reused across -workspaces, even when it came from different sources. -""" - -from __future__ import annotations - -import asyncio -import hashlib -import logging - -import numpy as np - -from app.config import config -from app.indexing_pipeline.cache.eligibility import is_embedding_cacheable -from app.indexing_pipeline.cache.schemas import CachedChunk, EmbeddingKey, EmbeddingSet -from app.indexing_pipeline.cache.service import EmbeddingCacheService -from app.indexing_pipeline.cache.settings import load_embedding_cache_settings -from app.indexing_pipeline.document_chunker import chunk_text, chunk_text_hybrid -from app.indexing_pipeline.document_embedder import embed_texts -from app.observability import metrics - -logger = logging.getLogger(__name__) - -ChunkPair = tuple[str, np.ndarray] - - -async def build_chunk_embeddings( - markdown: str, *, use_code_chunker: bool -) -> tuple[np.ndarray, list[ChunkPair]]: - """Return the document-level vector and ordered ``(chunk_text, vector)`` pairs. - - Drop-in for the inline chunk+embed step; reuses prior output when the same - markdown has already been embedded with the current model and chunker. - """ - settings = load_embedding_cache_settings() - chunker_kind = "code" if use_code_chunker else "hybrid" - embedding_dim = getattr(config.embedding_model_instance, "dimension", None) - - cacheable = is_embedding_cacheable( - cache_enabled=settings.enabled, - embedding_model=config.EMBEDDING_MODEL, - embedding_dim=embedding_dim, - ) - if not cacheable: - return await _compute(markdown, use_code_chunker=use_code_chunker) - - key = EmbeddingKey( - markdown_sha256=_hash_text(markdown), - embedding_model=config.EMBEDDING_MODEL, - embedding_dim=int(embedding_dim), - chunker_kind=chunker_kind, - chunker_version=settings.chunker_version, - ) - - cached = await _recall(key) - if cached is not None: - metrics.record_embedding_cache_lookup( - embedding_model=key.embedding_model, - chunker_kind=chunker_kind, - outcome="hit", - ) - logger.debug("Embedding cache hit for %s", key.markdown_sha256) - return cached.summary_embedding, [(c.text, c.embedding) for c in cached.chunks] - - metrics.record_embedding_cache_lookup( - embedding_model=key.embedding_model, chunker_kind=chunker_kind, outcome="miss" - ) - summary_embedding, chunk_pairs = await _compute( - markdown, use_code_chunker=use_code_chunker - ) - await _remember(key, summary_embedding, chunk_pairs) - return summary_embedding, chunk_pairs - - -async def chunk_markdown(markdown: str, *, use_code_chunker: bool) -> list[str]: - """Chunk markdown into ordered texts with the pipeline's chunker selection.""" - if use_code_chunker: - return await asyncio.to_thread(chunk_text, markdown, use_code_chunker=True) - # Table-aware hybrid chunker keeps Markdown tables intact (issue #1334). - return await asyncio.to_thread(chunk_text_hybrid, markdown) - - -async def embed_batch(texts: list[str]) -> list[np.ndarray]: - """Embed texts in one batch off the event loop.""" - return await asyncio.to_thread(embed_texts, texts) - - -async def _compute( - markdown: str, *, use_code_chunker: bool -) -> tuple[np.ndarray, list[ChunkPair]]: - chunk_texts = await chunk_markdown(markdown, use_code_chunker=use_code_chunker) - embeddings = await embed_batch([markdown, *chunk_texts]) - summary_embedding, *chunk_embeddings = embeddings - return summary_embedding, list(zip(chunk_texts, chunk_embeddings, strict=False)) - - -async def _recall(key: EmbeddingKey) -> EmbeddingSet | None: - # Caching is best-effort: any failure falls through to a normal embed. - try: - from app.tasks.celery_tasks import get_celery_session_maker - - async with get_celery_session_maker()() as session: - return await EmbeddingCacheService(session).recall(key) - except Exception: - logger.warning("Embedding cache recall failed; embedding fresh", exc_info=True) - return None - - -async def _remember( - key: EmbeddingKey, summary_embedding: np.ndarray, chunk_pairs: list[ChunkPair] -) -> None: - try: - from app.tasks.celery_tasks import get_celery_session_maker - - embedding_set = EmbeddingSet( - summary_embedding=summary_embedding, - chunks=[CachedChunk(text=text, embedding=vec) for text, vec in chunk_pairs], - ) - async with get_celery_session_maker()() as session: - await EmbeddingCacheService(session).remember(key, embedding_set) - except Exception: - logger.warning("Embedding cache write failed; result not cached", exc_info=True) - - -def _hash_text(text: str) -> str: - return hashlib.sha256(text.encode("utf-8")).hexdigest() diff --git a/surfsense_backend/app/indexing_pipeline/cache/eligibility.py b/surfsense_backend/app/indexing_pipeline/cache/eligibility.py deleted file mode 100644 index 446bea2f8..000000000 --- a/surfsense_backend/app/indexing_pipeline/cache/eligibility.py +++ /dev/null @@ -1,21 +0,0 @@ -"""Gating rule: may this document be served from / written to the embedding cache?""" - -from __future__ import annotations - - -def is_embedding_cacheable( - *, - cache_enabled: bool, - embedding_model: str | None, - embedding_dim: int | None, -) -> bool: - """Cache only when a concrete embedding model and dimension are configured. - - Without a model there is nothing to key against, and without a dimension the - blob's integrity guard cannot run -- both bypass the cache. - """ - if not cache_enabled: - return False - if not embedding_model: - return False - return bool(embedding_dim) diff --git a/surfsense_backend/app/indexing_pipeline/cache/eviction/__init__.py b/surfsense_backend/app/indexing_pipeline/cache/eviction/__init__.py deleted file mode 100644 index a0f74b360..000000000 --- a/surfsense_backend/app/indexing_pipeline/cache/eviction/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Background pruning of the embedding cache by age and size budget.""" - -from __future__ import annotations - -from .task import evict_embedding_cache_task - -__all__ = [ - "evict_embedding_cache_task", -] diff --git a/surfsense_backend/app/indexing_pipeline/cache/eviction/task.py b/surfsense_backend/app/indexing_pipeline/cache/eviction/task.py deleted file mode 100644 index 70eff6ea5..000000000 --- a/surfsense_backend/app/indexing_pipeline/cache/eviction/task.py +++ /dev/null @@ -1,68 +0,0 @@ -"""Celery task that prunes the embedding cache by TTL, then by size budget.""" - -from __future__ import annotations - -import contextlib -import logging -from datetime import UTC, datetime, timedelta - -from app.celery_app import celery_app -from app.etl_pipeline.cache.eviction.policy import select_over_budget -from app.etl_pipeline.cache.schemas import EvictionCandidate -from app.indexing_pipeline.cache.persistence import CachedEmbeddingSetRepository -from app.indexing_pipeline.cache.settings import load_embedding_cache_settings -from app.indexing_pipeline.cache.storage import EmbeddingCacheStore -from app.observability import metrics -from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task - -logger = logging.getLogger(__name__) - - -@celery_app.task(name="evict_embedding_cache") -def evict_embedding_cache_task(): - return run_async_celery_task(_evict) - - -async def _evict() -> None: - """Expire stale entries, then shed the coldest overflow only if still over budget.""" - settings = load_embedding_cache_settings() - if not settings.enabled: - return - - store = EmbeddingCacheStore() - async with get_celery_session_maker()() as session: - index = CachedEmbeddingSetRepository(session) - - cutoff = datetime.now(UTC) - timedelta(days=settings.ttl_days) - expired = await index.select_expired( - cutoff=cutoff, limit=settings.eviction_batch - ) - await _drop(index, store, expired, phase="ttl") - - total = await index.total_size_bytes() - if total > settings.max_total_bytes: - coldest = await index.select_coldest(limit=settings.eviction_batch) - over_budget = select_over_budget( - coldest, - current_total_bytes=total, - max_total_bytes=settings.max_total_bytes, - ) - await _drop(index, store, over_budget, phase="size") - - -async def _drop( - index: CachedEmbeddingSetRepository, - store: EmbeddingCacheStore, - candidates: list[EvictionCandidate], - *, - phase: str, -) -> None: - if not candidates: - return - for candidate in candidates: - # Drop the index row even if the blob delete fails (orphan blob is harmless). - with contextlib.suppress(Exception): - await store.delete(candidate.storage_key) - await index.delete_by_ids([candidate.id for candidate in candidates]) - metrics.record_embedding_cache_eviction(len(candidates), phase=phase) - logger.info("Evicted %d cached embedding sets (%s)", len(candidates), phase) diff --git a/surfsense_backend/app/indexing_pipeline/cache/persistence/__init__.py b/surfsense_backend/app/indexing_pipeline/cache/persistence/__init__.py deleted file mode 100644 index 62cde0d05..000000000 --- a/surfsense_backend/app/indexing_pipeline/cache/persistence/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Database access for cached embedding sets.""" - -from __future__ import annotations - -from .models import CachedEmbeddingSet -from .repository import CachedEmbeddingSetRepository - -__all__ = [ - "CachedEmbeddingSet", - "CachedEmbeddingSetRepository", -] diff --git a/surfsense_backend/app/indexing_pipeline/cache/persistence/models.py b/surfsense_backend/app/indexing_pipeline/cache/persistence/models.py deleted file mode 100644 index af34d92d2..000000000 --- a/surfsense_backend/app/indexing_pipeline/cache/persistence/models.py +++ /dev/null @@ -1,47 +0,0 @@ -"""``embedding_cache_sets``: one reusable chunk+embedding set per markdown.""" - -from __future__ import annotations - -from sqlalchemy import ( - BigInteger, - Column, - DateTime, - Index, - Integer, - String, - UniqueConstraint, -) - -from app.db import BaseModel, TimestampMixin - - -class CachedEmbeddingSet(BaseModel, TimestampMixin): - __tablename__ = "embedding_cache_sets" - - # Key: markdown text + the recipe that turned it into vectors. - markdown_sha256 = Column(String(64), nullable=False) - embedding_model = Column(String(255), nullable=False) - embedding_dim = Column(Integer, nullable=False) - chunker_kind = Column(String(8), nullable=False) - chunker_version = Column(Integer, nullable=False) - - # Where the embedding blob lives (kept out of the row to stay small). - storage_backend = Column(String(32), nullable=False) - storage_key = Column(String, nullable=False) - size_bytes = Column(BigInteger, nullable=False) - chunk_count = Column(Integer, nullable=False, default=0, server_default="0") - - # Drives eviction (popularity + recency). - times_reused = Column(BigInteger, nullable=False, default=0, server_default="0") - last_used_at = Column(DateTime(timezone=True), nullable=False) - - __table_args__ = ( - UniqueConstraint( - "markdown_sha256", - "embedding_model", - "chunker_kind", - "chunker_version", - name="uq_embedding_cache_sets_key", - ), - Index("ix_embedding_cache_sets_last_used_at", "last_used_at"), - ) diff --git a/surfsense_backend/app/indexing_pipeline/cache/persistence/repository.py b/surfsense_backend/app/indexing_pipeline/cache/persistence/repository.py deleted file mode 100644 index f7f1f4345..000000000 --- a/surfsense_backend/app/indexing_pipeline/cache/persistence/repository.py +++ /dev/null @@ -1,126 +0,0 @@ -"""CRUD and eviction selectors for ``embedding_cache_sets`` (no business rules).""" - -from __future__ import annotations - -from datetime import UTC, datetime - -from sqlalchemy import delete, func, select, update -from sqlalchemy.dialects.postgresql import insert as pg_insert -from sqlalchemy.ext.asyncio import AsyncSession - -from app.etl_pipeline.cache.schemas import EvictionCandidate -from app.indexing_pipeline.cache.schemas import EmbeddingKey - -from .models import CachedEmbeddingSet - -_EVICTION_COLUMNS = ( - CachedEmbeddingSet.id, - CachedEmbeddingSet.storage_key, - CachedEmbeddingSet.size_bytes, - CachedEmbeddingSet.last_used_at, - CachedEmbeddingSet.times_reused, -) - - -def _as_eviction_candidate(row) -> EvictionCandidate: - return EvictionCandidate( - id=row.id, - storage_key=row.storage_key, - size_bytes=row.size_bytes, - last_used_at=row.last_used_at, - times_reused=row.times_reused, - ) - - -class CachedEmbeddingSetRepository: - def __init__(self, session: AsyncSession) -> None: - self._session = session - - async def get(self, key: EmbeddingKey) -> CachedEmbeddingSet | None: - result = await self._session.execute( - select(CachedEmbeddingSet).where( - CachedEmbeddingSet.markdown_sha256 == key.markdown_sha256, - CachedEmbeddingSet.embedding_model == key.embedding_model, - CachedEmbeddingSet.chunker_kind == key.chunker_kind, - CachedEmbeddingSet.chunker_version == key.chunker_version, - ) - ) - return result.scalars().first() - - async def insert( - self, - *, - key: EmbeddingKey, - storage_backend: str, - storage_key: str, - size_bytes: int, - chunk_count: int, - ) -> None: - # Concurrent writers embed identical markdown, so a lost race is harmless. - now = datetime.now(UTC) - await self._session.execute( - pg_insert(CachedEmbeddingSet) - .values( - markdown_sha256=key.markdown_sha256, - embedding_model=key.embedding_model, - embedding_dim=key.embedding_dim, - chunker_kind=key.chunker_kind, - chunker_version=key.chunker_version, - storage_backend=storage_backend, - storage_key=storage_key, - size_bytes=size_bytes, - chunk_count=chunk_count, - times_reused=0, - last_used_at=now, - created_at=now, - ) - .on_conflict_do_nothing(constraint="uq_embedding_cache_sets_key") - ) - await self._session.commit() - - async def mark_used(self, row_id: int) -> None: - await self._session.execute( - update(CachedEmbeddingSet) - .where(CachedEmbeddingSet.id == row_id) - .values( - times_reused=CachedEmbeddingSet.times_reused + 1, - last_used_at=datetime.now(UTC), - ) - ) - await self._session.commit() - - async def total_size_bytes(self) -> int: - result = await self._session.execute( - select(func.coalesce(func.sum(CachedEmbeddingSet.size_bytes), 0)) - ) - return int(result.scalar() or 0) - - async def select_expired( - self, *, cutoff: datetime, limit: int - ) -> list[EvictionCandidate]: - result = await self._session.execute( - select(*_EVICTION_COLUMNS) - .where(CachedEmbeddingSet.last_used_at < cutoff) - .order_by(CachedEmbeddingSet.last_used_at.asc()) - .limit(limit) - ) - return [_as_eviction_candidate(row) for row in result] - - async def select_coldest(self, *, limit: int) -> list[EvictionCandidate]: - result = await self._session.execute( - select(*_EVICTION_COLUMNS) - .order_by( - CachedEmbeddingSet.times_reused.asc(), - CachedEmbeddingSet.last_used_at.asc(), - ) - .limit(limit) - ) - return [_as_eviction_candidate(row) for row in result] - - async def delete_by_ids(self, ids: list[int]) -> None: - if not ids: - return - await self._session.execute( - delete(CachedEmbeddingSet).where(CachedEmbeddingSet.id.in_(ids)) - ) - await self._session.commit() diff --git a/surfsense_backend/app/indexing_pipeline/cache/schemas/__init__.py b/surfsense_backend/app/indexing_pipeline/cache/schemas/__init__.py deleted file mode 100644 index c200ca1a6..000000000 --- a/surfsense_backend/app/indexing_pipeline/cache/schemas/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Pure value objects for the embedding cache.""" - -from __future__ import annotations - -from .embedding_key import EmbeddingKey -from .embedding_set import CachedChunk, EmbeddingSet - -__all__ = [ - "CachedChunk", - "EmbeddingKey", - "EmbeddingSet", -] diff --git a/surfsense_backend/app/indexing_pipeline/cache/schemas/embedding_key.py b/surfsense_backend/app/indexing_pipeline/cache/schemas/embedding_key.py deleted file mode 100644 index 55d891e73..000000000 --- a/surfsense_backend/app/indexing_pipeline/cache/schemas/embedding_key.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Identity of a cacheable embedding set: equal keys yield identical vectors. - -Embeddings depend on the markdown text, the embedding model, and the chunker -- -never on how the markdown was produced. So the key is the markdown's own hash -plus the model and chunker recipe, not the upstream parse identity. -""" - -from __future__ import annotations - -import hashlib -from dataclasses import dataclass - - -@dataclass(frozen=True, slots=True) -class EmbeddingKey: - markdown_sha256: str - embedding_model: str - embedding_dim: int - chunker_kind: str - chunker_version: int - - @property - def object_suffix(self) -> str: - # Fingerprint the model so distinct models never share a blob, while the - # markdown hash (the object's folder) stays human-readable. - fingerprint = hashlib.sha256(self.embedding_model.encode("utf-8")).hexdigest() - return f"{fingerprint[:16]}.{self.chunker_kind}.v{self.chunker_version}.emb" diff --git a/surfsense_backend/app/indexing_pipeline/cache/schemas/embedding_set.py b/surfsense_backend/app/indexing_pipeline/cache/schemas/embedding_set.py deleted file mode 100644 index 68c3a5211..000000000 --- a/surfsense_backend/app/indexing_pipeline/cache/schemas/embedding_set.py +++ /dev/null @@ -1,29 +0,0 @@ -"""The cached payload: a document's chunk texts paired with their vectors.""" - -from __future__ import annotations - -from dataclasses import dataclass - -import numpy as np - - -@dataclass(frozen=True, slots=True) -class CachedChunk: - text: str - embedding: np.ndarray - - -@dataclass(frozen=True, slots=True) -class EmbeddingSet: - """Everything the indexer needs to rebuild a document's chunks without embedding. - - ``summary_embedding`` is the document-level vector; ``chunks`` are the ordered - chunk texts and their vectors. - """ - - summary_embedding: np.ndarray - chunks: list[CachedChunk] - - @property - def chunk_count(self) -> int: - return len(self.chunks) diff --git a/surfsense_backend/app/indexing_pipeline/cache/serialization.py b/surfsense_backend/app/indexing_pipeline/cache/serialization.py deleted file mode 100644 index fde0acd00..000000000 --- a/surfsense_backend/app/indexing_pipeline/cache/serialization.py +++ /dev/null @@ -1,75 +0,0 @@ -"""Serialize an EmbeddingSet to a compact, self-describing blob (no pickle). - -Layout: ``MAGIC | uint32 header_len | json header | float32 matrix``. The header -carries the dim, chunk count, and ordered chunk texts; the matrix holds the -summary vector followed by one row per chunk, all float32 for compactness. -""" - -from __future__ import annotations - -import json -import struct - -import numpy as np - -from app.indexing_pipeline.cache.schemas import CachedChunk, EmbeddingSet - -# Marker at the start of every blob: "SurfSense EMBeddings, version 1"-> SSEMB1. Lets us -# reject foreign blobs and bump the trailing digit if the layout ever changes. -_MAGIC = b"SSEMB1" -# 4-byte big-endian unsigned int written before the variable-length JSON header, -# so the reader knows where the header ends and the float matrix begins. -_HEADER_LEN = struct.Struct(">I") - - -def serialize(embedding_set: EmbeddingSet) -> bytes: - summary = np.asarray(embedding_set.summary_embedding, dtype=np.float32).reshape(-1) - dim = int(summary.shape[0]) - - rows = [summary] - texts: list[str] = [] - for chunk in embedding_set.chunks: - vector = np.asarray(chunk.embedding, dtype=np.float32).reshape(-1) - if vector.shape[0] != dim: - raise ValueError( - "All vectors in an embedding set must share one dimension." - ) - rows.append(vector) - texts.append(chunk.text) - - matrix = np.stack(rows, axis=0) - header = json.dumps( - {"dim": dim, "count": len(texts), "texts": texts}, ensure_ascii=False - ).encode("utf-8") - return b"".join( - [_MAGIC, _HEADER_LEN.pack(len(header)), header, matrix.tobytes(order="C")] - ) - - -def deserialize(blob: bytes) -> EmbeddingSet: - view = memoryview(blob) - if bytes(view[: len(_MAGIC)]) != _MAGIC: - raise ValueError("Unrecognized embedding cache blob.") - - offset = len(_MAGIC) - (header_len,) = _HEADER_LEN.unpack(view[offset : offset + _HEADER_LEN.size]) - offset += _HEADER_LEN.size - - header = json.loads(bytes(view[offset : offset + header_len]).decode("utf-8")) - offset += header_len - - dim = int(header["dim"]) - count = int(header["count"]) - texts: list[str] = header["texts"] - - matrix = np.frombuffer(view[offset:], dtype=np.float32) - if matrix.shape[0] != (count + 1) * dim: - raise ValueError("Embedding cache blob is truncated or corrupt.") - matrix = matrix.reshape(count + 1, dim) - - return EmbeddingSet( - summary_embedding=matrix[0], - chunks=[ - CachedChunk(text=texts[i], embedding=matrix[i + 1]) for i in range(count) - ], - ) diff --git a/surfsense_backend/app/indexing_pipeline/cache/service.py b/surfsense_backend/app/indexing_pipeline/cache/service.py deleted file mode 100644 index b1d634782..000000000 --- a/surfsense_backend/app/indexing_pipeline/cache/service.py +++ /dev/null @@ -1,51 +0,0 @@ -"""Recall and remember embedding sets, coordinating the index and blob store.""" - -from __future__ import annotations - -import logging - -from sqlalchemy.ext.asyncio import AsyncSession - -from app.indexing_pipeline.cache.persistence import CachedEmbeddingSetRepository -from app.indexing_pipeline.cache.schemas import EmbeddingKey, EmbeddingSet -from app.indexing_pipeline.cache.storage import EmbeddingCacheStore - -logger = logging.getLogger(__name__) - - -class EmbeddingCacheService: - def __init__(self, session: AsyncSession) -> None: - self._index = CachedEmbeddingSetRepository(session) - self._store = EmbeddingCacheStore() - - async def recall(self, key: EmbeddingKey) -> EmbeddingSet | None: - """Return the cached embedding set, or None on a miss.""" - row = await self._index.get(key) - if row is None: - return None - - try: - embedding_set = await self._store.load(row.storage_key) - except Exception: - # Index points at a blob that is gone; treat as a miss and re-embed. - logger.warning("Cache blob missing: %s", row.storage_key, exc_info=True) - return None - - if int(embedding_set.summary_embedding.shape[0]) != key.embedding_dim: - # A model swapped its dimension under a reused name; never serve it. - logger.warning("Cached embedding dimension mismatch: %s", row.storage_key) - return None - - await self._index.mark_used(row.id) - return embedding_set - - async def remember(self, key: EmbeddingKey, embedding_set: EmbeddingSet) -> None: - """Store a freshly embedded set for future reuse.""" - storage_key, size_bytes = await self._store.save(key, embedding_set) - await self._index.insert( - key=key, - storage_backend=self._store.backend_name, - storage_key=storage_key, - size_bytes=size_bytes, - chunk_count=embedding_set.chunk_count, - ) diff --git a/surfsense_backend/app/indexing_pipeline/cache/settings.py b/surfsense_backend/app/indexing_pipeline/cache/settings.py deleted file mode 100644 index 9c6737445..000000000 --- a/surfsense_backend/app/indexing_pipeline/cache/settings.py +++ /dev/null @@ -1,30 +0,0 @@ -"""Embedding-cache configuration resolved from the central ``Config``. - -The blob backend is intentionally not configured here: it is shared with the ETL -parse cache (see ``ETL_CACHE_STORAGE_*``). -""" - -from __future__ import annotations - -from dataclasses import dataclass - - -@dataclass(frozen=True) -class EmbeddingCacheSettings: - enabled: bool - chunker_version: int - ttl_days: int - max_total_bytes: int - eviction_batch: int - - -def load_embedding_cache_settings() -> EmbeddingCacheSettings: - from app.config import config - - return EmbeddingCacheSettings( - enabled=config.EMBEDDING_CACHE_ENABLED, - chunker_version=config.EMBEDDING_CACHE_CHUNKER_VERSION, - ttl_days=config.EMBEDDING_CACHE_TTL_DAYS, - max_total_bytes=config.EMBEDDING_CACHE_MAX_TOTAL_MB * 1024 * 1024, - eviction_batch=config.EMBEDDING_CACHE_EVICTION_BATCH, - ) diff --git a/surfsense_backend/app/indexing_pipeline/cache/storage/__init__.py b/surfsense_backend/app/indexing_pipeline/cache/storage/__init__.py deleted file mode 100644 index 72b04c34d..000000000 --- a/surfsense_backend/app/indexing_pipeline/cache/storage/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Blob storage for cached embedding sets.""" - -from __future__ import annotations - -from .embedding_store import EmbeddingCacheStore - -__all__ = [ - "EmbeddingCacheStore", -] diff --git a/surfsense_backend/app/indexing_pipeline/cache/storage/embedding_store.py b/surfsense_backend/app/indexing_pipeline/cache/storage/embedding_store.py deleted file mode 100644 index 7b0329b4e..000000000 --- a/surfsense_backend/app/indexing_pipeline/cache/storage/embedding_store.py +++ /dev/null @@ -1,39 +0,0 @@ -"""Read and write cached embedding blobs through the shared cache backend. - -The blob backend is shared with the ETL parse cache (same bucket / root), so -markdown and its embeddings live side by side; only the object prefix differs. -""" - -from __future__ import annotations - -from app.etl_pipeline.cache.storage.backend import resolve_cache_backend -from app.indexing_pipeline.cache.schemas import EmbeddingKey, EmbeddingSet -from app.indexing_pipeline.cache.serialization import deserialize, serialize -from app.indexing_pipeline.cache.storage.object_keys import build_embedding_object_key - -_EMBEDDING_CONTENT_TYPE = "application/octet-stream" - - -class EmbeddingCacheStore: - def __init__(self) -> None: - self._backend = resolve_cache_backend() - - @property - def backend_name(self) -> str: - return self._backend.backend_name - - async def save( - self, key: EmbeddingKey, embedding_set: EmbeddingSet - ) -> tuple[str, int]: - """Persist the embedding set and return its storage key and byte size.""" - blob = serialize(embedding_set) - storage_key = build_embedding_object_key(key) - await self._backend.put(storage_key, blob, content_type=_EMBEDDING_CONTENT_TYPE) - return storage_key, len(blob) - - async def load(self, storage_key: str) -> EmbeddingSet: - chunks = [chunk async for chunk in self._backend.open_stream(storage_key)] - return deserialize(b"".join(chunks)) - - async def delete(self, storage_key: str) -> None: - await self._backend.delete(storage_key) diff --git a/surfsense_backend/app/indexing_pipeline/cache/storage/object_keys.py b/surfsense_backend/app/indexing_pipeline/cache/storage/object_keys.py deleted file mode 100644 index 6286ccf90..000000000 --- a/surfsense_backend/app/indexing_pipeline/cache/storage/object_keys.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Object keys for cached embedding sets, namespaced under a dedicated prefix.""" - -from __future__ import annotations - -from app.indexing_pipeline.cache.schemas import EmbeddingKey - -CACHE_PREFIX = "embedding_cache" - - -def build_embedding_object_key(key: EmbeddingKey) -> str: - # Content-addressed: identical markdown + recipe always map to the same key. - return f"{CACHE_PREFIX}/{key.markdown_sha256}/{key.object_suffix}" diff --git a/surfsense_backend/app/indexing_pipeline/chunk_reconciler.py b/surfsense_backend/app/indexing_pipeline/chunk_reconciler.py deleted file mode 100644 index 9354aeb9f..000000000 --- a/surfsense_backend/app/indexing_pipeline/chunk_reconciler.py +++ /dev/null @@ -1,56 +0,0 @@ -"""Diff a document's existing chunk rows against its freshly chunked texts. - -Embeddings are a pure function of chunk text, so a row whose content reappears -in the new chunking keeps its embedding (and its HNSW/GIN index entries); only -genuinely new texts are embedded and only vanished rows are deleted. Matching -is a greedy multiset match on content in document order, so duplicate -boilerplate chunks pair up one-to-one and reordered chunks become cheap -position updates instead of delete+reinsert. -""" - -from __future__ import annotations - -from collections import defaultdict, deque -from dataclasses import dataclass - - -@dataclass(frozen=True, slots=True) -class ExistingChunk: - id: int - content: str - position: int - - -@dataclass(frozen=True, slots=True) -class ChunkPlan: - """The minimal set of writes that turns the stored chunks into the new ones. - - ``reused`` holds only kept rows whose position actually changed; rows that - match in place need no write at all. Kept-row count (for metrics) is - ``len(existing) - len(to_delete)``. - """ - - reused: list[tuple[int, int]] # (existing_chunk_id, new_position) - to_embed: list[tuple[int, str]] # (new_position, text) - to_delete: list[int] # existing chunk ids - - -def reconcile(existing: list[ExistingChunk], new_texts: list[str]) -> ChunkPlan: - available: dict[str, deque[ExistingChunk]] = defaultdict(deque) - for chunk in sorted(existing, key=lambda c: c.position): - available[chunk.content].append(chunk) - - reused: list[tuple[int, int]] = [] - to_embed: list[tuple[int, str]] = [] - - for new_position, text in enumerate(new_texts): - matches = available.get(text) - if matches: - chunk = matches.popleft() - if chunk.position != new_position: - reused.append((chunk.id, new_position)) - else: - to_embed.append((new_position, text)) - - to_delete = [chunk.id for queue in available.values() for chunk in queue] - return ChunkPlan(reused=reused, to_embed=to_embed, to_delete=to_delete) diff --git a/surfsense_backend/app/indexing_pipeline/document_persistence.py b/surfsense_backend/app/indexing_pipeline/document_persistence.py index b716560d2..9fd8867e2 100644 --- a/surfsense_backend/app/indexing_pipeline/document_persistence.py +++ b/surfsense_backend/app/indexing_pipeline/document_persistence.py @@ -1,12 +1,12 @@ import contextlib import logging -import time from datetime import UTC, datetime from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import object_session from sqlalchemy.orm.attributes import set_committed_value -from app.db import Chunk, Document, DocumentStatus +from app.db import Document, DocumentStatus logger = logging.getLogger(__name__) @@ -22,6 +22,7 @@ async def rollback_and_persist_failure( try: await session.rollback() except Exception: + # Session is completely dead; surface it but never raise. logger.warning( "Rollback failed; cannot persist failed status for document %s", getattr(document, "id", "unknown"), @@ -34,6 +35,8 @@ async def rollback_and_persist_failure( document.status = DocumentStatus.failed(message) await session.commit() except Exception: + # Best-effort: the document stays non-ready and is retried next sync. + # Log it so a permanently-stuck document is at least traceable. logger.warning( "Could not persist failed status for document %s; will retry next sync", getattr(document, "id", "unknown"), @@ -43,60 +46,12 @@ async def rollback_and_persist_failure( await session.rollback() -async def persist_scratch_index( - session: AsyncSession, - document: Document, - content: str, - chunks: list[Chunk], - *, - batch_size: int, - perf: logging.Logger, -) -> None: - """Commit document content first, then chunk rows in batches, then mark ready.""" - if document.id is None: - raise ValueError("document.id is required to persist chunks") - - document.content = content - document.updated_at = datetime.now(UTC) - await session.commit() - - t_persist = time.perf_counter() - total = len(chunks) - if total == 0: - set_committed_value(document, "chunks", []) - document.status = DocumentStatus.ready() - document.updated_at = datetime.now(UTC) - await session.commit() - return - - effective_batch = total if batch_size <= 0 else batch_size - num_batches = (total + effective_batch - 1) // effective_batch - doc_id = document.id - - for batch_idx, start in enumerate(range(0, total, effective_batch), start=1): - batch = chunks[start : start + effective_batch] - t_batch = time.perf_counter() - for chunk in batch: - chunk.document_id = doc_id - session.add_all(batch) - await session.commit() - perf.info( - "[indexing] chunk batch doc=%d batch=%d/%d rows=%d in %.3fs", - doc_id, - batch_idx, - num_batches, - len(batch), - time.perf_counter() - t_batch, - ) - +def attach_chunks_to_document(document: Document, chunks: list) -> None: + """Assign chunks to a document without triggering SQLAlchemy async lazy loading.""" set_committed_value(document, "chunks", chunks) - document.status = DocumentStatus.ready() - document.updated_at = datetime.now(UTC) - await session.commit() - perf.info( - "[indexing] chunk persist doc=%d chunks=%d batches=%d in %.3fs", - doc_id, - total, - num_batches, - time.perf_counter() - t_persist, - ) + session = object_session(document) + if session is not None: + if document.id is not None: + for chunk in chunks: + chunk.document_id = document.id + session.add_all(chunks) diff --git a/surfsense_backend/app/indexing_pipeline/exceptions.py b/surfsense_backend/app/indexing_pipeline/exceptions.py index bf9d9e9fa..666fa4b9f 100644 --- a/surfsense_backend/app/indexing_pipeline/exceptions.py +++ b/surfsense_backend/app/indexing_pipeline/exceptions.py @@ -14,8 +14,6 @@ from litellm.exceptions import ( ) from sqlalchemy.exc import IntegrityError as IntegrityError -from app.services.llm_error_adapter import LLMErrorCategory, adapt_llm_exception - # Tuples for use directly in except clauses. RETRYABLE_LLM_ERRORS = ( RateLimitError, @@ -99,20 +97,38 @@ def safe_exception_message(exc: Exception) -> str: def llm_retryable_message(exc: Exception) -> str: try: - adapted = adapt_llm_exception(exc) - if adapted.category is LLMErrorCategory.UNKNOWN: - return safe_exception_message(exc) - return adapted.user_message + if isinstance(exc, RateLimitError): + return PipelineMessages.RATE_LIMIT + if isinstance(exc, Timeout): + return PipelineMessages.LLM_TIMEOUT + if isinstance(exc, ServiceUnavailableError): + return PipelineMessages.LLM_UNAVAILABLE + if isinstance(exc, BadGatewayError): + return PipelineMessages.LLM_BAD_GATEWAY + if isinstance(exc, InternalServerError): + return PipelineMessages.LLM_SERVER_ERROR + if isinstance(exc, APIConnectionError): + return PipelineMessages.LLM_CONNECTION + return safe_exception_message(exc) except Exception: return "Something went wrong when calling the LLM." def llm_permanent_message(exc: Exception) -> str: try: - adapted = adapt_llm_exception(exc) - if adapted.category is LLMErrorCategory.UNKNOWN: - return safe_exception_message(exc) - return adapted.user_message + if isinstance(exc, AuthenticationError): + return PipelineMessages.LLM_AUTH + if isinstance(exc, PermissionDeniedError): + return PipelineMessages.LLM_PERMISSION + if isinstance(exc, NotFoundError): + return PipelineMessages.LLM_NOT_FOUND + if isinstance(exc, BadRequestError): + return PipelineMessages.LLM_BAD_REQUEST + if isinstance(exc, UnprocessableEntityError): + return PipelineMessages.LLM_UNPROCESSABLE + if isinstance(exc, APIResponseValidationError): + return PipelineMessages.LLM_RESPONSE + return safe_exception_message(exc) except Exception: return "Something went wrong when calling the LLM." diff --git a/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py b/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py index 30ea9d5d6..67a6778e0 100644 --- a/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py +++ b/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py @@ -8,7 +8,7 @@ from collections.abc import Awaitable, Callable from dataclasses import dataclass, field from datetime import UTC, datetime -from sqlalchemy import delete, select, update +from sqlalchemy import delete, select from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession @@ -19,17 +19,16 @@ from app.db import ( DocumentStatus, DocumentType, ) -from app.indexing_pipeline.cache import build_chunk_embeddings -from app.indexing_pipeline.cache.cached_indexing import chunk_markdown, embed_batch -from app.indexing_pipeline.chunk_reconciler import ExistingChunk, reconcile from app.indexing_pipeline.connector_document import ConnectorDocument +from app.indexing_pipeline.document_chunker import chunk_text, chunk_text_hybrid +from app.indexing_pipeline.document_embedder import embed_texts from app.indexing_pipeline.document_hashing import ( compute_content_hash, compute_identifier_hash, compute_unique_identifier_hash, ) from app.indexing_pipeline.document_persistence import ( - persist_scratch_index, + attach_chunks_to_document, rollback_and_persist_failure, ) from app.indexing_pipeline.exceptions import ( @@ -381,50 +380,53 @@ class IndexingPipelineService: content = connector_doc.source_markdown - t_step = time.perf_counter() - existing = await self._load_existing_chunks(document.id) - if existing and self._reconcile_enabled(): - chunk_count = await self._reindex_incrementally( - document, content, connector_doc, existing - ) - perf.info( - "[indexing] chunk+embed doc=%d chunks=%d in %.3fs", - document.id, - chunk_count, - time.perf_counter() - t_step, - ) - document.content = content - document.updated_at = datetime.now(UTC) - document.status = DocumentStatus.ready() - await self.session.commit() - else: - from app.config import config + await self.session.execute( + delete(Chunk).where(Chunk.document_id == document.id) + ) - chunks = await self._reindex_from_scratch( - document, content, connector_doc + t_step = time.perf_counter() + if connector_doc.should_use_code_chunker: + chunk_texts = await asyncio.to_thread( + chunk_text, + connector_doc.source_markdown, + use_code_chunker=True, ) - chunk_count = len(chunks) - perf.info( - "[indexing] chunk+embed doc=%d chunks=%d in %.3fs", - document.id, - chunk_count, - time.perf_counter() - t_step, - ) - await persist_scratch_index( - self.session, - document, - content, - chunks, - batch_size=config.INDEXING_CHUNK_INSERT_BATCH_SIZE, - perf=perf, + else: + # Use the table-aware hybrid chunker so Markdown tables are not + # split mid-row (see issue #1334). + chunk_texts = await asyncio.to_thread( + chunk_text_hybrid, + connector_doc.source_markdown, ) + + texts_to_embed = [content, *chunk_texts] + embeddings = await asyncio.to_thread(embed_texts, texts_to_embed) + summary_embedding, *chunk_embeddings = embeddings + + chunks = [ + Chunk(content=text, embedding=emb) + for text, emb in zip(chunk_texts, chunk_embeddings, strict=False) + ] + perf.info( + "[indexing] chunk+embed doc=%d chunks=%d in %.3fs", + document.id, + len(chunks), + time.perf_counter() - t_step, + ) + + document.content = content + document.embedding = summary_embedding + attach_chunks_to_document(document, chunks) + document.updated_at = datetime.now(UTC) + document.status = DocumentStatus.ready() + await self.session.commit() perf.info( "[indexing] index TOTAL doc=%d chunks=%d in %.3fs", document.id, - chunk_count, + len(chunks), time.perf_counter() - t_index, ) - log_index_success(ctx, chunk_count=chunk_count) + log_index_success(ctx, chunk_count=len(chunks)) outcome_status = "success" await self._enqueue_ai_sort_if_enabled(document) @@ -481,89 +483,6 @@ class IndexingPipelineService: persist_span_cm.__exit__(*sys.exc_info()) return document - @staticmethod - def _reconcile_enabled() -> bool: - from app.config import config - - return config.CHUNK_RECONCILE_ENABLED - - async def _load_existing_chunks(self, document_id: int) -> list[ExistingChunk]: - result = await self.session.execute( - select(Chunk.id, Chunk.content, Chunk.position).where( - Chunk.document_id == document_id - ) - ) - return [ - ExistingChunk(id=row.id, content=row.content, position=row.position) - for row in result - ] - - async def _reindex_from_scratch( - self, document: Document, content: str, connector_doc: ConnectorDocument - ) -> list[Chunk]: - await self.session.execute( - delete(Chunk).where(Chunk.document_id == document.id) - ) - - summary_embedding, chunk_pairs = await build_chunk_embeddings( - content, - use_code_chunker=connector_doc.should_use_code_chunker, - ) - - document.embedding = summary_embedding - return [ - Chunk(content=text, embedding=emb, position=i) - for i, (text, emb) in enumerate(chunk_pairs) - ] - - async def _reindex_incrementally( - self, - document: Document, - content: str, - connector_doc: ConnectorDocument, - existing: list[ExistingChunk], - ) -> int: - """Edit path: keep rows whose text survived, embed only new texts. - - Unchanged rows keep their embedding and their HNSW/GIN index entries; - moved rows get a position-only UPDATE, which touches neither index. - """ - new_texts = await chunk_markdown( - content, use_code_chunker=connector_doc.should_use_code_chunker - ) - plan = reconcile(existing, new_texts) - - # One batch: the document-level summary vector plus the missing chunks. - embeddings = await embed_batch([content, *[t for _, t in plan.to_embed]]) - summary_embedding, *new_embeddings = embeddings - - if plan.reused: - await self.session.execute( - update(Chunk), - [{"id": cid, "position": pos} for cid, pos in plan.reused], - ) - if plan.to_delete: - await self.session.execute( - delete(Chunk).where(Chunk.id.in_(plan.to_delete)) - ) - self.session.add_all( - Chunk( - content=text, - embedding=emb, - position=pos, - document_id=document.id, - ) - for (pos, text), emb in zip(plan.to_embed, new_embeddings, strict=True) - ) - document.embedding = summary_embedding - - ot_metrics.record_chunk_reconcile( - reused=len(existing) - len(plan.to_delete), - embedded=len(plan.to_embed), - deleted=len(plan.to_delete), - ) - return len(new_texts) - async def _enqueue_ai_sort_if_enabled(self, document: Document) -> None: """Fire-and-forget: enqueue incremental AI sort if the search space has it enabled.""" try: diff --git a/surfsense_backend/app/notifications/api/api.py b/surfsense_backend/app/notifications/api/api.py index 9a136ca7b..ddca09c66 100644 --- a/surfsense_backend/app/notifications/api/api.py +++ b/surfsense_backend/app/notifications/api/api.py @@ -275,7 +275,7 @@ async def list_notifications( query = query.where(unread_filter) count_query = count_query.where(unread_filter) elif filter == "errors": - error_filter = (Notification.type == "insufficient_credits") | ( + error_filter = (Notification.type == "page_limit_exceeded") | ( Notification.notification_metadata["status"].astext == "failed" ) query = query.where(error_filter) diff --git a/surfsense_backend/app/notifications/constants.py b/surfsense_backend/app/notifications/constants.py index 4c7139972..e8bd8391d 100644 --- a/surfsense_backend/app/notifications/constants.py +++ b/surfsense_backend/app/notifications/constants.py @@ -2,9 +2,6 @@ from __future__ import annotations -# Matches notifications.title VARCHAR(200). -TITLE_MAX_LENGTH = 200 - # Notifications newer than this are live-synced; older ones load via the list endpoint. SYNC_WINDOW_DAYS = 14 @@ -15,7 +12,6 @@ CATEGORY_TYPES: dict[str, tuple[str, ...]] = { "connector_indexing", "connector_deletion", "document_processing", - "insufficient_credits", - "auto_reload_failed", + "page_limit_exceeded", ), } diff --git a/surfsense_backend/app/notifications/service/facade.py b/surfsense_backend/app/notifications/service/facade.py index 9f4ad50d0..63154301c 100644 --- a/surfsense_backend/app/notifications/service/facade.py +++ b/surfsense_backend/app/notifications/service/facade.py @@ -10,12 +10,11 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.notifications.persistence import Notification from app.notifications.service.handlers import ( - AutoReloadFailedNotificationHandler, CommentReplyNotificationHandler, ConnectorIndexingNotificationHandler, DocumentProcessingNotificationHandler, - InsufficientCreditsNotificationHandler, MentionNotificationHandler, + PageLimitNotificationHandler, ) logger = logging.getLogger(__name__) @@ -28,8 +27,7 @@ class NotificationService: document_processing = DocumentProcessingNotificationHandler() mention = MentionNotificationHandler() comment_reply = CommentReplyNotificationHandler() - insufficient_credits = InsufficientCreditsNotificationHandler() - auto_reload_failed = AutoReloadFailedNotificationHandler() + page_limit = PageLimitNotificationHandler() @staticmethod async def create_notification( diff --git a/surfsense_backend/app/notifications/service/handlers/__init__.py b/surfsense_backend/app/notifications/service/handlers/__init__.py index 1a6168e37..8c32dea3b 100644 --- a/surfsense_backend/app/notifications/service/handlers/__init__.py +++ b/surfsense_backend/app/notifications/service/handlers/__init__.py @@ -2,18 +2,16 @@ from __future__ import annotations -from .auto_reload_failed import AutoReloadFailedNotificationHandler from .comment_reply import CommentReplyNotificationHandler from .connector_indexing import ConnectorIndexingNotificationHandler from .document_processing import DocumentProcessingNotificationHandler -from .insufficient_credits import InsufficientCreditsNotificationHandler from .mention import MentionNotificationHandler +from .page_limit import PageLimitNotificationHandler __all__ = [ - "AutoReloadFailedNotificationHandler", "CommentReplyNotificationHandler", "ConnectorIndexingNotificationHandler", "DocumentProcessingNotificationHandler", - "InsufficientCreditsNotificationHandler", "MentionNotificationHandler", + "PageLimitNotificationHandler", ] diff --git a/surfsense_backend/app/notifications/service/handlers/auto_reload_failed.py b/surfsense_backend/app/notifications/service/handlers/auto_reload_failed.py deleted file mode 100644 index 0234a436d..000000000 --- a/surfsense_backend/app/notifications/service/handlers/auto_reload_failed.py +++ /dev/null @@ -1,54 +0,0 @@ -"""Notifications for failed off-session credit auto-reload charges.""" - -from __future__ import annotations - -import logging -from uuid import UUID - -from sqlalchemy.ext.asyncio import AsyncSession - -from app.notifications.persistence import Notification -from app.notifications.service.base import BaseNotificationHandler -from app.notifications.service.messages import auto_reload_failed as msg - -logger = logging.getLogger(__name__) - - -class AutoReloadFailedNotificationHandler(BaseNotificationHandler): - """Notifications for declined auto-reload top-ups.""" - - def __init__(self): - super().__init__("auto_reload_failed") - - async def notify_auto_reload_failed( - self, - session: AsyncSession, - user_id: UUID, - amount_micros: int, - payment_intent_id: str | None = None, - reason: str | None = None, - ) -> Notification: - """Notify that an off-session auto-reload charge was declined. - - Not tied to a search space (``search_space_id`` is None); the action - links to the billing settings so the user can fix their card. - """ - op_id = msg.operation_id(payment_intent_id or "") - title, message = msg.summary(amount_micros, reason) - - return await self.find_or_create_notification( - session=session, - user_id=user_id, - operation_id=op_id, - title=title, - message=message, - search_space_id=None, - initial_metadata={ - "amount_micros": amount_micros, - "payment_intent_id": payment_intent_id, - "status": "failed", - "error_type": "auto_reload_failed", - "action_url": "/dashboard", - "action_label": "Update card", - }, - ) diff --git a/surfsense_backend/app/notifications/service/handlers/document_processing.py b/surfsense_backend/app/notifications/service/handlers/document_processing.py index 714c4f1aa..8644df2c8 100644 --- a/surfsense_backend/app/notifications/service/handlers/document_processing.py +++ b/surfsense_backend/app/notifications/service/handlers/document_processing.py @@ -28,7 +28,7 @@ class DocumentProcessingNotificationHandler(BaseNotificationHandler): ) -> Notification: """Open the notification when document processing is queued.""" operation_id = msg.operation_id(document_type, document_name, search_space_id) - title = msg.started_title(document_name) + title = f"Processing: {document_name}" message = "Waiting in queue" metadata = { diff --git a/surfsense_backend/app/notifications/service/handlers/insufficient_credits.py b/surfsense_backend/app/notifications/service/handlers/page_limit.py similarity index 55% rename from surfsense_backend/app/notifications/service/handlers/insufficient_credits.py rename to surfsense_backend/app/notifications/service/handlers/page_limit.py index 46124f222..90722dc62 100644 --- a/surfsense_backend/app/notifications/service/handlers/insufficient_credits.py +++ b/surfsense_backend/app/notifications/service/handlers/page_limit.py @@ -1,4 +1,4 @@ -"""Notifications for running out of credit during document processing.""" +"""Notifications for exceeding the page limit.""" from __future__ import annotations @@ -9,42 +9,46 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.notifications.persistence import Notification from app.notifications.service.base import BaseNotificationHandler -from app.notifications.service.messages import insufficient_credits as msg +from app.notifications.service.messages import page_limit as msg logger = logging.getLogger(__name__) -class InsufficientCreditsNotificationHandler(BaseNotificationHandler): - """Notifications for running out of credit during document processing.""" +class PageLimitNotificationHandler(BaseNotificationHandler): + """Notifications for exceeding the page limit.""" def __init__(self): - super().__init__("insufficient_credits") + super().__init__("page_limit_exceeded") - async def notify_insufficient_credits( + async def notify_page_limit_exceeded( self, session: AsyncSession, user_id: UUID, document_name: str, document_type: str, search_space_id: int, - balance_micros: int, - required_micros: int, + pages_used: int, + pages_limit: int, + pages_to_add: int, ) -> Notification: - """Notify that a document was blocked by insufficient credit.""" + """Notify that a document was blocked by the page limit.""" operation_id = msg.operation_id(document_name, search_space_id) - title, message = msg.summary(document_name, balance_micros, required_micros) + title, message = msg.summary( + document_name, pages_used, pages_limit, pages_to_add + ) metadata = { "operation_id": operation_id, "document_name": document_name, "document_type": document_type, - "balance_micros": balance_micros, - "required_micros": required_micros, + "pages_used": pages_used, + "pages_limit": pages_limit, + "pages_to_add": pages_to_add, "status": "failed", - "error_type": "insufficient_credits", + "error_type": "page_limit_exceeded", # Where the inbox item links to. - "action_url": f"/dashboard/{search_space_id}/buy-more", - "action_label": "Buy credits", + "action_url": f"/dashboard/{search_space_id}/more-pages", + "action_label": "Upgrade Plan", } notification = Notification( @@ -59,7 +63,6 @@ class InsufficientCreditsNotificationHandler(BaseNotificationHandler): await session.commit() await session.refresh(notification) logger.info( - f"Created insufficient_credits notification {notification.id} " - f"for user {user_id}" + f"Created page_limit_exceeded notification {notification.id} for user {user_id}" ) return notification diff --git a/surfsense_backend/app/notifications/service/messages/auto_reload_failed.py b/surfsense_backend/app/notifications/service/messages/auto_reload_failed.py deleted file mode 100644 index 5af19623c..000000000 --- a/surfsense_backend/app/notifications/service/messages/auto_reload_failed.py +++ /dev/null @@ -1,31 +0,0 @@ -"""Pure presentation logic for auto-reload-failure notifications.""" - -from __future__ import annotations - -from datetime import UTC, datetime - - -def operation_id(payment_intent_id: str) -> str: - """Build a unique id for an auto-reload-failure notification. - - Keyed on the failed PaymentIntent so retries of the same charge collapse - into a single inbox item rather than spamming the user. - """ - if payment_intent_id: - return f"auto_reload_failed_{payment_intent_id}" - timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S_%f") - return f"auto_reload_failed_{timestamp}" - - -def summary(amount_micros: int, reason: str | None) -> tuple[str, str]: - """Compute the title and message for a failed off-session auto-reload charge.""" - amount_usd = max(0, amount_micros) / 1_000_000 - title = "Auto-reload failed" - base = ( - f"We couldn't automatically add ${amount_usd:.2f} of credit because your " - "saved card was declined. Auto-reload has been turned off — update your " - "card and re-enable it to keep topping up automatically." - ) - if reason: - base = f"{base} (Reason: {reason}.)" - return title, base diff --git a/surfsense_backend/app/notifications/service/messages/document_processing.py b/surfsense_backend/app/notifications/service/messages/document_processing.py index 1f324b35d..3805c2847 100644 --- a/surfsense_backend/app/notifications/service/messages/document_processing.py +++ b/surfsense_backend/app/notifications/service/messages/document_processing.py @@ -6,8 +6,6 @@ import hashlib from datetime import UTC, datetime from typing import Any -from app.notifications.service.messages.text import format_title - def operation_id(document_type: str, filename: str, search_space_id: int) -> str: """Build a unique id for a document processing run.""" @@ -16,11 +14,6 @@ def operation_id(document_type: str, filename: str, search_space_id: int) -> str return f"doc_{document_type}_{search_space_id}_{timestamp}_{filename_hash}" -def started_title(document_name: str) -> str: - """Title shown when document processing is queued.""" - return format_title("Processing: ", document_name) - - def progress( stage: str, stage_message: str | None = None, @@ -51,11 +44,11 @@ def completion( ) -> tuple[str, str, str, dict[str, Any]]: """Compute the final title, message, status, and metadata for a finished run.""" if error_message: - title = format_title("Failed: ", document_name) + title = f"Failed: {document_name}" message = f"Processing failed: {error_message}" status = "failed" else: - title = format_title("Ready: ", document_name) + title = f"Ready: {document_name}" message = "Now searchable!" status = "completed" diff --git a/surfsense_backend/app/notifications/service/messages/insufficient_credits.py b/surfsense_backend/app/notifications/service/messages/insufficient_credits.py deleted file mode 100644 index fad26ad91..000000000 --- a/surfsense_backend/app/notifications/service/messages/insufficient_credits.py +++ /dev/null @@ -1,30 +0,0 @@ -"""Pure presentation logic for insufficient-credit notifications.""" - -from __future__ import annotations - -import hashlib -from datetime import UTC, datetime - -from app.notifications.service.messages.text import truncate - - -def operation_id(document_name: str, search_space_id: int) -> str: - """Build a unique id for an insufficient-credits notification.""" - timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S_%f") - doc_hash = hashlib.md5(document_name.encode()).hexdigest()[:8] - return f"insufficient_credits_{search_space_id}_{timestamp}_{doc_hash}" - - -def summary( - document_name: str, balance_micros: int, required_micros: int -) -> tuple[str, str]: - """Compute the title and message for a blocked-by-insufficient-credits document.""" - display_name = truncate(document_name, 40) - title = f"Insufficient credits: {display_name}" - balance_usd = max(0, balance_micros) / 1_000_000 - required_usd = max(0, required_micros) / 1_000_000 - message = ( - f"This document costs about ${required_usd:.2f} to process but you have " - f"${balance_usd:.2f} of credit left. Add more credits to continue." - ) - return title, message diff --git a/surfsense_backend/app/notifications/service/messages/page_limit.py b/surfsense_backend/app/notifications/service/messages/page_limit.py new file mode 100644 index 000000000..54e5cbdec --- /dev/null +++ b/surfsense_backend/app/notifications/service/messages/page_limit.py @@ -0,0 +1,25 @@ +"""Pure presentation logic for page-limit notifications.""" + +from __future__ import annotations + +import hashlib +from datetime import UTC, datetime + +from app.notifications.service.messages.text import truncate + + +def operation_id(document_name: str, search_space_id: int) -> str: + """Build a unique id for a page-limit notification.""" + timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S_%f") + doc_hash = hashlib.md5(document_name.encode()).hexdigest()[:8] + return f"page_limit_{search_space_id}_{timestamp}_{doc_hash}" + + +def summary( + document_name: str, pages_used: int, pages_limit: int, pages_to_add: int +) -> tuple[str, str]: + """Compute the title and message for a blocked-by-page-limit document.""" + display_name = truncate(document_name, 40) + title = f"Page limit exceeded: {display_name}" + message = f"This document has ~{pages_to_add} page(s) but you've used {pages_used}/{pages_limit} pages. Upgrade to process more documents." + return title, message diff --git a/surfsense_backend/app/notifications/service/messages/text.py b/surfsense_backend/app/notifications/service/messages/text.py index 344c9eb4e..98d5284cb 100644 --- a/surfsense_backend/app/notifications/service/messages/text.py +++ b/surfsense_backend/app/notifications/service/messages/text.py @@ -2,21 +2,7 @@ from __future__ import annotations -from app.notifications.constants import TITLE_MAX_LENGTH - def truncate(text: str, limit: int) -> str: """Return ``text`` capped at ``limit`` chars, appending an ellipsis if cut.""" return text[:limit] + "..." if len(text) > limit else text - - -def format_title(prefix: str, text: str, *, max_length: int = TITLE_MAX_LENGTH) -> str: - """Build a notification title that fits ``max_length`` including ``prefix``.""" - budget = max_length - len(prefix) - if budget <= 0: - return prefix[:max_length] - if len(text) <= budget: - return f"{prefix}{text}" - if budget <= 3: - return f"{prefix}{text[:budget]}" - return f"{prefix}{text[: budget - 3]}..." diff --git a/surfsense_backend/app/notifications/types.py b/surfsense_backend/app/notifications/types.py index f2974e584..bb8bcfab1 100644 --- a/surfsense_backend/app/notifications/types.py +++ b/surfsense_backend/app/notifications/types.py @@ -10,8 +10,7 @@ NotificationType = Literal[ "document_processing", "new_mention", "comment_reply", - "insufficient_credits", - "auto_reload_failed", + "page_limit_exceeded", ] NotificationCategory = Literal["comments", "status"] diff --git a/surfsense_backend/app/observability/metrics.py b/surfsense_backend/app/observability/metrics.py index ade43ab01..5ba3be059 100644 --- a/surfsense_backend/app/observability/metrics.py +++ b/surfsense_backend/app/observability/metrics.py @@ -289,49 +289,6 @@ def _etl_extract_outcome(): ) -@lru_cache(maxsize=1) -def _etl_cache_lookups(): - return _get_meter().create_counter( - "surfsense.etl.cache.lookups", - description="Count of ETL parse-cache lookups by outcome (hit/miss).", - ) - - -@lru_cache(maxsize=1) -def _etl_cache_evictions(): - return _get_meter().create_counter( - "surfsense.etl.cache.evictions", - description="Count of ETL parse-cache entries evicted, by phase.", - ) - - -@lru_cache(maxsize=1) -def _embedding_cache_lookups(): - return _get_meter().create_counter( - "surfsense.embedding.cache.lookups", - description="Count of embedding (chunk+embedding) cache lookups by outcome (hit/miss).", - ) - - -@lru_cache(maxsize=1) -def _embedding_cache_evictions(): - return _get_meter().create_counter( - "surfsense.embedding.cache.evictions", - description="Count of embedding cache entries evicted, by phase.", - ) - - -@lru_cache(maxsize=1) -def _chunk_reconcile_chunks(): - return _get_meter().create_counter( - "surfsense.indexing.reconcile.chunks", - description=( - "Chunks handled by incremental re-indexing, by outcome " - "(reused/embedded/deleted)." - ), - ) - - @lru_cache(maxsize=1) def _celery_heartbeat_refreshes(): return _get_meter().create_counter( @@ -713,61 +670,6 @@ def record_etl_extract_outcome( ) -def record_etl_cache_lookup( - *, etl_service: str | None, mode: str | None, outcome: str -) -> None: - """Record a parse-cache lookup. ``outcome`` is ``hit`` or ``miss``.""" - _add( - _etl_cache_lookups(), - 1, - { - "etl.service": etl_service or "unknown", - "mode": mode or "unknown", - "outcome": outcome, - }, - ) - - -def record_etl_cache_eviction(count: int, *, phase: str) -> None: - """Record evicted entries. ``phase`` is ``ttl`` or ``size``.""" - if count <= 0: - return - _add(_etl_cache_evictions(), count, {"phase": phase}) - - -def record_embedding_cache_lookup( - *, embedding_model: str | None, chunker_kind: str | None, outcome: str -) -> None: - """Record an embedding-cache lookup. ``outcome`` is ``hit`` or ``miss``.""" - _add( - _embedding_cache_lookups(), - 1, - { - "embedding.model": embedding_model or "unknown", - "chunker.kind": chunker_kind or "unknown", - "outcome": outcome, - }, - ) - - -def record_embedding_cache_eviction(count: int, *, phase: str) -> None: - """Record evicted entries. ``phase`` is ``ttl`` or ``size``.""" - if count <= 0: - return - _add(_embedding_cache_evictions(), count, {"phase": phase}) - - -def record_chunk_reconcile(*, reused: int, embedded: int, deleted: int) -> None: - """Record an incremental re-index: how many chunks were kept vs recomputed.""" - for outcome, count in ( - ("reused", reused), - ("embedded", embedded), - ("deleted", deleted), - ): - if count > 0: - _add(_chunk_reconcile_chunks(), count, {"outcome": outcome}) - - def record_celery_heartbeat_refresh(*, heartbeat_type: str) -> None: _add(_celery_heartbeat_refreshes(), 1, {"heartbeat.type": heartbeat_type}) @@ -961,14 +863,9 @@ __all__ = [ "record_celery_queue_latency", "record_chat_request_duration", "record_chat_request_outcome", - "record_chunk_reconcile", "record_compaction_run", "record_connector_sync_duration", "record_connector_sync_outcome", - "record_embedding_cache_eviction", - "record_embedding_cache_lookup", - "record_etl_cache_eviction", - "record_etl_cache_lookup", "record_etl_extract_duration", "record_etl_extract_outcome", "record_indexing_document_duration", diff --git a/surfsense_backend/app/podcasts/__init__.py b/surfsense_backend/app/podcasts/__init__.py deleted file mode 100644 index 6a152af22..000000000 --- a/surfsense_backend/app/podcasts/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Podcast feature: brief resolution, transcript drafting, and audio rendering. - -Owns the ``podcasts`` table model, which :mod:`app.db` re-exports so existing -``from app.db import Podcast`` imports keep resolving. -""" - -from __future__ import annotations - -__all__: list[str] = [] diff --git a/surfsense_backend/app/podcasts/api/__init__.py b/surfsense_backend/app/podcasts/api/__init__.py deleted file mode 100644 index 4b5b12971..000000000 --- a/surfsense_backend/app/podcasts/api/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""HTTP API for the podcast lifecycle.""" - -from __future__ import annotations - -from .routes import router - -__all__ = ["router"] diff --git a/surfsense_backend/app/podcasts/api/routes.py b/surfsense_backend/app/podcasts/api/routes.py deleted file mode 100644 index cfcb2ede9..000000000 --- a/surfsense_backend/app/podcasts/api/routes.py +++ /dev/null @@ -1,360 +0,0 @@ -"""HTTP surface for the podcast lifecycle. - -Status is observed by the frontend through Zero, so these routes are about -actions (create, edit/approve the brief, regenerate, cancel) and audio delivery. -Each mutating route performs the guarded transition via the service, commits, -then enqueues the matching Celery task; lifecycle errors map to 409/422. -""" - -from __future__ import annotations - -import os -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager -from pathlib import Path - -from fastapi import APIRouter, Depends, HTTPException, Response -from fastapi.responses import StreamingResponse -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession - -from app.config import config as app_config -from app.db import ( - Permission, - SearchSpace, - SearchSpaceMembership, - User, - get_async_session, -) -from app.podcasts.generation.brief import propose_brief -from app.podcasts.persistence import Podcast, PodcastRepository, PodcastStatus -from app.podcasts.service import ( - InvalidTransitionError, - PodcastService, - PreconditionFailedError, - SpecConflictError, -) -from app.podcasts.storage import audio_exists, open_audio_stream, purge_audio -from app.podcasts.tasks import draft_transcript_task -from app.podcasts.tts import get_text_to_speech -from app.podcasts.voices import ( - get_voice_catalog, - provider_from_service, - render_voice_preview, -) -from app.users import current_active_user -from app.utils.rbac import check_permission - -from .schemas import ( - CreatePodcastRequest, - LanguageOptions, - PodcastDetail, - PodcastSummary, - UpdateSpecRequest, - VoiceOption, -) - -router = APIRouter() - - -@router.get("/podcasts", response_model=list[PodcastSummary]) -async def list_podcasts( - search_space_id: int | None = None, - skip: int = 0, - limit: int = 100, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - if skip < 0 or limit < 1: - raise HTTPException(status_code=400, detail="Invalid pagination parameters") - - if search_space_id is not None: - await _require(session, user, search_space_id, Permission.PODCASTS_READ) - query = ( - select(Podcast) - .where(Podcast.search_space_id == search_space_id) - .order_by(Podcast.created_at.desc()) - .offset(skip) - .limit(limit) - ) - else: - query = ( - select(Podcast) - .join(SearchSpace) - .join(SearchSpaceMembership) - .where(SearchSpaceMembership.user_id == user.id) - .order_by(Podcast.created_at.desc()) - .offset(skip) - .limit(limit) - ) - result = await session.execute(query) - return list(result.scalars().all()) - - -@router.get("/podcasts/voices", response_model=list[VoiceOption]) -async def list_voices(language: str | None = None): - """Voices the active TTS provider offers, optionally filtered by language.""" - if not app_config.TTS_SERVICE: - raise HTTPException(status_code=503, detail="No TTS provider configured") - - provider = provider_from_service(app_config.TTS_SERVICE) - catalog = get_voice_catalog() - voices = ( - catalog.for_language(provider, language) - if language - else catalog.for_provider(provider) - ) - return [ - VoiceOption( - voice_id=v.voice_id, - display_name=v.display_name, - language=v.language, - gender=v.gender.value, - ) - for v in voices - ] - - -@router.get("/podcasts/languages", response_model=LanguageOptions) -async def list_languages(): - """Languages the active TTS provider can offer the brief editor.""" - if not app_config.TTS_SERVICE: - raise HTTPException(status_code=503, detail="No TTS provider configured") - - provider = provider_from_service(app_config.TTS_SERVICE) - offering = get_voice_catalog().offerable_languages(provider) - return LanguageOptions( - languages=offering.languages, - allows_custom=offering.allows_custom, - ) - - -@router.get("/podcasts/voices/{voice_id}/preview") -async def preview_voice( - voice_id: str, - user: User = Depends(current_active_user), -): - """A short audio sample of a voice, so users pick by sound.""" - if not app_config.TTS_SERVICE: - raise HTTPException(status_code=503, detail="No TTS provider configured") - - provider = provider_from_service(app_config.TTS_SERVICE) - try: - voice = get_voice_catalog().get(voice_id) - except KeyError: - raise HTTPException(status_code=404, detail="Unknown voice") from None - if voice.provider is not provider: - raise HTTPException( - status_code=404, detail="Voice not offered by the active TTS provider" - ) - - data, content_type = await render_voice_preview(voice, get_text_to_speech()) - return Response(content=data, media_type=content_type) - - -@router.post("/podcasts", response_model=PodcastDetail, status_code=201) -async def create_podcast( - body: CreatePodcastRequest, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - await _require(session, user, body.search_space_id, Permission.PODCASTS_CREATE) - - service = PodcastService(session) - podcast = await service.create( - title=body.title, - search_space_id=body.search_space_id, - thread_id=body.thread_id, - ) - podcast.source_content = body.source_content - - spec = await propose_brief( - session, - search_space_id=body.search_space_id, - speaker_count=body.speaker_count, - min_seconds=body.min_seconds, - max_seconds=body.max_seconds, - focus=body.focus, - ) - await service.attach_brief(podcast, spec) - await session.commit() - return PodcastDetail.of(podcast) - - -@router.get("/podcasts/{podcast_id}", response_model=PodcastDetail) -async def get_podcast( - podcast_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - podcast = await _load(session, user, podcast_id, Permission.PODCASTS_READ) - return PodcastDetail.of(podcast) - - -@router.patch("/podcasts/{podcast_id}/spec", response_model=PodcastDetail) -async def update_spec( - podcast_id: int, - body: UpdateSpecRequest, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE) - async with _lifecycle_errors(): - await PodcastService(session).update_spec( - podcast, body.spec, body.expected_version - ) - await session.commit() - return PodcastDetail.of(podcast) - - -@router.post("/podcasts/{podcast_id}/brief/approve", response_model=PodcastDetail) -async def approve_brief( - podcast_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """Approve the brief and start drafting the transcript.""" - podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE) - async with _lifecycle_errors(): - await PodcastService(session).begin_drafting(podcast) - await session.commit() - draft_transcript_task.delay(podcast.id, podcast.search_space_id) - return PodcastDetail.of(podcast) - - -@router.post( - "/podcasts/{podcast_id}/transcript/regenerate", response_model=PodcastDetail -) -async def regenerate_transcript( - podcast_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """Reopen the brief gate for a fresh take; drafting waits for re-approval.""" - podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE) - async with _lifecycle_errors(): - await PodcastService(session).regenerate(podcast) - await session.commit() - return PodcastDetail.of(podcast) - - -@router.post("/podcasts/{podcast_id}/regenerate/revert", response_model=PodcastDetail) -async def revert_regeneration( - podcast_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """Back out of a regeneration and return to the finished episode.""" - podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE) - async with _lifecycle_errors(): - await PodcastService(session).revert_regeneration(podcast) - await session.commit() - return PodcastDetail.of(podcast) - - -@router.post("/podcasts/{podcast_id}/cancel", response_model=PodcastDetail) -async def cancel_podcast( - podcast_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE) - async with _lifecycle_errors(): - await PodcastService(session).cancel(podcast) - await session.commit() - return PodcastDetail.of(podcast) - - -@router.delete("/podcasts/{podcast_id}", response_model=dict) -async def delete_podcast( - podcast_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - podcast = await _load(session, user, podcast_id, Permission.PODCASTS_DELETE) - await purge_audio(podcast) - await session.delete(podcast) - await session.commit() - return {"message": "Podcast deleted successfully"} - - -@router.get("/podcasts/{podcast_id}/stream") -async def stream_podcast( - podcast_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - podcast = await _load(session, user, podcast_id, Permission.PODCASTS_READ) - - if podcast.storage_key: - # Verify first so a missing object is a 404, not a mid-stream crash. - if not await audio_exists(podcast): - raise HTTPException( - status_code=404, detail="Podcast audio is no longer available" - ) - return StreamingResponse( - open_audio_stream(podcast), - media_type="audio/mpeg", - headers={"Accept-Ranges": "bytes"}, - ) - - # Back-compat: rows rendered before the storage migration kept a local path. - if podcast.file_location and os.path.isfile(podcast.file_location): - path = podcast.file_location - - def iterfile(): - with open(path, mode="rb") as handle: - yield from handle - - return StreamingResponse( - iterfile(), - media_type="audio/mpeg", - headers={ - "Accept-Ranges": "bytes", - "Content-Disposition": f"inline; filename={Path(path).name}", - }, - ) - - # No audio: terminal states never will have any, otherwise it's in flight. - if PodcastStatus(podcast.status).is_terminal: - raise HTTPException(status_code=404, detail="Podcast audio not found") - raise HTTPException(status_code=409, detail="Podcast audio is not ready yet") - - -async def _require( - session: AsyncSession, - user: User, - search_space_id: int, - permission: Permission, -) -> None: - await check_permission( - session, - user, - search_space_id, - permission.value, - "You don't have permission for podcasts in this search space", - ) - - -async def _load( - session: AsyncSession, - user: User, - podcast_id: int, - permission: Permission, -) -> Podcast: - podcast = await PodcastRepository(session).get(podcast_id) - if podcast is None: - raise HTTPException(status_code=404, detail="Podcast not found") - await _require(session, user, podcast.search_space_id, permission) - return podcast - - -@asynccontextmanager -async def _lifecycle_errors() -> AsyncIterator[None]: - """Map service lifecycle errors onto HTTP responses.""" - try: - yield - except (SpecConflictError, InvalidTransitionError) as exc: - raise HTTPException(status_code=409, detail=str(exc)) from exc - except PreconditionFailedError as exc: - raise HTTPException(status_code=422, detail=str(exc)) from exc diff --git a/surfsense_backend/app/podcasts/api/schemas.py b/surfsense_backend/app/podcasts/api/schemas.py deleted file mode 100644 index cb8559651..000000000 --- a/surfsense_backend/app/podcasts/api/schemas.py +++ /dev/null @@ -1,120 +0,0 @@ -"""Request and response shapes for the podcast API. - -Read models surface the lifecycle state the frontend can't derive from Zero (the -deserialized brief and transcript); the action requests carry just what each -guarded transition needs. -""" - -from __future__ import annotations - -from datetime import datetime - -from pydantic import BaseModel, ConfigDict, Field - -from app.podcasts.duration_limits import ( - DEFAULT_MAX_SECONDS, - DEFAULT_MIN_SECONDS, - MAX_DURATION_SECONDS, - MIN_DURATION_SECONDS, -) -from app.podcasts.persistence import Podcast, PodcastStatus -from app.podcasts.schemas import PodcastSpec, Transcript -from app.podcasts.service import has_stored_episode, read_spec, read_transcript - -# Defaults applied when a create request omits brief sizing; the brief gate lets -# the user adjust before any cost is incurred. -DEFAULT_SPEAKER_COUNT = 2 - - -class CreatePodcastRequest(BaseModel): - """Create a podcast and kick off brief proposal.""" - - title: str = Field(..., min_length=1, max_length=500) - search_space_id: int - source_content: str = Field(..., min_length=1) - thread_id: int | None = None - speaker_count: int = Field(default=DEFAULT_SPEAKER_COUNT, ge=1, le=6) - min_seconds: int = Field( - default=DEFAULT_MIN_SECONDS, - ge=MIN_DURATION_SECONDS, - le=MAX_DURATION_SECONDS, - ) - max_seconds: int = Field( - default=DEFAULT_MAX_SECONDS, - ge=MIN_DURATION_SECONDS, - le=MAX_DURATION_SECONDS, - ) - focus: str | None = Field(default=None, max_length=2000) - - -class UpdateSpecRequest(BaseModel): - """Replace the brief at the gate, guarded by the expected version.""" - - spec: PodcastSpec - expected_version: int = Field(..., ge=1) - - -class VoiceOption(BaseModel): - """One selectable voice surfaced to the brief editor.""" - - voice_id: str - display_name: str - language: str - gender: str - - -class LanguageOptions(BaseModel): - """The languages the brief editor may offer for the active provider. - - When ``allows_custom`` is true the list is a curated starting point and - the editor accepts any BCP-47 tag beyond it. - """ - - languages: list[str] - allows_custom: bool - - -class PodcastSummary(BaseModel): - """Lightweight list item.""" - - model_config = ConfigDict(from_attributes=True) - - id: int - title: str - status: PodcastStatus - created_at: datetime - search_space_id: int - - -class PodcastDetail(BaseModel): - """Full podcast state for the detail view and action responses.""" - - id: int - title: str - status: PodcastStatus - spec_version: int - spec: PodcastSpec | None - transcript: Transcript | None - has_audio: bool - duration_seconds: int | None - error: str | None - created_at: datetime - search_space_id: int - thread_id: int | None - - @classmethod - def of(cls, podcast: Podcast) -> PodcastDetail: - return cls( - id=podcast.id, - title=podcast.title, - status=PodcastStatus(podcast.status), - spec_version=podcast.spec_version, - spec=read_spec(podcast), - transcript=read_transcript(podcast), - has_audio=has_stored_episode(podcast), - duration_seconds=podcast.duration_seconds, - error=podcast.error, - created_at=podcast.created_at, - search_space_id=podcast.search_space_id, - thread_id=podcast.thread_id, - ) diff --git a/surfsense_backend/app/podcasts/duration_limits.py b/surfsense_backend/app/podcasts/duration_limits.py deleted file mode 100644 index fc7d29890..000000000 --- a/surfsense_backend/app/podcasts/duration_limits.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Shared bounds and defaults for podcast target duration.""" - -MAX_DURATION_SECONDS = 24 * 60 * 60 -MIN_DURATION_SECONDS = 15 -DEFAULT_MIN_SECONDS = 20 -DEFAULT_MAX_SECONDS = 30 diff --git a/surfsense_backend/app/podcasts/generation/__init__.py b/surfsense_backend/app/podcasts/generation/__init__.py deleted file mode 100644 index a30b8f9af..000000000 --- a/surfsense_backend/app/podcasts/generation/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -"""Generation: the controlled graphs that produce a brief and a transcript. - -``brief`` proposes a reviewable spec from deterministic defaults; ``transcript`` -is the LLM-driven step, drafting long-form dialogue outline-first. -""" - -from __future__ import annotations - -from .brief import BriefConfig, BriefState, build_brief_graph -from .transcript import TranscriptConfig, TranscriptState, build_transcript_graph - -__all__ = [ - "BriefConfig", - "BriefState", - "TranscriptConfig", - "TranscriptState", - "build_brief_graph", - "build_transcript_graph", -] diff --git a/surfsense_backend/app/podcasts/generation/brief/__init__.py b/surfsense_backend/app/podcasts/generation/brief/__init__.py deleted file mode 100644 index 5083c4708..000000000 --- a/surfsense_backend/app/podcasts/generation/brief/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -"""Brief planning: propose a reviewable spec from last-used preferences.""" - -from __future__ import annotations - -from .config import BriefConfig -from .graph import build_brief_graph -from .propose import propose_brief -from .state import BriefState - -__all__ = ["BriefConfig", "BriefState", "build_brief_graph", "propose_brief"] diff --git a/surfsense_backend/app/podcasts/generation/brief/config.py b/surfsense_backend/app/podcasts/generation/brief/config.py deleted file mode 100644 index 9b206bde4..000000000 --- a/surfsense_backend/app/podcasts/generation/brief/config.py +++ /dev/null @@ -1,33 +0,0 @@ -"""Configurable inputs for the brief-planning graph.""" - -from __future__ import annotations - -from dataclasses import dataclass, field, fields - -from langchain_core.runnables import RunnableConfig - -from app.podcasts.duration_limits import ( - DEFAULT_MAX_SECONDS, - DEFAULT_MIN_SECONDS, -) - -# Sensible defaults for a fresh brief; the user adjusts the range at the gate. -DEFAULT_SPEAKER_COUNT = 2 - - -@dataclass(kw_only=True) -class BriefConfig: - """Signals used to propose a brief; everything here is non-LLM context.""" - - speaker_count: int = DEFAULT_SPEAKER_COUNT - min_seconds: int = DEFAULT_MIN_SECONDS - max_seconds: int = DEFAULT_MAX_SECONDS - focus: str | None = None - last_used_language: str | None = None - last_used_voices: list[str] = field(default_factory=list) - - @classmethod - def from_runnable_config(cls, config: RunnableConfig | None = None) -> BriefConfig: - configurable = (config.get("configurable") or {}) if config else {} - names = {f.name for f in fields(cls) if f.init} - return cls(**{k: v for k, v in configurable.items() if k in names}) diff --git a/surfsense_backend/app/podcasts/generation/brief/graph.py b/surfsense_backend/app/podcasts/generation/brief/graph.py deleted file mode 100644 index a643bdbb4..000000000 --- a/surfsense_backend/app/podcasts/generation/brief/graph.py +++ /dev/null @@ -1,25 +0,0 @@ -"""The brief-planning graph: propose a reviewable spec from defaults.""" - -from __future__ import annotations - -from langgraph.graph import StateGraph - -from .config import BriefConfig -from .nodes import propose_spec -from .state import BriefState - - -def build_brief_graph(): - workflow = StateGraph(BriefState, config_schema=BriefConfig) - - workflow.add_node("propose_spec", propose_spec) - - workflow.add_edge("__start__", "propose_spec") - workflow.add_edge("propose_spec", "__end__") - - graph = workflow.compile() - graph.name = "Surfsense Podcast Brief" - return graph - - -graph = build_brief_graph() diff --git a/surfsense_backend/app/podcasts/generation/brief/nodes.py b/surfsense_backend/app/podcasts/generation/brief/nodes.py deleted file mode 100644 index de6a9717e..000000000 --- a/surfsense_backend/app/podcasts/generation/brief/nodes.py +++ /dev/null @@ -1,119 +0,0 @@ -"""Brief-planning node: propose a full spec from deterministic defaults. - -``propose_spec`` is pure resolution — it never spends tokens. It reuses the -user's last-used language/voices when available and otherwise falls back to -English, so the brief gate opens pre-filled and the common case needs no edits. -""" - -from __future__ import annotations - -from typing import Any - -from langchain_core.runnables import RunnableConfig - -from app.config import config as app_config -from app.podcasts.resolution import ( - DEFAULT_LANGUAGE, - LanguageContext, - resolve_language, - resolve_voices, -) -from app.podcasts.schemas import ( - DurationTarget, - PodcastSpec, - PodcastStyle, - SpeakerRole, - SpeakerSpec, - normalize_language_tag, -) -from app.podcasts.voices import ( - TtsProvider, - VoiceCatalog, - get_voice_catalog, - provider_from_service, -) - -from .config import BriefConfig -from .state import BriefState - -# Default role per speaker slot; extra speakers beyond the list fall back to guest. -_ROLE_BY_SLOT = ( - SpeakerRole.HOST, - SpeakerRole.GUEST, - SpeakerRole.EXPERT, - SpeakerRole.COHOST, - SpeakerRole.NARRATOR, -) - - -def propose_spec(state: BriefState, config: RunnableConfig) -> dict[str, Any]: - """Build a complete :class:`PodcastSpec` from the resolved defaults.""" - brief = BriefConfig.from_runnable_config(config) - provider = _active_provider() - catalog = get_voice_catalog() - - language = _supported_language( - last_used=brief.last_used_language, - provider=provider, - catalog=catalog, - ) - voices = resolve_voices( - catalog=catalog, - provider=provider, - language=language, - speaker_count=brief.speaker_count, - preferred=brief.last_used_voices, - ) - - speakers = [ - SpeakerSpec( - slot=slot, - name=_default_name(slot), - role=_role_for(slot), - voice_id=voice.voice_id, - ) - for slot, voice in enumerate(voices) - ] - spec = PodcastSpec( - language=language, - style=PodcastStyle.CONVERSATIONAL, - speakers=speakers, - duration=DurationTarget( - min_seconds=brief.min_seconds, max_seconds=brief.max_seconds - ), - focus=brief.focus, - ) - return {"spec": spec} - - -def _active_provider() -> TtsProvider: - service = app_config.TTS_SERVICE - if not service: - raise ValueError("TTS_SERVICE is not configured") - return provider_from_service(service) - - -def _supported_language( - *, - last_used: str | None, - provider: TtsProvider, - catalog: VoiceCatalog, -) -> str: - raw = resolve_language(LanguageContext(last_used=last_used)) - try: - language = normalize_language_tag(raw) - except ValueError: - language = DEFAULT_LANGUAGE - if not catalog.supports_language(provider, language): - return DEFAULT_LANGUAGE - return language - - -def _role_for(slot: int) -> SpeakerRole: - return _ROLE_BY_SLOT[slot] if slot < len(_ROLE_BY_SLOT) else SpeakerRole.GUEST - - -def _default_name(slot: int) -> str: - role = _role_for(slot) - label = role.value.replace("cohost", "co-host").title() - return label if slot < len(_ROLE_BY_SLOT) else f"{label} {slot}" diff --git a/surfsense_backend/app/podcasts/generation/brief/propose.py b/surfsense_backend/app/podcasts/generation/brief/propose.py deleted file mode 100644 index 09d74840e..000000000 --- a/surfsense_backend/app/podcasts/generation/brief/propose.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Propose a podcast's initial brief spec.""" - -from __future__ import annotations - -from sqlalchemy.ext.asyncio import AsyncSession - -from app.podcasts.duration_limits import DEFAULT_MAX_SECONDS, DEFAULT_MIN_SECONDS -from app.podcasts.persistence import PodcastRepository -from app.podcasts.schemas import PodcastSpec -from app.podcasts.service import preferences_from - -from .config import DEFAULT_SPEAKER_COUNT -from .graph import graph as brief_graph -from .state import BriefState - - -async def propose_brief( - session: AsyncSession, - *, - search_space_id: int, - speaker_count: int = DEFAULT_SPEAKER_COUNT, - min_seconds: int = DEFAULT_MIN_SECONDS, - max_seconds: int = DEFAULT_MAX_SECONDS, - focus: str | None = None, -) -> PodcastSpec: - """Reuse the last-used language and voices, else English; return the spec.""" - last_language, last_voices = preferences_from( - await PodcastRepository(session).latest_with_spec(search_space_id) - ) - config = { - "configurable": { - "speaker_count": speaker_count, - "min_seconds": min_seconds, - "max_seconds": max_seconds, - "focus": focus, - "last_used_language": last_language, - "last_used_voices": last_voices, - } - } - result = await brief_graph.ainvoke(BriefState(), config=config) - return result["spec"] diff --git a/surfsense_backend/app/podcasts/generation/brief/state.py b/surfsense_backend/app/podcasts/generation/brief/state.py deleted file mode 100644 index 418fb6fa9..000000000 --- a/surfsense_backend/app/podcasts/generation/brief/state.py +++ /dev/null @@ -1,14 +0,0 @@ -"""Mutable state threaded through the brief-planning graph.""" - -from __future__ import annotations - -from dataclasses import dataclass - -from app.podcasts.schemas import PodcastSpec - - -@dataclass -class BriefState: - """The proposed spec the graph produces; inputs arrive via the config.""" - - spec: PodcastSpec | None = None diff --git a/surfsense_backend/app/podcasts/generation/prompts/__init__.py b/surfsense_backend/app/podcasts/generation/prompts/__init__.py deleted file mode 100644 index 041dd4e6d..000000000 --- a/surfsense_backend/app/podcasts/generation/prompts/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -"""Prompt builders for the generation graphs.""" - -from __future__ import annotations - -from .draft_segment import draft_segment_prompt -from .plan_outline import plan_outline_prompt -from .speakers import render_speaker_roster - -__all__ = [ - "draft_segment_prompt", - "plan_outline_prompt", - "render_speaker_roster", -] diff --git a/surfsense_backend/app/podcasts/generation/prompts/draft_segment.py b/surfsense_backend/app/podcasts/generation/prompts/draft_segment.py deleted file mode 100644 index c81dfa385..000000000 --- a/surfsense_backend/app/podcasts/generation/prompts/draft_segment.py +++ /dev/null @@ -1,54 +0,0 @@ -"""Prompt for drafting one outline segment into dialogue turns. - -Each segment is drafted on its own so long episodes stay coherent and within -context limits. A short recap of the preceding dialogue is passed in so the new -segment continues naturally instead of restarting. The model must write in the -episode language and attribute every line to a real speaker slot. -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -from app.podcasts.schemas import PodcastSpec - -from .speakers import render_speaker_roster - -if TYPE_CHECKING: - from app.podcasts.generation.transcript.planning import OutlineSegment - - -def draft_segment_prompt( - *, - spec: PodcastSpec, - segment: OutlineSegment, - position: int, - total: int, - recap: str | None, -) -> str: - talking_points = "\n".join(f"- {point}" for point in segment.talking_points) - recap_block = ( - f"\nRecap of the conversation so far (continue from here, do not repeat " - f"it):\n{recap}\n" - if recap - else "\nThis is the opening segment; begin the conversation naturally.\n" - ) - return f"""\ -You are scripting natural, engaging podcast dialogue for segment {position} of \ -{total}. - -Write entirely in {spec.language}. The format is {spec.style.value}. -Speakers — attribute every line using these exact slot numbers: -{render_speaker_roster(spec)} -{recap_block} -This segment is "{segment.title}". Cover these points using only facts grounded \ -in the provided source content: -{talking_points} - -Aim for about {segment.target_words} words of dialogue. Keep turns conversational \ -and varied; speakers should react to each other rather than deliver monologues. \ -Do not add greetings or sign-offs unless this is the first or last segment. - -Respond with strict JSON and nothing else: -{{"turns": [{{"speaker": <slot>, "text": "..."}}]}} -""" diff --git a/surfsense_backend/app/podcasts/generation/prompts/plan_outline.py b/surfsense_backend/app/podcasts/generation/prompts/plan_outline.py deleted file mode 100644 index 1b227c2ff..000000000 --- a/surfsense_backend/app/podcasts/generation/prompts/plan_outline.py +++ /dev/null @@ -1,47 +0,0 @@ -"""Prompt for planning a long-form podcast outline before drafting dialogue. - -Outlining first is what makes long-form reliable: a single LLM call cannot hold -a coherent one- to two-hour script, but it can plan segments that are then -drafted independently against a shared plan. The prompt is told the target -length so the number and size of segments scale with the requested duration. -""" - -from __future__ import annotations - -from app.podcasts.schemas import PodcastSpec - -from .speakers import render_speaker_roster - - -def plan_outline_prompt( - *, - spec: PodcastSpec, - target_words: int, - suggested_segments: int, - focus: str | None, -) -> str: - focus_block = ( - f"\nThe user asked the episode to focus on:\n{focus}\n" if focus else "" - ) - return f"""\ -You are a podcast showrunner planning the structure of an episode before any \ -dialogue is written. - -The episode language is {spec.language}. The format is {spec.style.value}. -Speakers (refer to them by these slots later): -{render_speaker_roster(spec)} -{focus_block} -Plan an outline that, when fully drafted, reaches roughly {target_words} words \ -of spoken dialogue (about {suggested_segments} segments). Each segment is one \ -coherent beat of the conversation: an opening, distinct topic areas grounded in \ -the source content, and a closing. - -For each segment provide: -- title: a short label for the beat -- talking_points: 2-5 concrete points to cover, drawn from the source content -- target_words: how many words of dialogue this segment should run (the sum \ -across segments should approximate {target_words}) - -Respond with strict JSON and nothing else: -{{"segments": [{{"title": "...", "talking_points": ["..."], "target_words": 0}}]}} -""" diff --git a/surfsense_backend/app/podcasts/generation/prompts/speakers.py b/surfsense_backend/app/podcasts/generation/prompts/speakers.py deleted file mode 100644 index 9df4138df..000000000 --- a/surfsense_backend/app/podcasts/generation/prompts/speakers.py +++ /dev/null @@ -1,18 +0,0 @@ -"""Render a spec's speaker roster for prompts. - -The drafting prompts must reference speakers by the exact ``slot`` the renderer -expects, so this is the single place that formats that roster — keeping the -slot contract identical across every prompt that mentions speakers. -""" - -from __future__ import annotations - -from app.podcasts.schemas import PodcastSpec - - -def render_speaker_roster(spec: PodcastSpec) -> str: - lines = [ - f"- slot {speaker.slot} — {speaker.name} (role: {speaker.role.value})" - for speaker in spec.speakers - ] - return "\n".join(lines) diff --git a/surfsense_backend/app/podcasts/generation/structured.py b/surfsense_backend/app/podcasts/generation/structured.py deleted file mode 100644 index 08132e776..000000000 --- a/surfsense_backend/app/podcasts/generation/structured.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Parse a model's reply into a Pydantic shape, tolerating chatty output. - -Agent LLMs return JSON wrapped in prose, markdown fences, or reasoning blocks, -so a plain ``model_validate_json`` is unreliable. Centralising the tolerant -parse here keeps every generation node validating replies the same way. -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, TypeVar - -from pydantic import BaseModel, ValidationError - -from app.utils.content_utils import extract_text_content, strip_markdown_fences - -if TYPE_CHECKING: - from langchain_core.messages import BaseMessage - -T = TypeVar("T", bound=BaseModel) - - -class StructuredOutputError(RuntimeError): - """The model reply could not be parsed into the expected shape.""" - - -async def invoke_json[T: BaseModel]( - llm, messages: list[BaseMessage], model: type[T] -) -> T: - """Invoke ``llm`` and validate its reply as ``model``.""" - response = await llm.ainvoke(messages) - content = strip_markdown_fences(extract_text_content(response.content)) - - try: - return model.model_validate_json(content) - except (ValidationError, ValueError): - pass - - start = content.find("{") - end = content.rfind("}") + 1 - if 0 <= start < end: - try: - return model.model_validate_json(content[start:end]) - except (ValidationError, ValueError) as exc: - raise StructuredOutputError( - f"could not parse {model.__name__} from model reply" - ) from exc - - raise StructuredOutputError( - f"no JSON object found for {model.__name__} in model reply" - ) diff --git a/surfsense_backend/app/podcasts/generation/transcript/__init__.py b/surfsense_backend/app/podcasts/generation/transcript/__init__.py deleted file mode 100644 index 5c8f23cd7..000000000 --- a/surfsense_backend/app/podcasts/generation/transcript/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -"""Transcript drafting: outline-first, long-form dialogue generation.""" - -from __future__ import annotations - -from .config import TranscriptConfig -from .graph import build_transcript_graph -from .planning import Outline, OutlineSegment, SegmentDraft -from .state import TranscriptState - -__all__ = [ - "Outline", - "OutlineSegment", - "SegmentDraft", - "TranscriptConfig", - "TranscriptState", - "build_transcript_graph", -] diff --git a/surfsense_backend/app/podcasts/generation/transcript/config.py b/surfsense_backend/app/podcasts/generation/transcript/config.py deleted file mode 100644 index f627fc166..000000000 --- a/surfsense_backend/app/podcasts/generation/transcript/config.py +++ /dev/null @@ -1,26 +0,0 @@ -"""Configurable inputs for the transcript-drafting graph.""" - -from __future__ import annotations - -from dataclasses import dataclass, fields - -from langchain_core.runnables import RunnableConfig - -from app.podcasts.schemas import PodcastSpec - - -@dataclass(kw_only=True) -class TranscriptConfig: - """The approved spec and user focus that drive drafting.""" - - search_space_id: int - spec: PodcastSpec - focus: str | None = None - - @classmethod - def from_runnable_config( - cls, config: RunnableConfig | None = None - ) -> TranscriptConfig: - configurable = (config.get("configurable") or {}) if config else {} - names = {f.name for f in fields(cls) if f.init} - return cls(**{k: v for k, v in configurable.items() if k in names}) diff --git a/surfsense_backend/app/podcasts/generation/transcript/graph.py b/surfsense_backend/app/podcasts/generation/transcript/graph.py deleted file mode 100644 index 2f97db50f..000000000 --- a/surfsense_backend/app/podcasts/generation/transcript/graph.py +++ /dev/null @@ -1,29 +0,0 @@ -"""The transcript-drafting graph: outline, draft segments, finalize.""" - -from __future__ import annotations - -from langgraph.graph import StateGraph - -from .config import TranscriptConfig -from .nodes import draft_segments, finalize, plan_outline -from .state import TranscriptState - - -def build_transcript_graph(): - workflow = StateGraph(TranscriptState, config_schema=TranscriptConfig) - - workflow.add_node("plan_outline", plan_outline) - workflow.add_node("draft_segments", draft_segments) - workflow.add_node("finalize", finalize) - - workflow.add_edge("__start__", "plan_outline") - workflow.add_edge("plan_outline", "draft_segments") - workflow.add_edge("draft_segments", "finalize") - workflow.add_edge("finalize", "__end__") - - graph = workflow.compile() - graph.name = "Surfsense Podcast Transcript" - return graph - - -graph = build_transcript_graph() diff --git a/surfsense_backend/app/podcasts/generation/transcript/nodes.py b/surfsense_backend/app/podcasts/generation/transcript/nodes.py deleted file mode 100644 index 7b472348d..000000000 --- a/surfsense_backend/app/podcasts/generation/transcript/nodes.py +++ /dev/null @@ -1,127 +0,0 @@ -"""Transcript-drafting nodes: plan an outline, draft each beat, then assemble. - -Long-form is produced beat-by-beat: a single call plans the structure, then each -segment is drafted on its own with a recap of what came before so the script -stays coherent without holding the whole episode in one context window. -""" - -from __future__ import annotations - -from typing import Any - -from langchain_core.messages import HumanMessage, SystemMessage -from langchain_core.runnables import RunnableConfig - -from app.podcasts.schemas import PodcastSpec, Transcript, TranscriptTurn -from app.services.llm_service import get_agent_llm - -from ..prompts import draft_segment_prompt, plan_outline_prompt -from ..structured import invoke_json -from .config import TranscriptConfig -from .planning import Outline, SegmentDraft -from .state import TranscriptState - -# Average speaking rate; converts target minutes to a target word count. -_WORDS_PER_MINUTE = 150 -# Rough words per outline segment, used to suggest how many segments to plan. -_WORDS_PER_SEGMENT = 250 -# Cap on source text sent per LLM call to bound tokens on large sources. -_SOURCE_BUDGET_CHARS = 12000 -# How much prior dialogue to recap into each segment for continuity. -_RECAP_CHARS = 800 - - -async def plan_outline( - state: TranscriptState, config: RunnableConfig -) -> dict[str, Any]: - """Plan the segment structure sized to the spec's target duration.""" - tc = TranscriptConfig.from_runnable_config(config) - llm = await _require_llm(state, tc) - - target_words = round(tc.spec.duration.midpoint_seconds * _WORDS_PER_MINUTE / 60) - suggested_segments = max(1, round(target_words / _WORDS_PER_SEGMENT)) - - messages = [ - SystemMessage( - content=plan_outline_prompt( - spec=tc.spec, - target_words=target_words, - suggested_segments=suggested_segments, - focus=tc.focus, - ) - ), - HumanMessage(content=_source_block(state.source_content)), - ] - outline = await invoke_json(llm, messages, Outline) - return {"outline": outline} - - -async def draft_segments( - state: TranscriptState, config: RunnableConfig -) -> dict[str, Any]: - """Draft each outline segment in order, carrying a running recap.""" - tc = TranscriptConfig.from_runnable_config(config) - llm = await _require_llm(state, tc) - outline = state.outline - if outline is None: - raise RuntimeError("draft_segments requires an outline") - - source_block = _source_block(state.source_content) - turns: list[TranscriptTurn] = [] - total = len(outline.segments) - - for index, segment in enumerate(outline.segments): - messages = [ - SystemMessage( - content=draft_segment_prompt( - spec=tc.spec, - segment=segment, - position=index + 1, - total=total, - recap=_recap(turns, tc.spec), - ) - ), - HumanMessage(content=source_block), - ] - draft = await invoke_json(llm, messages, SegmentDraft) - turns.extend(_valid_turns(draft, tc.spec)) - - return {"drafted_turns": turns} - - -def finalize(state: TranscriptState, config: RunnableConfig) -> dict[str, Any]: - """Assemble drafted turns into a validated transcript.""" - if not state.drafted_turns: - raise RuntimeError("drafting produced no usable dialogue") - return {"transcript": Transcript(turns=state.drafted_turns)} - - -async def _require_llm(state: TranscriptState, tc: TranscriptConfig): - llm = await get_agent_llm(state.db_session, tc.search_space_id) - if llm is None: - raise RuntimeError( - f"no agent LLM configured for search space {tc.search_space_id}" - ) - return llm - - -def _source_block(source_content: str) -> str: - sample = (source_content or "")[:_SOURCE_BUDGET_CHARS] - return f"<source_content>{sample}</source_content>" - - -def _valid_turns(draft: SegmentDraft, spec: PodcastSpec) -> list[TranscriptTurn]: - # Drop any turn the model attributed to a slot the spec doesn't define, so a - # stray attribution can't break rendering downstream. - valid_slots = {speaker.slot for speaker in spec.speakers} - return [turn for turn in draft.turns if turn.speaker in valid_slots] - - -def _recap(turns: list[TranscriptTurn], spec: PodcastSpec) -> str | None: - if not turns: - return None - names = {speaker.slot: speaker.name for speaker in spec.speakers} - rendered = "\n".join( - f"{names.get(turn.speaker, turn.speaker)}: {turn.text}" for turn in turns - ) - return rendered[-_RECAP_CHARS:] diff --git a/surfsense_backend/app/podcasts/generation/transcript/planning.py b/surfsense_backend/app/podcasts/generation/transcript/planning.py deleted file mode 100644 index 3f6aeac9b..000000000 --- a/surfsense_backend/app/podcasts/generation/transcript/planning.py +++ /dev/null @@ -1,32 +0,0 @@ -"""Internal shapes the transcript graph passes between its nodes. - -These are generation-time artifacts (the outline and per-segment drafts), not -persisted or API-facing. Segment drafts reuse :class:`TranscriptTurn` so the -speaker-slot contract and turn validation are identical to the final transcript. -""" - -from __future__ import annotations - -from pydantic import BaseModel, Field - -from app.podcasts.schemas import TranscriptTurn - - -class OutlineSegment(BaseModel): - """One planned beat of the conversation, drafted independently.""" - - title: str = Field(..., min_length=1) - talking_points: list[str] = Field(default_factory=list) - target_words: int = Field(..., ge=1) - - -class Outline(BaseModel): - """The full plan: ordered segments sized to the target duration.""" - - segments: list[OutlineSegment] = Field(..., min_length=1) - - -class SegmentDraft(BaseModel): - """The dialogue a single segment produced.""" - - turns: list[TranscriptTurn] = Field(default_factory=list) diff --git a/surfsense_backend/app/podcasts/generation/transcript/state.py b/surfsense_backend/app/podcasts/generation/transcript/state.py deleted file mode 100644 index f11337471..000000000 --- a/surfsense_backend/app/podcasts/generation/transcript/state.py +++ /dev/null @@ -1,22 +0,0 @@ -"""Mutable state threaded through the transcript-drafting graph.""" - -from __future__ import annotations - -from dataclasses import dataclass, field - -from sqlalchemy.ext.asyncio import AsyncSession - -from app.podcasts.schemas import Transcript, TranscriptTurn - -from .planning import Outline - - -@dataclass -class TranscriptState: - """Source content plus the intermediate and final drafting artifacts.""" - - db_session: AsyncSession - source_content: str - outline: Outline | None = None - drafted_turns: list[TranscriptTurn] = field(default_factory=list) - transcript: Transcript | None = None diff --git a/surfsense_backend/app/podcasts/persistence/__init__.py b/surfsense_backend/app/podcasts/persistence/__init__.py deleted file mode 100644 index 2166d5d9d..000000000 --- a/surfsense_backend/app/podcasts/persistence/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Models, enums, and data access for the podcasts table.""" - -from __future__ import annotations - -from .enums import PodcastStatus -from .models import Podcast -from .repository import PodcastRepository - -__all__ = ["Podcast", "PodcastRepository", "PodcastStatus"] diff --git a/surfsense_backend/app/podcasts/persistence/enums/__init__.py b/surfsense_backend/app/podcasts/persistence/enums/__init__.py deleted file mode 100644 index f0527fd78..000000000 --- a/surfsense_backend/app/podcasts/persistence/enums/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Enums for the podcasts table.""" - -from __future__ import annotations - -from .podcast_status import PodcastStatus - -__all__ = ["PodcastStatus"] diff --git a/surfsense_backend/app/podcasts/persistence/enums/podcast_status.py b/surfsense_backend/app/podcasts/persistence/enums/podcast_status.py deleted file mode 100644 index 28f29afb5..000000000 --- a/surfsense_backend/app/podcasts/persistence/enums/podcast_status.py +++ /dev/null @@ -1,42 +0,0 @@ -"""Podcast generation lifecycle. - -The status drives a guarded state machine. A podcast is proposed (``PENDING``), -gets a reviewable brief (``AWAITING_BRIEF``), is drafted into a transcript -(``DRAFTING``), then rendered to audio (``RENDERING`` → ``READY``). ``FAILED`` -and ``CANCELLED`` are terminal; a ``READY`` episode can be sent back to the -brief gate for regeneration, and an in-flight regeneration can be reverted to -``READY`` while the previous audio still exists. ``AWAITING_REVIEW`` is -retained for legacy rows but -never entered anymore — the brief is the only approval gate. The Python enum is -kept in lockstep with the ``podcast_status`` Postgres type via its paired -migration. -""" - -from __future__ import annotations - -from enum import StrEnum - - -class PodcastStatus(StrEnum): - PENDING = "pending" - AWAITING_BRIEF = "awaiting_brief" - DRAFTING = "drafting" - AWAITING_REVIEW = "awaiting_review" - RENDERING = "rendering" - READY = "ready" - FAILED = "failed" - CANCELLED = "cancelled" - - @property - def is_terminal(self) -> bool: - """Whether no further transition is possible from this state.""" - return self in _TERMINAL - - @property - def is_gate(self) -> bool: - """Whether this state waits on user input before proceeding.""" - return self in _GATES - - -_TERMINAL = frozenset({PodcastStatus.FAILED, PodcastStatus.CANCELLED}) -_GATES = frozenset({PodcastStatus.AWAITING_BRIEF, PodcastStatus.AWAITING_REVIEW}) diff --git a/surfsense_backend/app/podcasts/persistence/models.py b/surfsense_backend/app/podcasts/persistence/models.py deleted file mode 100644 index 6e40a8040..000000000 --- a/surfsense_backend/app/podcasts/persistence/models.py +++ /dev/null @@ -1,82 +0,0 @@ -"""``podcasts`` table: a generated podcast, its brief, transcript, and state.""" - -from __future__ import annotations - -from sqlalchemy import ( - Column, - Enum as SQLAlchemyEnum, - ForeignKey, - Integer, - String, - Text, -) -from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy.orm import relationship - -from app.db import BaseModel, TimestampMixin - -from .enums import PodcastStatus - - -class Podcast(BaseModel, TimestampMixin): - """A podcast across its whole lifecycle: brief, transcript, audio, status. - - ``spec`` (the reviewable brief) and ``podcast_transcript`` are JSONB so the - flexible Pydantic shapes can evolve without migrations. ``spec_version`` - backs optimistic concurrency on brief edits. Rendered audio lives in the - object store, addressed by ``storage_backend`` + ``storage_key`` rather than - a raw path. - """ - - __tablename__ = "podcasts" - - title = Column(String(500), nullable=False) - - status = Column( - SQLAlchemyEnum( - PodcastStatus, - name="podcast_status", - create_type=False, - values_callable=lambda x: [e.value for e in x], - ), - nullable=False, - default=PodcastStatus.PENDING, - server_default=PodcastStatus.PENDING.value, - index=True, - ) - - # The source material the episode is generated from. Persisted because - # drafting happens after the brief gate, long after creation. - source_content = Column(Text, nullable=True) - - # The reviewable brief (PodcastSpec); null until the brief gate is reached. - spec = Column(JSONB, nullable=True) - # Bumped on every spec edit; guards concurrent edits at the brief gate. - spec_version = Column(Integer, nullable=False, default=1, server_default="1") - - # The drafted dialogue (Transcript); null until drafting completes. - podcast_transcript = Column(JSONB, nullable=True) - - # Where the rendered audio lives in the object store; null until READY. - storage_backend = Column(String(32), nullable=True) - storage_key = Column(Text, nullable=True) - duration_seconds = Column(Integer, nullable=True) - - # Human-readable reason when status is FAILED. - error = Column(Text, nullable=True) - - # Legacy local audio path; retained for back-compat until cutover. - file_location = Column(Text, nullable=True) - - search_space_id = Column( - Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False - ) - search_space = relationship("SearchSpace", back_populates="podcasts") - - thread_id = Column( - Integer, - ForeignKey("new_chat_threads.id", ondelete="SET NULL"), - nullable=True, - index=True, - ) - thread = relationship("NewChatThread") diff --git a/surfsense_backend/app/podcasts/persistence/repository.py b/surfsense_backend/app/podcasts/persistence/repository.py deleted file mode 100644 index 04eae9ce1..000000000 --- a/surfsense_backend/app/podcasts/persistence/repository.py +++ /dev/null @@ -1,46 +0,0 @@ -"""Data access for the ``podcasts`` table. - -A thin async repository so the service and tasks never write raw queries. It -only loads and persists rows; lifecycle rules and (de)serialization live in the -service. -""" - -from __future__ import annotations - -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession - -from .models import Podcast - - -class PodcastRepository: - """Loads and stores :class:`Podcast` rows for one session.""" - - def __init__(self, session: AsyncSession) -> None: - self._session = session - - async def get(self, podcast_id: int) -> Podcast | None: - return await self._session.get(Podcast, podcast_id) - - async def add(self, podcast: Podcast) -> Podcast: - """Persist a new row and assign its primary key.""" - self._session.add(podcast) - await self._session.flush() - return podcast - - async def latest_with_spec(self, search_space_id: int) -> Podcast | None: - """Most recent podcast in the space that has a stored brief. - - Used to seed language/voice defaults for a new podcast from what the - user chose last. - """ - result = await self._session.execute( - select(Podcast) - .where( - Podcast.search_space_id == search_space_id, - Podcast.spec.is_not(None), - ) - .order_by(Podcast.created_at.desc()) - .limit(1) - ) - return result.scalars().first() diff --git a/surfsense_backend/app/podcasts/rendering/__init__.py b/surfsense_backend/app/podcasts/rendering/__init__.py deleted file mode 100644 index 9fb50a2e1..000000000 --- a/surfsense_backend/app/podcasts/rendering/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Rendering: synthesise and merge an approved transcript into audio. - -The :class:`PodcastRenderer` is the public entry point; the segment cache and -FFmpeg merge are implementation details it owns. -""" - -from __future__ import annotations - -from .errors import RenderError -from .renderer import PodcastRenderer, RenderedPodcast - -__all__ = ["PodcastRenderer", "RenderError", "RenderedPodcast"] diff --git a/surfsense_backend/app/podcasts/rendering/cache.py b/surfsense_backend/app/podcasts/rendering/cache.py deleted file mode 100644 index 32d9f0c21..000000000 --- a/surfsense_backend/app/podcasts/rendering/cache.py +++ /dev/null @@ -1,53 +0,0 @@ -"""Content-addressed cache for synthesised segments. - -Each segment's audio is keyed by everything that determines its bytes (voice, -language, speed, text). Keeping the cache in a stable per-podcast directory -makes re-renders cheap: changing one speaker's voice only misses that speaker's -turns, and a worker restart mid-render resumes from whatever was already -written. The key intentionally excludes the segment's position so identical -lines (e.g. repeated "Right.") synthesise once. -""" - -from __future__ import annotations - -import hashlib -import json -from pathlib import Path - -from app.podcasts.tts import SynthesisRequest - - -class SegmentCache: - """On-disk store of segment audio, addressed by request content hash.""" - - def __init__(self, root: Path) -> None: - self._root = root - self._root.mkdir(parents=True, exist_ok=True) - - def key(self, request: SynthesisRequest) -> str: - """A stable hash of the inputs that determine the synthesised bytes.""" - material = json.dumps( - { - "voice": request.voice, - "language": request.language, - "speed": request.speed, - "text": request.text, - }, - sort_keys=True, - ensure_ascii=True, - ) - return hashlib.sha256(material.encode("utf-8")).hexdigest() - - def path(self, key: str, container: str) -> Path: - return self._root / f"{key}.{container}" - - def get(self, key: str, container: str) -> Path | None: - """Return the cached segment path, or ``None`` on a miss.""" - path = self.path(key, container) - return path if path.exists() else None - - def put(self, key: str, container: str, data: bytes) -> Path: - """Write ``data`` for ``key`` and return its path.""" - path = self.path(key, container) - path.write_bytes(data) - return path diff --git a/surfsense_backend/app/podcasts/rendering/errors.py b/surfsense_backend/app/podcasts/rendering/errors.py deleted file mode 100644 index 7192890c6..000000000 --- a/surfsense_backend/app/podcasts/rendering/errors.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Failures raised while rendering a transcript to audio.""" - -from __future__ import annotations - - -class RenderError(RuntimeError): - """Rendering could not produce a final audio file. - - Wraps both per-segment synthesis failures and the merge step so the render - task sees one failure type regardless of where it originated. - """ diff --git a/surfsense_backend/app/podcasts/rendering/merge.py b/surfsense_backend/app/podcasts/rendering/merge.py deleted file mode 100644 index 223295349..000000000 --- a/surfsense_backend/app/podcasts/rendering/merge.py +++ /dev/null @@ -1,48 +0,0 @@ -"""Concatenate ordered segment files into a single MP3. - -Uses FFmpeg's concat *demuxer* (a list file of inputs) rather than a -``filter_complex`` graph. The demuxer takes one ``-i`` no matter how many -segments there are, so an hour-long episode with thousands of turns never hits -command-line length limits. Output is always re-encoded to MP3 for a uniform -artifact regardless of the source container (Kokoro WAV or hosted MP3). -""" - -from __future__ import annotations - -from pathlib import Path - -from ffmpeg.asyncio import FFmpeg - -from .errors import RenderError - - -async def concat_to_mp3(segment_paths: list[Path], output_path: Path) -> None: - """Merge ``segment_paths`` in order into ``output_path`` as MP3.""" - if not segment_paths: - raise RenderError("cannot merge an empty list of segments") - - list_file = output_path.with_name(f"{output_path.stem}.concat.txt") - list_file.write_text(_concat_list(segment_paths), encoding="utf-8") - - try: - ffmpeg = ( - FFmpeg() - .option("y") - .input(str(list_file), f="concat", safe=0) - .output(str(output_path), {"c:a": "libmp3lame"}) - ) - await ffmpeg.execute() - except Exception as exc: - raise RenderError(f"audio merge failed: {exc}") from exc - finally: - list_file.unlink(missing_ok=True) - - -def _concat_list(segment_paths: list[Path]) -> str: - # The concat demuxer reads `file '<path>'` lines; single quotes in a path - # are escaped per its quoting rules ('\''). - lines = [] - for path in segment_paths: - escaped = str(path.resolve()).replace("'", "'\\''") - lines.append(f"file '{escaped}'") - return "\n".join(lines) + "\n" diff --git a/surfsense_backend/app/podcasts/rendering/renderer.py b/surfsense_backend/app/podcasts/rendering/renderer.py deleted file mode 100644 index 44071c060..000000000 --- a/surfsense_backend/app/podcasts/rendering/renderer.py +++ /dev/null @@ -1,155 +0,0 @@ -"""Render an approved transcript into a single podcast audio file. - -The renderer is the only place that turns dialogue into sound. It maps each -turn to its speaker's voice, synthesises segments concurrently (capped, served -from the segment cache when possible, and coalesced so identical lines render -once), then merges them in order. It takes a settled spec + transcript and -returns bytes; persistence and lifecycle transitions belong to the service. -""" - -from __future__ import annotations - -import asyncio -from dataclasses import dataclass -from pathlib import Path - -from app.podcasts.schemas import PodcastSpec, Transcript, TranscriptTurn -from app.podcasts.tts import SynthesisRequest, TextToSpeech, TextToSpeechError -from app.podcasts.voices import VoiceCatalog - -from .cache import SegmentCache -from .errors import RenderError -from .merge import concat_to_mp3 - -# Bounds how many segments synthesise at once. Protects hosted-provider rate -# limits and avoids thrashing the local Kokoro pipeline; the renderer is I/O- or -# model-bound per segment, so a small pool already saturates throughput. -DEFAULT_MAX_CONCURRENCY = 4 - -_MERGED_FILENAME = "podcast.mp3" - - -@dataclass(frozen=True, slots=True) -class RenderedPodcast: - """The finished episode: encoded bytes plus their container.""" - - data: bytes - container: str - - -class PodcastRenderer: - """Synthesises and merges a transcript using one TTS provider.""" - - def __init__( - self, - *, - tts: TextToSpeech, - catalog: VoiceCatalog, - max_concurrency: int = DEFAULT_MAX_CONCURRENCY, - ) -> None: - self._tts = tts - self._catalog = catalog - self._max_concurrency = max_concurrency - - async def render( - self, - *, - spec: PodcastSpec, - transcript: Transcript, - workdir: Path, - ) -> RenderedPodcast: - """Produce the merged MP3 for ``transcript`` under ``spec``. - - ``workdir`` holds the segment cache and merge output; reusing the same - directory across renders is what makes voice edits cheap. - """ - cache = SegmentCache(workdir / "segments") - requests = [self._request_for(spec, turn) for turn in transcript.turns] - - # Concurrency primitives are created per render so each call is bound to - # the event loop running it (Celery tasks may use a fresh loop). - synthesizer = _SegmentSynthesizer(self._tts, cache, self._max_concurrency) - segment_paths = await asyncio.gather( - *(synthesizer.segment(request) for request in requests) - ) - - output_path = workdir / _MERGED_FILENAME - await concat_to_mp3(list(segment_paths), output_path) - return RenderedPodcast(data=output_path.read_bytes(), container="mp3") - - def _request_for(self, spec: PodcastSpec, turn: TranscriptTurn) -> SynthesisRequest: - try: - speaker = spec.speaker_for(turn.speaker) - except KeyError as exc: - raise RenderError( - f"transcript references unknown speaker slot {turn.speaker}" - ) from exc - try: - voice = self._catalog.get(speaker.voice_id) - except KeyError as exc: - raise RenderError(f"unknown voice {speaker.voice_id!r}") from exc - return SynthesisRequest( - text=turn.text, voice=voice.native_ref, language=spec.language - ) - - -class _SegmentSynthesizer: - """Per-render synthesis coordinator: caps concurrency and dedupes work. - - Beyond the on-disk cache (which serves cross-render reuse), this coalesces - identical segments that race within one render so the same line is voiced - once even when several turns request it simultaneously. - """ - - def __init__( - self, tts: TextToSpeech, cache: SegmentCache, max_concurrency: int - ) -> None: - self._tts = tts - self._cache = cache - self._container = tts.container - self._semaphore = asyncio.Semaphore(max_concurrency) - self._inflight: dict[str, asyncio.Future[Path]] = {} - self._inflight_lock = asyncio.Lock() - - async def segment(self, request: SynthesisRequest) -> Path: - key = self._cache.key(request) - cached = self._cache.get(key, self._container) - if cached is not None: - return cached - - async with self._inflight_lock: - future = self._inflight.get(key) - owner = future is None - if owner: - future = asyncio.get_event_loop().create_future() - self._inflight[key] = future - - # The owner runs the work and publishes the outcome on the shared future; - # every caller (owner included) reads it back via ``await future`` so the - # result is retrieved exactly once-or-more and never left dangling. - if owner: - try: - path = await self._synthesize(request, key) - except BaseException as exc: - future.set_exception(exc) - else: - future.set_result(path) - finally: - await self._forget(key) - - return await future - - async def _synthesize(self, request: SynthesisRequest, key: str) -> Path: - async with self._semaphore: - cached = self._cache.get(key, self._container) - if cached is not None: - return cached - try: - audio = await self._tts.synthesize(request) - except TextToSpeechError as exc: - raise RenderError(f"segment synthesis failed: {exc}") from exc - return self._cache.put(key, audio.container, audio.data) - - async def _forget(self, key: str) -> None: - async with self._inflight_lock: - self._inflight.pop(key, None) diff --git a/surfsense_backend/app/podcasts/resolution/__init__.py b/surfsense_backend/app/podcasts/resolution/__init__.py deleted file mode 100644 index 19a7edfb3..000000000 --- a/surfsense_backend/app/podcasts/resolution/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Resolution: deterministic default chains for a fresh brief. - -Turns the user's last-used preferences into concrete language and voice -defaults, so the brief gate opens pre-filled and most users approve without -editing. -""" - -from __future__ import annotations - -from .language import ( - DEFAULT_LANGUAGE, - DEFAULT_LANGUAGE_CHAIN, - LanguageContext, - LanguageResolver, - resolve_language, -) -from .voices import VoiceResolutionError, resolve_voices - -__all__ = [ - "DEFAULT_LANGUAGE", - "DEFAULT_LANGUAGE_CHAIN", - "LanguageContext", - "LanguageResolver", - "VoiceResolutionError", - "resolve_language", - "resolve_voices", -] diff --git a/surfsense_backend/app/podcasts/resolution/language.py b/surfsense_backend/app/podcasts/resolution/language.py deleted file mode 100644 index 336d9036b..000000000 --- a/surfsense_backend/app/podcasts/resolution/language.py +++ /dev/null @@ -1,64 +0,0 @@ -"""Resolve the brief's language without spending tokens at the gate. - -The chain mirrors the agreed policy: reuse the language the user last chose, and -otherwise default to English (which the user can still override in the brief). We -deliberately never guess the language from the source content — proposing a -language the user did not ask for is worse than a predictable default. -""" - -from __future__ import annotations - -from abc import ABC, abstractmethod -from dataclasses import dataclass - -# What a brand-new user with no signal gets, and what every chain ends on. -DEFAULT_LANGUAGE = "en" - - -@dataclass(frozen=True, slots=True) -class LanguageContext: - """Signals available when proposing a language for a fresh podcast.""" - - last_used: str | None = None - - -class LanguageResolver(ABC): - """One step in the language fallback chain.""" - - @abstractmethod - def resolve(self, context: LanguageContext) -> str | None: - """Return a language tag, or ``None`` to defer to the next resolver.""" - - -class LastUsedLanguage(LanguageResolver): - """Reuse the language from the user's previous podcast.""" - - def resolve(self, context: LanguageContext) -> str | None: - return context.last_used - - -class DefaultLanguage(LanguageResolver): - """Terminal step: always yields the default so the chain never fails.""" - - def resolve(self, context: LanguageContext) -> str | None: - return DEFAULT_LANGUAGE - - -# Order encodes the policy; prepend stronger signals here as they appear. -DEFAULT_LANGUAGE_CHAIN: tuple[LanguageResolver, ...] = ( - LastUsedLanguage(), - DefaultLanguage(), -) - - -def resolve_language( - context: LanguageContext, - chain: tuple[LanguageResolver, ...] = DEFAULT_LANGUAGE_CHAIN, -) -> str: - """Walk ``chain`` and return the first language a resolver yields.""" - for resolver in chain: - language = resolver.resolve(context) - if language: - return language.strip() - # The default resolver guarantees a value; this guards a misconfigured chain. - return DEFAULT_LANGUAGE diff --git a/surfsense_backend/app/podcasts/resolution/voices.py b/surfsense_backend/app/podcasts/resolution/voices.py deleted file mode 100644 index 8d865fbaa..000000000 --- a/surfsense_backend/app/podcasts/resolution/voices.py +++ /dev/null @@ -1,79 +0,0 @@ -"""Assign a default voice to each speaker for the resolved language. - -The default chain reuses the user's previously chosen voices where they are -still valid for the new language/provider, then fills any remaining speakers -with distinct catalog voices (preferring an unused gender so a two-speaker -episode sounds like two people). The user can override any of these in the -brief; this only seeds sensible defaults so most briefs need no edits. -""" - -from __future__ import annotations - -from collections.abc import Sequence - -from app.podcasts.voices import CatalogVoice, TtsProvider, VoiceCatalog - - -class VoiceResolutionError(RuntimeError): - """No catalog voice exists for the requested provider and language.""" - - -def resolve_voices( - *, - catalog: VoiceCatalog, - provider: TtsProvider, - language: str, - speaker_count: int, - preferred: Sequence[str] | None = None, -) -> list[CatalogVoice]: - """Return one :class:`CatalogVoice` per speaker, in slot order. - - ``preferred`` is the user's last-used voice ids (by slot); any that no - longer fit the provider/language are silently dropped and replaced. - """ - if speaker_count < 1: - raise ValueError("speaker_count must be >= 1") - - available = catalog.for_language(provider, language) - if not available: - raise VoiceResolutionError( - f"{provider.value} has no voice for language {language!r}" - ) - - preferred = preferred or () - by_id = {voice.voice_id: voice for voice in available} - - assignment: list[CatalogVoice] = [] - used_ids: set[str] = set() - used_genders: set = set() - - for slot in range(speaker_count): - reuse_id = preferred[slot] if slot < len(preferred) else None - if reuse_id and reuse_id in by_id and reuse_id not in used_ids: - voice = by_id[reuse_id] - else: - voice = _pick_distinct(available, used_ids, used_genders) - assignment.append(voice) - used_ids.add(voice.voice_id) - used_genders.add(voice.gender) - - return assignment - - -def _pick_distinct( - available: list[CatalogVoice], - used_ids: set[str], - used_genders: set, -) -> CatalogVoice: - """Pick a fresh voice, preferring an unused gender, then any unused voice. - - Falls back to the first catalog voice when speakers outnumber distinct - voices, so resolution always assigns every speaker rather than failing. - """ - fresh = [v for v in available if v.voice_id not in used_ids] - if fresh: - for voice in fresh: - if voice.gender not in used_genders: - return voice - return fresh[0] - return available[0] diff --git a/surfsense_backend/app/podcasts/schemas/__init__.py b/surfsense_backend/app/podcasts/schemas/__init__.py deleted file mode 100644 index cd19a21cc..000000000 --- a/surfsense_backend/app/podcasts/schemas/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -"""Pydantic shapes for the podcast brief and transcript.""" - -from __future__ import annotations - -from .spec import ( - DurationTarget, - PodcastSpec, - PodcastStyle, - SpeakerRole, - SpeakerSpec, - normalize_language_tag, -) -from .transcript import Transcript, TranscriptTurn - -__all__ = [ - "DurationTarget", - "PodcastSpec", - "PodcastStyle", - "SpeakerRole", - "SpeakerSpec", - "Transcript", - "TranscriptTurn", - "normalize_language_tag", -] diff --git a/surfsense_backend/app/podcasts/schemas/spec.py b/surfsense_backend/app/podcasts/schemas/spec.py deleted file mode 100644 index 3799d883b..000000000 --- a/surfsense_backend/app/podcasts/schemas/spec.py +++ /dev/null @@ -1,187 +0,0 @@ -"""The brief: the editable configuration a user approves before drafting. - -A :class:`PodcastSpec` front-loads every decision that drives token or audio -cost (language, speakers, voices, style, target length) so the expensive -drafting and rendering steps run once against settled inputs. It is stored as -JSONB on the ``podcasts`` row and round-trips through the review API. -""" - -from __future__ import annotations - -import re -from enum import StrEnum -from typing import Any - -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator - -from app.podcasts.duration_limits import ( - MAX_DURATION_SECONDS, - MIN_DURATION_SECONDS, -) - -# A speaker count beyond this is almost never a real podcast and explodes the -# voice/turn-attribution space, so we reject it at the brief gate. -MAX_SPEAKERS = 6 - -# BCP-47 primary subtag plus optional region (e.g. ``en``, ``en-US``, ``pt-BR``). -# Kept deliberately permissive: the voice catalog, not the brief, decides which -# languages can actually be synthesised. Casing is normalised after matching. -_LANGUAGE_TAG = re.compile(r"^[A-Za-z]{2,3}(-[A-Za-z0-9]{2,8})*$") - - -def normalize_language_tag(value: str) -> str: - """Validate and canonicalise a BCP-47 tag (lowercased primary subtag). - - Shared with the generation layer so resolved and user-entered languages are - normalised identically before they reach a :class:`PodcastSpec`. - """ - cleaned = value.strip() - if not _LANGUAGE_TAG.match(cleaned): - raise ValueError(f"not a valid BCP-47 language tag: {value!r}") - primary, _, rest = cleaned.partition("-") - return primary.lower() if not rest else f"{primary.lower()}-{rest}" - - -class SpeakerRole(StrEnum): - """How a speaker functions in the conversation, used to steer drafting.""" - - HOST = "host" - COHOST = "cohost" - GUEST = "guest" - EXPERT = "expert" - NARRATOR = "narrator" - - -class PodcastStyle(StrEnum): - """The conversational format the transcript should follow.""" - - CONVERSATIONAL = "conversational" - INTERVIEW = "interview" - DEBATE = "debate" - MONOLOGUE = "monologue" - NARRATIVE = "narrative" - - -class SpeakerSpec(BaseModel): - """One voice in the podcast: who they are and which TTS voice renders them. - - ``slot`` is the stable join key. Transcript turns reference a speaker by - ``slot`` and the renderer resolves ``voice_id`` for that same slot, so the - two never drift even if speakers are reordered in the brief. - """ - - model_config = ConfigDict(extra="forbid") - - slot: int = Field( - ..., ge=0, description="Stable index a transcript turn references" - ) - name: str = Field(..., min_length=1, max_length=120) - role: SpeakerRole - voice_id: str = Field( - ..., - min_length=1, - description="Catalog voice id valid for the spec's language and provider", - ) - - @field_validator("name", "voice_id") - @classmethod - def _strip_required_text(cls, value: str) -> str: - cleaned = value.strip() - if not cleaned: - raise ValueError("must not be blank") - return cleaned - - -class DurationTarget(BaseModel): - """The desired finished length as an inclusive second range. - - Drafting aims for the midpoint and treats the bounds as soft guardrails; - storing a range (rather than a point) keeps long-form expectations honest - without pretending we can hit an exact runtime. - """ - - model_config = ConfigDict(extra="forbid") - - min_seconds: int = Field(..., ge=MIN_DURATION_SECONDS, le=MAX_DURATION_SECONDS) - max_seconds: int = Field(..., ge=MIN_DURATION_SECONDS, le=MAX_DURATION_SECONDS) - - @model_validator(mode="before") - @classmethod - def _coerce_legacy_minutes(cls, data: Any) -> Any: - """Rows stored before seconds-based briefs still load from JSONB.""" - if ( - isinstance(data, dict) - and "min_seconds" not in data - and "min_minutes" in data - ): - migrated = dict(data) - migrated["min_seconds"] = int(migrated.pop("min_minutes")) * 60 - migrated["max_seconds"] = int(migrated.pop("max_minutes")) * 60 - return migrated - return data - - @model_validator(mode="after") - def _check_order(self) -> DurationTarget: - if self.max_seconds < self.min_seconds: - raise ValueError("max_seconds must be >= min_seconds") - return self - - @property - def midpoint_seconds(self) -> float: - """The runtime drafting should aim for within the range.""" - return (self.min_seconds + self.max_seconds) / 2 - - @property - def midpoint_minutes(self) -> float: - return self.midpoint_seconds / 60 - - -class PodcastSpec(BaseModel): - """The full brief approved before any tokens or audio are spent.""" - - model_config = ConfigDict(extra="forbid") - - language: str = Field(..., description="BCP-47 tag, e.g. 'en', 'en-US', 'pt-BR'") - style: PodcastStyle = PodcastStyle.CONVERSATIONAL - speakers: list[SpeakerSpec] = Field(..., min_length=1, max_length=MAX_SPEAKERS) - duration: DurationTarget - focus: str | None = Field( - default=None, - max_length=2000, - description="Optional user steer for what the episode should emphasise", - ) - - @field_validator("language") - @classmethod - def _normalise_language(cls, value: str) -> str: - return normalize_language_tag(value) - - @field_validator("focus") - @classmethod - def _blank_focus_is_none(cls, value: str | None) -> str | None: - if value is None: - return None - cleaned = value.strip() - return cleaned or None - - @model_validator(mode="after") - def _check_speaker_slots(self) -> PodcastSpec: - slots = [speaker.slot for speaker in self.speakers] - if len(slots) != len(set(slots)): - raise ValueError("speaker slots must be unique") - return self - - @model_validator(mode="after") - def _check_style_speakers(self) -> PodcastSpec: - # One voice is what "monologue" means; letting extra speakers through - # would force drafting to silently pick a winner. - if self.style is PodcastStyle.MONOLOGUE and len(self.speakers) != 1: - raise ValueError("a monologue has exactly one speaker") - return self - - def speaker_for(self, slot: int) -> SpeakerSpec: - """Return the speaker bound to ``slot`` or raise if none matches.""" - for speaker in self.speakers: - if speaker.slot == slot: - return speaker - raise KeyError(f"no speaker for slot {slot}") diff --git a/surfsense_backend/app/podcasts/schemas/transcript.py b/surfsense_backend/app/podcasts/schemas/transcript.py deleted file mode 100644 index b4c1463d8..000000000 --- a/surfsense_backend/app/podcasts/schemas/transcript.py +++ /dev/null @@ -1,41 +0,0 @@ -"""The transcript: ordered dialogue turns drafting produces for review. - -A :class:`Transcript` is the reviewable artifact at the go/no-go gate and the -exact input the renderer turns into audio. Each turn names a speaker by the -``slot`` defined in the :class:`~app.podcasts.schemas.spec.PodcastSpec`, so the -renderer can resolve the right voice without re-attributing anything. -""" - -from __future__ import annotations - -from pydantic import BaseModel, ConfigDict, Field, field_validator - - -class TranscriptTurn(BaseModel): - """A single spoken line by one speaker.""" - - model_config = ConfigDict(extra="forbid") - - speaker: int = Field(..., ge=0, description="The PodcastSpec speaker slot speaking") - text: str = Field(..., min_length=1) - - @field_validator("text") - @classmethod - def _strip_text(cls, value: str) -> str: - cleaned = value.strip() - if not cleaned: - raise ValueError("turn text must not be blank") - return cleaned - - -class Transcript(BaseModel): - """The full ordered dialogue for an episode.""" - - model_config = ConfigDict(extra="forbid") - - turns: list[TranscriptTurn] = Field(..., min_length=1) - - @property - def word_count(self) -> int: - """Total spoken words, used to estimate runtime against the brief.""" - return sum(len(turn.text.split()) for turn in self.turns) diff --git a/surfsense_backend/app/podcasts/service.py b/surfsense_backend/app/podcasts/service.py deleted file mode 100644 index 165bc77a4..000000000 --- a/surfsense_backend/app/podcasts/service.py +++ /dev/null @@ -1,255 +0,0 @@ -"""The podcast lifecycle authority: every status change goes through here. - -The service owns the state machine. Each method names a real lifecycle step, -validates it against the allowed-transition table, and (de)serializes the brief -and transcript to/from their JSONB columns. It deliberately does not enqueue -Celery work — callers transition the row here, then schedule the next task — so -the rules stay testable and free of task-queue coupling. -""" - -from __future__ import annotations - -from sqlalchemy.ext.asyncio import AsyncSession - -from app.podcasts.persistence import Podcast, PodcastRepository, PodcastStatus -from app.podcasts.schemas import PodcastSpec, Transcript, TranscriptTurn - -_MAX_ERROR_CHARS = 2000 - -# The only status changes the machine permits. Terminal states have no exits. -_ALLOWED: dict[PodcastStatus, frozenset[PodcastStatus]] = { - PodcastStatus.PENDING: frozenset( - {PodcastStatus.AWAITING_BRIEF, PodcastStatus.FAILED, PodcastStatus.CANCELLED} - ), - # The READY exits below exist for reverting a regeneration; the audio - # guard for that lives in revert_regeneration. - PodcastStatus.AWAITING_BRIEF: frozenset( - { - PodcastStatus.DRAFTING, - PodcastStatus.READY, - PodcastStatus.FAILED, - PodcastStatus.CANCELLED, - } - ), - PodcastStatus.DRAFTING: frozenset( - { - PodcastStatus.RENDERING, - PodcastStatus.READY, - PodcastStatus.FAILED, - PodcastStatus.CANCELLED, - } - ), - # Never entered anymore (the transcript gate was dropped); kept with exits - # so legacy rows aren't stranded. - PodcastStatus.AWAITING_REVIEW: frozenset( - {PodcastStatus.AWAITING_BRIEF, PodcastStatus.FAILED, PodcastStatus.CANCELLED} - ), - PodcastStatus.RENDERING: frozenset( - {PodcastStatus.READY, PodcastStatus.FAILED, PodcastStatus.CANCELLED} - ), - # Not terminal: regeneration reopens the brief gate so the user can tweak - # the spec before a new take is drafted. - PodcastStatus.READY: frozenset({PodcastStatus.AWAITING_BRIEF}), - PodcastStatus.FAILED: frozenset(), - PodcastStatus.CANCELLED: frozenset(), -} - - -class PodcastError(RuntimeError): - """Base class for lifecycle errors.""" - - -class InvalidTransitionError(PodcastError): - """A requested status change is not permitted from the current state.""" - - -class SpecConflictError(PodcastError): - """A spec edit raced another: the expected version is stale.""" - - def __init__(self, expected: int, actual: int) -> None: - super().__init__( - f"spec version conflict: expected {expected}, current is {actual}" - ) - self.expected = expected - self.actual = actual - - -class PreconditionFailedError(PodcastError): - """A transition's data precondition (brief/transcript present) is unmet.""" - - -class PodcastService: - """Drives one podcast through its lifecycle within a single session.""" - - def __init__(self, session: AsyncSession) -> None: - self._session = session - self._repo = PodcastRepository(session) - - async def create( - self, *, title: str, search_space_id: int, thread_id: int | None = None - ) -> Podcast: - """Create a fresh podcast in ``PENDING`` awaiting its brief.""" - podcast = Podcast( - title=title, - search_space_id=search_space_id, - thread_id=thread_id, - status=PodcastStatus.PENDING, - spec_version=1, - ) - return await self._repo.add(podcast) - - async def attach_brief(self, podcast: Podcast, spec: PodcastSpec) -> Podcast: - """Record the proposed brief and open the review gate.""" - self._transition(podcast, PodcastStatus.AWAITING_BRIEF) - podcast.spec = spec.model_dump(mode="json") - await self._session.flush() - return podcast - - async def update_spec( - self, podcast: Podcast, spec: PodcastSpec, expected_version: int - ) -> Podcast: - """Edit the brief at the gate, guarded by optimistic concurrency.""" - if _status(podcast) is not PodcastStatus.AWAITING_BRIEF: - raise InvalidTransitionError( - f"the brief can only be edited while awaiting_brief, " - f"not {_status(podcast).value}" - ) - if expected_version != podcast.spec_version: - raise SpecConflictError(expected_version, podcast.spec_version) - podcast.spec = spec.model_dump(mode="json") - podcast.spec_version += 1 - await self._session.flush() - return podcast - - async def begin_drafting(self, podcast: Podcast) -> Podcast: - """Approve the brief and start transcript drafting.""" - if podcast.spec is None: - raise PreconditionFailedError("cannot draft without a brief") - self._transition(podcast, PodcastStatus.DRAFTING) - await self._session.flush() - return podcast - - async def attach_transcript( - self, podcast: Podcast, transcript: Transcript - ) -> Podcast: - """Record the drafted transcript and move straight to rendering.""" - self._transition(podcast, PodcastStatus.RENDERING) - podcast.podcast_transcript = transcript.model_dump(mode="json") - await self._session.flush() - return podcast - - # Guards regenerate beyond the transition table: from PENDING the - # AWAITING_BRIEF target is also legal, but there it means attaching a brief. - _REGENERABLE = frozenset({PodcastStatus.READY, PodcastStatus.AWAITING_REVIEW}) - - async def regenerate(self, podcast: Podcast) -> Podcast: - """Reopen the brief gate; the saved spec becomes the new starting point.""" - if _status(podcast) not in self._REGENERABLE: - raise InvalidTransitionError( - f"nothing to regenerate from {_status(podcast).value}" - ) - # Legacy episodes finished before briefs existed; a gate with nothing - # to review would strand them. - if podcast.spec is None: - raise PreconditionFailedError("cannot regenerate without a brief") - self._transition(podcast, PodcastStatus.AWAITING_BRIEF) - await self._session.flush() - return podcast - - async def revert_regeneration(self, podcast: Podcast) -> Podcast: - """Back out of a regeneration and fall back to the stored episode. - - Regeneration keeps the rendered audio until a new take replaces it, so - any point before that commit is a free change of mind. A fresh podcast - has no regeneration to revert and is rejected. - """ - if not has_stored_episode(podcast): - raise InvalidTransitionError("no finished episode to fall back to") - self._transition(podcast, PodcastStatus.READY) - await self._session.flush() - return podcast - - async def attach_audio( - self, - podcast: Podcast, - *, - storage_backend: str, - storage_key: str, - duration_seconds: int | None = None, - ) -> Podcast: - """Record rendered audio and mark the podcast ready.""" - self._transition(podcast, PodcastStatus.READY) - podcast.storage_backend = storage_backend - podcast.storage_key = storage_key - podcast.duration_seconds = duration_seconds - podcast.error = None - await self._session.flush() - return podcast - - async def fail(self, podcast: Podcast, error: str) -> Podcast: - """Move a non-terminal podcast to ``FAILED`` with a reason.""" - self._transition(podcast, PodcastStatus.FAILED) - podcast.error = (error or "")[:_MAX_ERROR_CHARS] or None - await self._session.flush() - return podcast - - async def cancel(self, podcast: Podcast) -> Podcast: - """Cancel a podcast that has produced nothing the user could keep. - - No user action may destroy playable audio: once an episode exists, - backing out goes through revert_regeneration instead. - """ - if has_stored_episode(podcast): - raise InvalidTransitionError( - "a finished episode exists; revert the regeneration instead" - ) - self._transition(podcast, PodcastStatus.CANCELLED) - await self._session.flush() - return podcast - - def _transition(self, podcast: Podcast, target: PodcastStatus) -> None: - current = _status(podcast) - if target not in _ALLOWED[current]: - raise InvalidTransitionError( - f"{current.value} -> {target.value} is not allowed" - ) - podcast.status = target - - -def _status(podcast: Podcast) -> PodcastStatus: - return PodcastStatus(podcast.status) - - -def has_stored_episode(podcast: Podcast) -> bool: - """Whether finished audio is stored (``file_location`` covers legacy rows).""" - return bool(podcast.storage_key or podcast.file_location) - - -def read_spec(podcast: Podcast) -> PodcastSpec | None: - """Deserialize the stored brief, or ``None`` if not yet proposed.""" - return PodcastSpec.model_validate(podcast.spec) if podcast.spec else None - - -def read_transcript(podcast: Podcast) -> Transcript | None: - """Deserialize the stored transcript, or ``None`` if not yet drafted.""" - raw = podcast.podcast_transcript - if not raw: - return None - # Rows from before the lifecycle rework stored a bare turn list with - # different field names; they must keep reading, not fail validation. - if isinstance(raw, list): - return Transcript( - turns=[ - TranscriptTurn(speaker=turn["speaker_id"], text=turn["dialog"]) - for turn in raw - ] - ) - return Transcript.model_validate(raw) - - -def preferences_from(podcast: Podcast | None) -> tuple[str | None, list[str]]: - """Extract reusable (language, voice_ids) defaults from a prior podcast.""" - spec = read_spec(podcast) if podcast is not None else None - if spec is None: - return None, [] - return spec.language, [speaker.voice_id for speaker in spec.speakers] diff --git a/surfsense_backend/app/podcasts/storage.py b/surfsense_backend/app/podcasts/storage.py deleted file mode 100644 index c3326460d..000000000 --- a/surfsense_backend/app/podcasts/storage.py +++ /dev/null @@ -1,60 +0,0 @@ -"""Durable storage for rendered podcast audio. - -Wraps the shared :class:`StorageBackend` so the rest of the module never deals -with object keys directly. Audio is stored under a per-podcast key, streamed for -download, and purged when a podcast is deleted. -""" - -from __future__ import annotations - -import uuid -from collections.abc import AsyncIterator - -from app.file_storage.factory import get_storage_backend -from app.podcasts.persistence import Podcast - -_AUDIO_CONTENT_TYPE = "audio/mpeg" - - -def build_audio_key(*, search_space_id: int, podcast_id: int) -> str: - """Object key for a podcast's audio. - - Shape: ``podcasts/{search_space_id}/{podcast_id}/{uuid}.mp3``. The uuid lets - a re-render write a fresh object before the old one is purged. - """ - return f"podcasts/{search_space_id}/{podcast_id}/{uuid.uuid4().hex}.mp3" - - -async def store_audio( - *, search_space_id: int, podcast_id: int, data: bytes -) -> tuple[str, str]: - """Persist audio bytes and return ``(backend_name, storage_key)``.""" - backend = get_storage_backend() - key = build_audio_key(search_space_id=search_space_id, podcast_id=podcast_id) - await backend.put(key, data, content_type=_AUDIO_CONTENT_TYPE) - return backend.backend_name, key - - -def open_audio_stream(podcast: Podcast) -> AsyncIterator[bytes]: - """Stream a ready podcast's audio bytes. Raises if it has none.""" - if not podcast.storage_key: - raise FileNotFoundError(f"podcast {podcast.id} has no stored audio") - return get_storage_backend().open_stream(podcast.storage_key) - - -async def audio_exists(podcast: Podcast) -> bool: - """Whether the podcast's stored audio object is actually present.""" - return bool(podcast.storage_key) and await get_storage_backend().exists( - podcast.storage_key - ) - - -async def purge_audio(podcast: Podcast) -> None: - """Delete a podcast's stored audio if present; a missing object is fine.""" - await purge_audio_object(podcast.storage_key) - - -async def purge_audio_object(key: str | None) -> None: - """Delete a stored audio object by key, e.g. the one a re-render replaced.""" - if key: - await get_storage_backend().delete(key) diff --git a/surfsense_backend/app/podcasts/tasks/__init__.py b/surfsense_backend/app/podcasts/tasks/__init__.py deleted file mode 100644 index cd0b7e4c4..000000000 --- a/surfsense_backend/app/podcasts/tasks/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -"""Celery tasks driving the podcast lifecycle across its expensive phases. - -One task per heavy async phase: draft the transcript (LLM) and render the audio -(TTS). The brief is deterministic and proposed inline at create time, so it has -no task. Each task is enqueued by the API after it performs the guarded status -transition, and each pushes its result onto the row for the frontend to observe. -""" - -from __future__ import annotations - -from .draft import draft_transcript_task -from .render import render_audio_task - -__all__ = [ - "draft_transcript_task", - "render_audio_task", -] diff --git a/surfsense_backend/app/podcasts/tasks/draft.py b/surfsense_backend/app/podcasts/tasks/draft.py deleted file mode 100644 index c5b489571..000000000 --- a/surfsense_backend/app/podcasts/tasks/draft.py +++ /dev/null @@ -1,100 +0,0 @@ -"""Transcript-drafting task: DRAFTING -> RENDERING. - -The expensive, LLM-heavy step, so it runs under ``billable_call``. The API has -already moved the row to DRAFTING and stored the approved brief; this task -drafts the long-form transcript and chains straight into the render — the brief -gate is the only approval in the lifecycle. -""" - -from __future__ import annotations - -import logging - -from app.celery_app import celery_app -from app.config import config as app_config -from app.podcasts.generation.transcript.graph import graph as transcript_graph -from app.podcasts.generation.transcript.state import TranscriptState -from app.podcasts.persistence import PodcastRepository -from app.podcasts.service import PodcastService, read_spec -from app.services.billable_calls import ( - BillingSettlementError, - QuotaInsufficientError, - _resolve_agent_billing_for_search_space, - billable_call, -) -from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task - -from .render import render_audio_task -from .runtime import billable_session, mark_failed - -logger = logging.getLogger(__name__) - - -@celery_app.task(name="podcast.draft_transcript", bind=True) -def draft_transcript_task(self, podcast_id: int, search_space_id: int) -> dict: - try: - return run_async_celery_task( - lambda: _draft_transcript(podcast_id, search_space_id) - ) - except Exception as exc: - logger.error("Podcast %s drafting failed: %s", podcast_id, exc) - message = str(exc) - run_async_celery_task(lambda: mark_failed(podcast_id, message)) - return {"status": "failed", "podcast_id": podcast_id} - - -async def _draft_transcript(podcast_id: int, search_space_id: int) -> dict: - async with get_celery_session_maker()() as session: - repo = PodcastRepository(session) - service = PodcastService(session) - podcast = await repo.get(podcast_id) - if podcast is None: - raise ValueError(f"podcast {podcast_id} not found") - - spec = read_spec(podcast) - if spec is None: - raise ValueError(f"podcast {podcast_id} has no approved brief") - - owner_id, tier, base_model = await _resolve_agent_billing_for_search_space( - session, search_space_id, thread_id=podcast.thread_id - ) - - state = TranscriptState( - db_session=session, source_content=podcast.source_content or "" - ) - config = { - "configurable": { - "search_space_id": search_space_id, - "spec": spec, - "focus": spec.focus, - } - } - - try: - async with billable_call( - user_id=owner_id, - search_space_id=search_space_id, - billing_tier=tier, - base_model=base_model, - quota_reserve_micros_override=app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS, - usage_type="podcast_generation", - call_details={"podcast_id": podcast_id, "title": podcast.title}, - billable_session_factory=billable_session, - ): - result = await transcript_graph.ainvoke(state, config=config) - except QuotaInsufficientError: - await service.fail(podcast, "premium quota exhausted") - await session.commit() - return {"status": "failed", "podcast_id": podcast_id, "reason": "quota"} - except BillingSettlementError: - await service.fail(podcast, "billing settlement failed") - await session.commit() - return {"status": "failed", "podcast_id": podcast_id, "reason": "billing"} - - await service.attach_transcript(podcast, result["transcript"]) - await session.commit() - - # Enqueue only after the transaction is committed, so the render worker can - # never pick up a row whose transcript isn't visible yet. - render_audio_task.delay(podcast_id) - return {"status": "rendering", "podcast_id": podcast_id} diff --git a/surfsense_backend/app/podcasts/tasks/render.py b/surfsense_backend/app/podcasts/tasks/render.py deleted file mode 100644 index 2e550a868..000000000 --- a/surfsense_backend/app/podcasts/tasks/render.py +++ /dev/null @@ -1,88 +0,0 @@ -"""Audio-rendering task: RENDERING -> READY. - -Synthesises and merges the approved transcript, stores the MP3 in the object -store, and marks the podcast ready. The working directory is stable per podcast -so a re-render (e.g. after a voice change) reuses the segment cache. -""" - -from __future__ import annotations - -import logging -import tempfile -from pathlib import Path - -from app.celery_app import celery_app -from app.podcasts.persistence import PodcastRepository -from app.podcasts.rendering import PodcastRenderer -from app.podcasts.service import ( - InvalidTransitionError, - PodcastService, - read_spec, - read_transcript, -) -from app.podcasts.storage import purge_audio_object, store_audio -from app.podcasts.tts import get_text_to_speech -from app.podcasts.voices import get_voice_catalog -from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task - -from .runtime import mark_failed - -logger = logging.getLogger(__name__) - -_WORKDIR_BASE = Path(tempfile.gettempdir()) / "surfsense_podcasts" - - -@celery_app.task(name="podcast.render_audio", bind=True) -def render_audio_task(self, podcast_id: int) -> dict: - try: - return run_async_celery_task(lambda: _render_audio(podcast_id)) - except Exception as exc: - logger.error("Podcast %s render failed: %s", podcast_id, exc) - message = str(exc) - run_async_celery_task(lambda: mark_failed(podcast_id, message)) - return {"status": "failed", "podcast_id": podcast_id} - - -async def _render_audio(podcast_id: int) -> dict: - async with get_celery_session_maker()() as session: - repo = PodcastRepository(session) - podcast = await repo.get(podcast_id) - if podcast is None: - raise ValueError(f"podcast {podcast_id} not found") - - spec = read_spec(podcast) - transcript = read_transcript(podcast) - if spec is None or transcript is None: - raise ValueError(f"podcast {podcast_id} is missing brief or transcript") - - renderer = PodcastRenderer( - tts=get_text_to_speech(), catalog=get_voice_catalog() - ) - workdir = _WORKDIR_BASE / str(podcast_id) - workdir.mkdir(parents=True, exist_ok=True) - rendered = await renderer.render( - spec=spec, transcript=transcript, workdir=workdir - ) - - superseded_key = podcast.storage_key - - backend_name, key = await store_audio( - search_space_id=podcast.search_space_id, - podcast_id=podcast_id, - data=rendered.data, - ) - try: - await PodcastService(session).attach_audio( - podcast, storage_backend=backend_name, storage_key=key - ) - await session.commit() - except InvalidTransitionError: - # A user back-out won the race (e.g. the regeneration was - # reverted): drop the stale render and leave the row alone. - await purge_audio_object(key) - return {"status": "superseded", "podcast_id": podcast_id} - - # Purge only after the new audio is committed, so a failed re-render never - # destroys the episode the user can still play. - await purge_audio_object(superseded_key) - return {"status": "ready", "podcast_id": podcast_id} diff --git a/surfsense_backend/app/podcasts/tasks/runtime.py b/surfsense_backend/app/podcasts/tasks/runtime.py deleted file mode 100644 index 349aeffb2..000000000 --- a/surfsense_backend/app/podcasts/tasks/runtime.py +++ /dev/null @@ -1,40 +0,0 @@ -"""Shared plumbing for the podcast Celery tasks. - -Each task runs its async body via :func:`run_async_celery_task` and, on any -failure, records the reason on the row through the lifecycle service. Marking -failed is best-effort: a podcast that already reached a terminal state is left -untouched rather than forced. -""" - -from __future__ import annotations - -import logging -from contextlib import asynccontextmanager - -from app.podcasts.persistence import PodcastRepository -from app.podcasts.service import PodcastError, PodcastService -from app.tasks.celery_tasks import get_celery_session_maker - -logger = logging.getLogger(__name__) - - -@asynccontextmanager -async def billable_session(): - """Session factory for ``billable_call`` inside the worker loop.""" - async with get_celery_session_maker()() as session: - yield session - - -async def mark_failed(podcast_id: int, error: str) -> None: - """Best-effort: move a non-terminal podcast to FAILED with ``error``.""" - async with get_celery_session_maker()() as session: - repo = PodcastRepository(session) - podcast = await repo.get(podcast_id) - if podcast is None: - return - try: - await PodcastService(session).fail(podcast, error) - await session.commit() - except PodcastError: - # Already terminal (e.g. cancelled): nothing to record. - logger.info("Podcast %s already terminal; not marking failed", podcast_id) diff --git a/surfsense_backend/app/podcasts/tts/__init__.py b/surfsense_backend/app/podcasts/tts/__init__.py deleted file mode 100644 index 16379dc2b..000000000 --- a/surfsense_backend/app/podcasts/tts/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -"""Text-to-speech: a per-segment synthesis port with provider adapters. - -Callers depend on :class:`TextToSpeech` and obtain the configured provider from -:func:`get_text_to_speech`; the concrete Kokoro/LiteLLM adapters stay private. -""" - -from __future__ import annotations - -from .audio import SynthesizedAudio -from .errors import TextToSpeechError -from .factory import get_text_to_speech -from .port import TextToSpeech -from .request import SynthesisRequest, VoiceRef - -__all__ = [ - "SynthesisRequest", - "SynthesizedAudio", - "TextToSpeech", - "TextToSpeechError", - "VoiceRef", - "get_text_to_speech", -] diff --git a/surfsense_backend/app/podcasts/tts/adapters/__init__.py b/surfsense_backend/app/podcasts/tts/adapters/__init__.py deleted file mode 100644 index 24d517e55..000000000 --- a/surfsense_backend/app/podcasts/tts/adapters/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Per-provider TextToSpeech implementations.""" - -from __future__ import annotations diff --git a/surfsense_backend/app/podcasts/tts/adapters/kokoro.py b/surfsense_backend/app/podcasts/tts/adapters/kokoro.py deleted file mode 100644 index 2ef0069c5..000000000 --- a/surfsense_backend/app/podcasts/tts/adapters/kokoro.py +++ /dev/null @@ -1,109 +0,0 @@ -"""Local Kokoro adapter: on-box synthesis, no network or per-segment cost. - -Kokoro selects its language model by a single-letter ``lang_code``, so this -adapter maps the brief's BCP-47 tag to that code and caches one pipeline per -code (pipeline construction loads weights and is expensive). Pipelines run in a -thread pool because Kokoro is synchronous; the renderer caps how many segments -synthesise at once. -""" - -from __future__ import annotations - -import asyncio -import io -from typing import TYPE_CHECKING - -from ..audio import SynthesizedAudio -from ..errors import TextToSpeechError -from ..port import TextToSpeech -from ..request import SynthesisRequest - -if TYPE_CHECKING: - from kokoro import KPipeline - -# Kokoro emits 24 kHz mono PCM regardless of voice. -_SAMPLE_RATE = 24000 - -# BCP-47 primary subtag -> Kokoro language code. English defaults to American; -# the en-GB region override below switches it to British. -_LANG_CODE_BY_PRIMARY = { - "en": "a", - "es": "e", - "fr": "f", - "hi": "h", - "it": "i", - "ja": "j", - "pt": "p", - "zh": "z", -} - - -class KokoroTextToSpeech(TextToSpeech): - """Synthesises segments with locally hosted Kokoro pipelines.""" - - def __init__(self) -> None: - self._pipelines: dict[str, KPipeline] = {} - - @property - def container(self) -> str: - return "wav" - - async def synthesize(self, request: SynthesisRequest) -> SynthesizedAudio: - if not isinstance(request.voice, str): - raise TextToSpeechError("Kokoro voices are named by string, not a mapping") - - pipeline = self._pipeline_for(request.language) - loop = asyncio.get_event_loop() - try: - generator = await loop.run_in_executor( - None, - lambda: pipeline( - request.text, - voice=request.voice, - speed=request.speed, - split_pattern=r"\n+", - ), - ) - segments = [audio for _gs, _ps, audio in generator] - except Exception as exc: - raise TextToSpeechError(f"Kokoro synthesis failed: {exc}") from exc - - if not segments: - raise TextToSpeechError("Kokoro produced no audio for the text") - - return SynthesizedAudio( - data=_encode_wav(segments, _SAMPLE_RATE), - container="wav", - sample_rate=_SAMPLE_RATE, - ) - - def _pipeline_for(self, language: str) -> KPipeline: - lang_code = _lang_code(language) - pipeline = self._pipelines.get(lang_code) - if pipeline is None: - from kokoro import KPipeline - - pipeline = KPipeline(lang_code=lang_code) - self._pipelines[lang_code] = pipeline - return pipeline - - -def _lang_code(language: str) -> str: - normalised = language.strip().lower() - if normalised.startswith("en-gb") or normalised == "en-uk": - return "b" - primary = normalised.partition("-")[0] - code = _LANG_CODE_BY_PRIMARY.get(primary) - if code is None: - raise TextToSpeechError(f"Kokoro has no language model for {language!r}") - return code - - -def _encode_wav(segments: list, sample_rate: int) -> bytes: - import numpy as np - import soundfile as sf - - waveform = segments[0] if len(segments) == 1 else np.concatenate(segments) - buffer = io.BytesIO() - sf.write(buffer, waveform, sample_rate, format="WAV") - return buffer.getvalue() diff --git a/surfsense_backend/app/podcasts/tts/adapters/litellm.py b/surfsense_backend/app/podcasts/tts/adapters/litellm.py deleted file mode 100644 index d0014c5cd..000000000 --- a/surfsense_backend/app/podcasts/tts/adapters/litellm.py +++ /dev/null @@ -1,67 +0,0 @@ -"""LiteLLM adapter: hosted TTS (OpenAI, Azure, Vertex AI) via one ``aspeech`` call. - -LiteLLM normalises every hosted provider behind the same ``aspeech`` surface, -so a single adapter covers them all. The provider is encoded in the model -string (e.g. ``openai/tts-1``, ``vertex_ai/...``) and the voice reference is -whatever that provider expects, which the catalog already supplies. -""" - -from __future__ import annotations - -from ..audio import SynthesizedAudio -from ..errors import TextToSpeechError -from ..port import TextToSpeech -from ..request import SynthesisRequest - -# Hosted providers return MP3-encoded bytes from ``aspeech``. -_CONTAINER = "mp3" - -# A long single segment still finishes well under this; retries absorb transient -# upstream failures without failing the whole render. -_TIMEOUT_SECONDS = 600 -_MAX_RETRIES = 2 - - -class LiteLlmTextToSpeech(TextToSpeech): - """Synthesises segments through any LiteLLM-supported hosted TTS model.""" - - def __init__( - self, - *, - model: str, - api_base: str | None = None, - api_key: str | None = None, - ) -> None: - self._model = model - self._api_base = api_base - self._api_key = api_key - - @property - def container(self) -> str: - return _CONTAINER - - async def synthesize(self, request: SynthesisRequest) -> SynthesizedAudio: - from litellm import aspeech - - kwargs = { - "model": self._model, - "voice": request.voice, - "input": request.text, - "max_retries": _MAX_RETRIES, - "timeout": _TIMEOUT_SECONDS, - } - if self._api_base: - kwargs["api_base"] = self._api_base - if self._api_key: - kwargs["api_key"] = self._api_key - - try: - response = await aspeech(**kwargs) - except Exception as exc: - raise TextToSpeechError(f"{self._model} synthesis failed: {exc}") from exc - - data = getattr(response, "content", None) - if not data: - raise TextToSpeechError(f"{self._model} returned no audio") - - return SynthesizedAudio(data=data, container=_CONTAINER) diff --git a/surfsense_backend/app/podcasts/tts/audio.py b/surfsense_backend/app/podcasts/tts/audio.py deleted file mode 100644 index f3c79dd5a..000000000 --- a/surfsense_backend/app/podcasts/tts/audio.py +++ /dev/null @@ -1,19 +0,0 @@ -"""The bytes a TTS provider returns for one segment.""" - -from __future__ import annotations - -from dataclasses import dataclass - - -@dataclass(frozen=True, slots=True) -class SynthesizedAudio: - """Encoded audio for a single segment, ready to cache and concatenate. - - ``container`` is the file extension the bytes are encoded as (``"wav"`` or - ``"mp3"``); the renderer uses it to name the on-disk segment so FFmpeg can - demux the right format during merge. - """ - - data: bytes - container: str - sample_rate: int | None = None diff --git a/surfsense_backend/app/podcasts/tts/errors.py b/surfsense_backend/app/podcasts/tts/errors.py deleted file mode 100644 index 8e7ec3f2b..000000000 --- a/surfsense_backend/app/podcasts/tts/errors.py +++ /dev/null @@ -1,13 +0,0 @@ -"""Failures raised by the TTS layer.""" - -from __future__ import annotations - - -class TextToSpeechError(RuntimeError): - """A provider failed to synthesise a segment. - - Raised for both configuration faults (an unusable voice reference) and - provider faults (the upstream call errored or returned no audio), so the - renderer can fail the segment without unwrapping provider-specific - exceptions. - """ diff --git a/surfsense_backend/app/podcasts/tts/factory.py b/surfsense_backend/app/podcasts/tts/factory.py deleted file mode 100644 index 7b4a48adf..000000000 --- a/surfsense_backend/app/podcasts/tts/factory.py +++ /dev/null @@ -1,38 +0,0 @@ -"""Resolve the configured :class:`TextToSpeech` as a process-wide singleton.""" - -from __future__ import annotations - -from functools import lru_cache - -from .port import TextToSpeech - -# Sentinel model string that selects the local Kokoro pipeline; anything else is -# treated as a LiteLLM-hosted model (``openai/...``, ``vertex_ai/...``, etc.). -KOKORO_SERVICE = "local/kokoro" - - -@lru_cache(maxsize=1) -def get_text_to_speech() -> TextToSpeech: - """Build the provider selected by ``TTS_SERVICE`` (adapters lazy-imported). - - Cached because the Kokoro adapter holds loaded pipelines that must be reused - across segments and requests rather than rebuilt per call. - """ - from app.config import config as app_config - - service = app_config.TTS_SERVICE - if not service: - raise ValueError("TTS_SERVICE is not configured") - - if service == KOKORO_SERVICE: - from .adapters.kokoro import KokoroTextToSpeech - - return KokoroTextToSpeech() - - from .adapters.litellm import LiteLlmTextToSpeech - - return LiteLlmTextToSpeech( - model=service, - api_base=app_config.TTS_SERVICE_API_BASE, - api_key=app_config.TTS_SERVICE_API_KEY, - ) diff --git a/surfsense_backend/app/podcasts/tts/port.py b/surfsense_backend/app/podcasts/tts/port.py deleted file mode 100644 index 604708260..000000000 --- a/surfsense_backend/app/podcasts/tts/port.py +++ /dev/null @@ -1,31 +0,0 @@ -"""The TTS contract: turn one segment of text into encoded audio.""" - -from __future__ import annotations - -from abc import ABC, abstractmethod - -from .audio import SynthesizedAudio -from .request import SynthesisRequest - - -class TextToSpeech(ABC): - """Synthesises a single segment; one implementation per provider. - - The contract is intentionally per-segment rather than per-episode: it keeps - each call independently cacheable and lets the renderer cap concurrency and - retry segments in isolation. Stitching segments into one file is the - renderer's job, not the provider's. - """ - - @property - @abstractmethod - def container(self) -> str: - """File extension/container this provider emits (e.g. ``"mp3"``).""" - - @abstractmethod - async def synthesize(self, request: SynthesisRequest) -> SynthesizedAudio: - """Voice ``request.text`` and return its encoded audio. - - Raises :class:`~app.podcasts.tts.errors.TextToSpeechError` on any - provider or configuration failure. - """ diff --git a/surfsense_backend/app/podcasts/tts/request.py b/surfsense_backend/app/podcasts/tts/request.py deleted file mode 100644 index 2cb5f6ec4..000000000 --- a/surfsense_backend/app/podcasts/tts/request.py +++ /dev/null @@ -1,22 +0,0 @@ -"""What the renderer hands a TTS provider to voice a single segment.""" - -from __future__ import annotations - -from collections.abc import Mapping -from dataclasses import dataclass -from typing import Any - -# A provider-native voice reference. OpenAI/Azure/Kokoro name a voice with a -# string; Vertex passes a mapping (``languageCode`` + ``name``). The catalog -# stores whichever shape the provider expects and we pass it through untouched. -VoiceRef = str | Mapping[str, Any] - - -@dataclass(frozen=True, slots=True) -class SynthesisRequest: - """One unit of speech to synthesise: the smallest cacheable render step.""" - - text: str - voice: VoiceRef - language: str - speed: float = 1.0 diff --git a/surfsense_backend/app/podcasts/voices/__init__.py b/surfsense_backend/app/podcasts/voices/__init__.py deleted file mode 100644 index 97874a655..000000000 --- a/surfsense_backend/app/podcasts/voices/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -"""Voices: the catalog of selectable TTS voices and the active provider. - -Callers obtain the catalog via :func:`get_voice_catalog` and identify the -configured provider via :func:`provider_from_service`. -""" - -from __future__ import annotations - -from .catalog import LanguageOffering, VoiceCatalog, get_voice_catalog -from .preview import render_voice_preview -from .provider import TtsProvider, provider_from_service -from .voice import ANY_LANGUAGE, CatalogVoice, VoiceGender - -__all__ = [ - "ANY_LANGUAGE", - "CatalogVoice", - "LanguageOffering", - "TtsProvider", - "VoiceCatalog", - "VoiceGender", - "get_voice_catalog", - "provider_from_service", - "render_voice_preview", -] diff --git a/surfsense_backend/app/podcasts/voices/catalog.py b/surfsense_backend/app/podcasts/voices/catalog.py deleted file mode 100644 index 6bf39510a..000000000 --- a/surfsense_backend/app/podcasts/voices/catalog.py +++ /dev/null @@ -1,80 +0,0 @@ -"""The voice catalog: look up and filter selectable voices. - -A :class:`VoiceCatalog` is the single source of truth for which voices exist. -Resolution uses it to pick defaults for a brief, the API exposes it as picker -options, and the renderer uses it to turn a stored ``voice_id`` back into the -provider-native reference. -""" - -from __future__ import annotations - -from collections.abc import Iterable -from dataclasses import dataclass -from functools import lru_cache - -from .data import AZURE_VOICES, KOKORO_VOICES, OPENAI_VOICES, VERTEX_VOICES -from .data.languages import COMMON_LANGUAGES -from .provider import TtsProvider -from .voice import ANY_LANGUAGE, CatalogVoice - - -@dataclass(frozen=True, slots=True) -class LanguageOffering: - """The languages a provider's roster can offer the brief form. - - ``allows_custom`` is true when the roster has wildcard voices: the listed - languages are then a curated starting point, not a limit, and any BCP-47 - tag may be entered. - """ - - languages: list[str] - allows_custom: bool - - -class VoiceCatalog: - """An indexed, read-only collection of :class:`CatalogVoice`.""" - - def __init__(self, voices: Iterable[CatalogVoice]) -> None: - self._by_id: dict[str, CatalogVoice] = {} - self._by_provider: dict[TtsProvider, list[CatalogVoice]] = {} - for voice in voices: - if voice.voice_id in self._by_id: - raise ValueError(f"duplicate voice_id: {voice.voice_id}") - self._by_id[voice.voice_id] = voice - self._by_provider.setdefault(voice.provider, []).append(voice) - - def get(self, voice_id: str) -> CatalogVoice: - """Return the voice with ``voice_id`` or raise ``KeyError``.""" - return self._by_id[voice_id] - - def for_provider(self, provider: TtsProvider) -> list[CatalogVoice]: - """All voices offered by ``provider``, in catalog order.""" - return list(self._by_provider.get(provider, ())) - - def for_language(self, provider: TtsProvider, language: str) -> list[CatalogVoice]: - """``provider`` voices that can render ``language``, in catalog order.""" - return [v for v in self.for_provider(provider) if v.speaks(language)] - - def supports_language(self, provider: TtsProvider, language: str) -> bool: - """Whether ``provider`` has at least one voice for ``language``.""" - return any(v.speaks(language) for v in self.for_provider(provider)) - - def offerable_languages(self, provider: TtsProvider) -> LanguageOffering: - """The languages ``provider`` can offer up front. - - Language-bound voices contribute their concrete tags; wildcard voices - cannot enumerate languages, so their presence merges in the curated - common list and opens free entry. - """ - voices = self.for_provider(provider) - tags = {v.language for v in voices if v.language != ANY_LANGUAGE} - has_wildcard = any(v.language == ANY_LANGUAGE for v in voices) - if has_wildcard: - tags.update(COMMON_LANGUAGES) - return LanguageOffering(languages=sorted(tags), allows_custom=has_wildcard) - - -@lru_cache(maxsize=1) -def get_voice_catalog() -> VoiceCatalog: - """The process-wide catalog assembled from every provider's roster.""" - return VoiceCatalog((*KOKORO_VOICES, *OPENAI_VOICES, *AZURE_VOICES, *VERTEX_VOICES)) diff --git a/surfsense_backend/app/podcasts/voices/data/__init__.py b/surfsense_backend/app/podcasts/voices/data/__init__.py deleted file mode 100644 index 5316f10f6..000000000 --- a/surfsense_backend/app/podcasts/voices/data/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -"""Static per-provider voice rosters that compose the catalog.""" - -from __future__ import annotations - -from .azure import AZURE_VOICES -from .kokoro import KOKORO_VOICES -from .openai import OPENAI_VOICES -from .vertex import VERTEX_VOICES - -__all__ = ["AZURE_VOICES", "KOKORO_VOICES", "OPENAI_VOICES", "VERTEX_VOICES"] diff --git a/surfsense_backend/app/podcasts/voices/data/azure.py b/surfsense_backend/app/podcasts/voices/data/azure.py deleted file mode 100644 index 104ab766d..000000000 --- a/surfsense_backend/app/podcasts/voices/data/azure.py +++ /dev/null @@ -1,32 +0,0 @@ -"""Azure TTS voices, routed through the OpenAI-compatible voice names. - -The deployment fronts Azure with OpenAI-style voice names (matching the legacy -podcaster), so these mirror the OpenAI roster and, like it, speak any requested -language. -""" - -from __future__ import annotations - -from ..provider import TtsProvider -from ..voice import ANY_LANGUAGE, CatalogVoice, VoiceGender - - -def _voice(name: str, display: str, gender: VoiceGender) -> CatalogVoice: - return CatalogVoice( - voice_id=f"azure:{name}", - provider=TtsProvider.AZURE, - language=ANY_LANGUAGE, - display_name=display, - gender=gender, - native_ref=name, - ) - - -AZURE_VOICES: tuple[CatalogVoice, ...] = ( - _voice("alloy", "Alloy", VoiceGender.NEUTRAL), - _voice("echo", "Echo", VoiceGender.MALE), - _voice("fable", "Fable", VoiceGender.NEUTRAL), - _voice("onyx", "Onyx", VoiceGender.MALE), - _voice("nova", "Nova", VoiceGender.FEMALE), - _voice("shimmer", "Shimmer", VoiceGender.FEMALE), -) diff --git a/surfsense_backend/app/podcasts/voices/data/kokoro.py b/surfsense_backend/app/podcasts/voices/data/kokoro.py deleted file mode 100644 index 732dced23..000000000 --- a/surfsense_backend/app/podcasts/voices/data/kokoro.py +++ /dev/null @@ -1,63 +0,0 @@ -"""Curated Kokoro voices, the local provider's multilingual roster. - -Kokoro voice names encode language and gender in their first two letters -(``a``=American English, ``b``=British, ``e``=Spanish, ``f``=French, -``h``=Hindi, ``i``=Italian, ``j``=Japanese, ``p``=Brazilian Portuguese, -``z``=Mandarin; second letter ``f``/``m`` = female/male). We carry at least one -male and one female voice per language so a two-speaker brief always has a -distinct pair. ``native_ref`` is the bare voice name Kokoro expects. - -Reference: https://huggingface.co/hexgrad/Kokoro-82M/tree/main/voices -""" - -from __future__ import annotations - -from ..provider import TtsProvider -from ..voice import CatalogVoice, VoiceGender - - -def _voice(name: str, language: str, display: str, gender: VoiceGender) -> CatalogVoice: - return CatalogVoice( - voice_id=f"kokoro:{name}", - provider=TtsProvider.KOKORO, - language=language, - display_name=display, - gender=gender, - native_ref=name, - ) - - -KOKORO_VOICES: tuple[CatalogVoice, ...] = ( - # American English - _voice("am_adam", "en-US", "Adam (US)", VoiceGender.MALE), - _voice("am_michael", "en-US", "Michael (US)", VoiceGender.MALE), - _voice("af_bella", "en-US", "Bella (US)", VoiceGender.FEMALE), - _voice("af_heart", "en-US", "Heart (US)", VoiceGender.FEMALE), - _voice("af_nicole", "en-US", "Nicole (US)", VoiceGender.FEMALE), - _voice("af_sarah", "en-US", "Sarah (US)", VoiceGender.FEMALE), - # British English - _voice("bm_george", "en-GB", "George (UK)", VoiceGender.MALE), - _voice("bm_lewis", "en-GB", "Lewis (UK)", VoiceGender.MALE), - _voice("bf_emma", "en-GB", "Emma (UK)", VoiceGender.FEMALE), - _voice("bf_isabella", "en-GB", "Isabella (UK)", VoiceGender.FEMALE), - # Spanish - _voice("em_alex", "es", "Alex (ES)", VoiceGender.MALE), - _voice("ef_dora", "es", "Dora (ES)", VoiceGender.FEMALE), - # French - _voice("ff_siwis", "fr", "Siwis (FR)", VoiceGender.FEMALE), - # Hindi - _voice("hm_omega", "hi", "Omega (HI)", VoiceGender.MALE), - _voice("hf_alpha", "hi", "Alpha (HI)", VoiceGender.FEMALE), - # Italian - _voice("im_nicola", "it", "Nicola (IT)", VoiceGender.MALE), - _voice("if_sara", "it", "Sara (IT)", VoiceGender.FEMALE), - # Japanese - _voice("jm_kumo", "ja", "Kumo (JA)", VoiceGender.MALE), - _voice("jf_alpha", "ja", "Alpha (JA)", VoiceGender.FEMALE), - # Brazilian Portuguese - _voice("pm_alex", "pt-BR", "Alex (BR)", VoiceGender.MALE), - _voice("pf_dora", "pt-BR", "Dora (BR)", VoiceGender.FEMALE), - # Mandarin Chinese - _voice("zm_yunxi", "zh", "Yunxi (ZH)", VoiceGender.MALE), - _voice("zf_xiaoxiao", "zh", "Xiaoxiao (ZH)", VoiceGender.FEMALE), -) diff --git a/surfsense_backend/app/podcasts/voices/data/languages.py b/surfsense_backend/app/podcasts/voices/data/languages.py deleted file mode 100644 index c00fd7f05..000000000 --- a/surfsense_backend/app/podcasts/voices/data/languages.py +++ /dev/null @@ -1,33 +0,0 @@ -"""Curated languages offered when a roster has wildcard (any-language) voices. - -OpenAI-style multilingual voices speak whatever language the text is in, so -there is no provider list to enumerate. This is the set the brief form offers -up front for such providers; it is an offering, not a limit — the API flags -``allows_custom`` so users can enter any BCP-47 tag beyond it. -""" - -from __future__ import annotations - -COMMON_LANGUAGES: tuple[str, ...] = ( - "ar", - "bn", - "de", - "en", - "es", - "fr", - "hi", - "id", - "it", - "ja", - "ko", - "nl", - "pl", - "pt", - "ru", - "sw", - "th", - "tr", - "uk", - "vi", - "zh", -) diff --git a/surfsense_backend/app/podcasts/voices/data/openai.py b/surfsense_backend/app/podcasts/voices/data/openai.py deleted file mode 100644 index ce5c480c5..000000000 --- a/surfsense_backend/app/podcasts/voices/data/openai.py +++ /dev/null @@ -1,32 +0,0 @@ -"""OpenAI TTS voices: language-agnostic, so each speaks any requested language. - -OpenAI voices follow the language of the input text rather than being tied to a -locale, so they are tagged :data:`ANY_LANGUAGE` and match every brief. The -``native_ref`` is the plain voice name the API expects. -""" - -from __future__ import annotations - -from ..provider import TtsProvider -from ..voice import ANY_LANGUAGE, CatalogVoice, VoiceGender - - -def _voice(name: str, display: str, gender: VoiceGender) -> CatalogVoice: - return CatalogVoice( - voice_id=f"openai:{name}", - provider=TtsProvider.OPENAI, - language=ANY_LANGUAGE, - display_name=display, - gender=gender, - native_ref=name, - ) - - -OPENAI_VOICES: tuple[CatalogVoice, ...] = ( - _voice("alloy", "Alloy", VoiceGender.NEUTRAL), - _voice("echo", "Echo", VoiceGender.MALE), - _voice("fable", "Fable", VoiceGender.NEUTRAL), - _voice("onyx", "Onyx", VoiceGender.MALE), - _voice("nova", "Nova", VoiceGender.FEMALE), - _voice("shimmer", "Shimmer", VoiceGender.FEMALE), -) diff --git a/surfsense_backend/app/podcasts/voices/data/vertex.py b/surfsense_backend/app/podcasts/voices/data/vertex.py deleted file mode 100644 index 99477eb21..000000000 --- a/surfsense_backend/app/podcasts/voices/data/vertex.py +++ /dev/null @@ -1,81 +0,0 @@ -"""Vertex AI Studio voices: locale-specific, referenced by a mapping. - -Vertex voices are tied to a locale and named via a ``{languageCode, name}`` -mapping, which is exactly the ``native_ref`` the LiteLLM adapter forwards. The -values mirror the legacy podcaster's English Studio voices. -""" - -from __future__ import annotations - -from ..provider import TtsProvider -from ..voice import CatalogVoice, VoiceGender - - -def _voice( - key: str, - language: str, - locale: str, - name: str, - display: str, - gender: VoiceGender, -) -> CatalogVoice: - return CatalogVoice( - voice_id=f"vertex_ai:{key}", - provider=TtsProvider.VERTEX_AI, - language=language, - display_name=display, - gender=gender, - native_ref={"languageCode": locale, "name": name}, - ) - - -VERTEX_VOICES: tuple[CatalogVoice, ...] = ( - _voice( - "en-US-Studio-O", - "en-US", - "en-US", - "en-US-Studio-O", - "Studio O (US)", - VoiceGender.FEMALE, - ), - _voice( - "en-US-Studio-M", - "en-US", - "en-US", - "en-US-Studio-M", - "Studio M (US)", - VoiceGender.MALE, - ), - _voice( - "en-GB-Studio-A", - "en-GB", - "en-UK", - "en-UK-Studio-A", - "Studio A (UK)", - VoiceGender.FEMALE, - ), - _voice( - "en-GB-Studio-B", - "en-GB", - "en-UK", - "en-UK-Studio-B", - "Studio B (UK)", - VoiceGender.MALE, - ), - _voice( - "en-AU-Studio-A", - "en-AU", - "en-AU", - "en-AU-Studio-A", - "Studio A (AU)", - VoiceGender.FEMALE, - ), - _voice( - "en-AU-Studio-B", - "en-AU", - "en-AU", - "en-AU-Studio-B", - "Studio B (AU)", - VoiceGender.MALE, - ), -) diff --git a/surfsense_backend/app/podcasts/voices/preview.py b/surfsense_backend/app/podcasts/voices/preview.py deleted file mode 100644 index 868504a91..000000000 --- a/surfsense_backend/app/podcasts/voices/preview.py +++ /dev/null @@ -1,65 +0,0 @@ -"""Audible previews so users pick voices by sound, not by name. - -A preview is a short sample sentence synthesised in the voice's own language. -Samples are served through the same content-addressed cache the renderer uses, -so each voice costs at most one synthesis per cache lifetime — repeat listens -while comparing voices are free. -""" - -from __future__ import annotations - -import tempfile -from pathlib import Path - -from app.podcasts.rendering.cache import SegmentCache -from app.podcasts.tts import SynthesisRequest, TextToSpeech - -from .voice import ANY_LANGUAGE, CatalogVoice - -# Previews are user-independent, so one rendered sample serves everyone. -PREVIEW_CACHE_ROOT = Path(tempfile.gettempdir()) / "surfsense_podcasts" / "previews" - -_FALLBACK_LANGUAGE = "en" - -# A voice previews best speaking its own language. -_SAMPLE_TEXTS = { - "en": "Hi there! This is how I sound when narrating your podcast.", - "es": "¡Hola! Así sueno cuando narro tu pódcast.", - "fr": "Bonjour ! Voici ma voix quand je raconte votre podcast.", - "hi": "नमस्ते! आपका पॉडकास्ट सुनाते समय मेरी आवाज़ ऐसी होती है।", - "it": "Ciao! Questa è la mia voce quando racconto il tuo podcast.", - "ja": "こんにちは。ポッドキャストをお届けするときの私の声です。", - "pt": "Olá! É assim que eu soo ao narrar o seu podcast.", - "zh": "你好!这就是我为你播报播客时的声音。", # noqa: RUF001 -} - -_CONTENT_TYPES = {"mp3": "audio/mpeg", "wav": "audio/wav"} - - -async def render_voice_preview( - voice: CatalogVoice, tts: TextToSpeech -) -> tuple[bytes, str]: - """Return ``(audio_bytes, content_type)`` for a sample spoken by ``voice``.""" - language = _FALLBACK_LANGUAGE if voice.language == ANY_LANGUAGE else voice.language - request = SynthesisRequest( - text=_sample_text(language), voice=voice.native_ref, language=language - ) - - cache = SegmentCache(PREVIEW_CACHE_ROOT) - key = cache.key(request) - cached = cache.get(key, tts.container) - if cached is not None: - return cached.read_bytes(), _content_type(tts.container) - - audio = await tts.synthesize(request) - cache.put(key, audio.container, audio.data) - return audio.data, _content_type(audio.container) - - -def _sample_text(language: str) -> str: - primary = language.split("-", 1)[0].strip().lower() - return _SAMPLE_TEXTS.get(primary, _SAMPLE_TEXTS[_FALLBACK_LANGUAGE]) - - -def _content_type(container: str) -> str: - return _CONTENT_TYPES.get(container, "application/octet-stream") diff --git a/surfsense_backend/app/podcasts/voices/provider.py b/surfsense_backend/app/podcasts/voices/provider.py deleted file mode 100644 index f57ae11cc..000000000 --- a/surfsense_backend/app/podcasts/voices/provider.py +++ /dev/null @@ -1,27 +0,0 @@ -"""The TTS providers we carry voices for, and how to name one from config.""" - -from __future__ import annotations - -from enum import StrEnum - - -class TtsProvider(StrEnum): - """A speech provider whose voices the catalog enumerates.""" - - KOKORO = "kokoro" - OPENAI = "openai" - AZURE = "azure" - VERTEX_AI = "vertex_ai" - - -def provider_from_service(service: str) -> TtsProvider: - """Map a ``TTS_SERVICE`` string to its provider. - - The config value is a LiteLLM-style ``provider/model`` string - (``openai/tts-1``, ``vertex_ai/...``) except for local Kokoro, which is - spelled ``local/kokoro``; both halves of that special case resolve here. - """ - prefix = service.split("/", 1)[0].strip().lower() - if prefix == "local": - return TtsProvider.KOKORO - return TtsProvider(prefix) diff --git a/surfsense_backend/app/podcasts/voices/voice.py b/surfsense_backend/app/podcasts/voices/voice.py deleted file mode 100644 index 6478f04b0..000000000 --- a/surfsense_backend/app/podcasts/voices/voice.py +++ /dev/null @@ -1,50 +0,0 @@ -"""A catalog voice: a stable id paired with its provider-native reference.""" - -from __future__ import annotations - -from dataclasses import dataclass -from enum import StrEnum - -from app.podcasts.tts import VoiceRef - -from .provider import TtsProvider - -# A voice that speaks whatever language the input text is in (e.g. OpenAI's -# voices), matched against every requested language. -ANY_LANGUAGE = "*" - - -class VoiceGender(StrEnum): - """Perceived voice gender, used to pick distinct voices per speaker.""" - - MALE = "male" - FEMALE = "female" - NEUTRAL = "neutral" - - -@dataclass(frozen=True, slots=True) -class CatalogVoice: - """One selectable voice. - - ``voice_id`` is the provider-prefixed, stable id stored on a speaker in the - brief (e.g. ``"kokoro:am_adam"``). ``native_ref`` is the untyped value the - TTS adapter passes to the provider — a string for most, a mapping for - Vertex — kept separate so renaming the catalog id never breaks synthesis. - """ - - voice_id: str - provider: TtsProvider - language: str - display_name: str - gender: VoiceGender - native_ref: VoiceRef - - def speaks(self, language: str) -> bool: - """Whether this voice can render ``language`` (primary subtag match).""" - if self.language == ANY_LANGUAGE: - return True - return _primary(self.language) == _primary(language) - - -def _primary(language: str) -> str: - return language.split("-", 1)[0].strip().lower() diff --git a/surfsense_backend/app/prompts/default_system_instructions.py b/surfsense_backend/app/prompts/default_system_instructions.py index b968fc1f0..fd0a8e186 100644 --- a/surfsense_backend/app/prompts/default_system_instructions.py +++ b/surfsense_backend/app/prompts/default_system_instructions.py @@ -82,7 +82,7 @@ def build_configurable_system_prompt( *, model_name: str | None = None, ) -> str: - """Build a configurable SurfSense system prompt. + """Build a configurable SurfSense system prompt (NewLLMConfig path). See :func:`app.prompts.system_prompt_composer.composer.compose_system_prompt` for full parameter docs. @@ -104,7 +104,7 @@ def build_configurable_system_prompt( def get_default_system_instructions() -> str: """Return the default ``<system_instruction>`` block (no tools / citations). - Useful for populating the UI when editing custom system instructions. + Useful for populating the UI when seeding ``NewLLMConfig.system_instructions``. The output reflects the current fragment tree, not a baked-in constant. """ resolved_today = datetime.now(UTC).date().isoformat() diff --git a/surfsense_backend/app/prompts/system_prompt_composer/composer.py b/surfsense_backend/app/prompts/system_prompt_composer/composer.py index c639d4aa0..3849af313 100644 --- a/surfsense_backend/app/prompts/system_prompt_composer/composer.py +++ b/surfsense_backend/app/prompts/system_prompt_composer/composer.py @@ -348,7 +348,8 @@ def compose_system_prompt( mcp_connector_tools: ``{server_name: [tool_names...]}`` to inject an explicit MCP routing block. custom_system_instructions: Free-form instructions that override - the default ``<system_instruction>`` block. + the default ``<system_instruction>`` block (legacy support + for ``NewLLMConfig.system_instructions``). use_default_system_instructions: When ``custom_system_instructions`` is empty/None, fall back to defaults (legacy semantics). citations_enabled: Include ``citations_on.md`` (true) or diff --git a/surfsense_backend/app/retriever/chunks_hybrid_search.py b/surfsense_backend/app/retriever/chunks_hybrid_search.py index 5e5edec2e..47f7fe6b1 100644 --- a/surfsense_backend/app/retriever/chunks_hybrid_search.py +++ b/surfsense_backend/app/retriever/chunks_hybrid_search.py @@ -420,10 +420,7 @@ class ChucksHybridSearchRetriever: select( Chunk.id.label("chunk_id"), func.row_number() - .over( - partition_by=Chunk.document_id, - order_by=(Chunk.position, Chunk.id), - ) + .over(partition_by=Chunk.document_id, order_by=Chunk.id) .label("rn"), ) .where(Chunk.document_id.in_(doc_ids)) @@ -444,7 +441,7 @@ class ChucksHybridSearchRetriever: select(Chunk.id, Chunk.content, Chunk.document_id) .join(numbered, Chunk.id == numbered.c.chunk_id) .where(chunk_filter) - .order_by(Chunk.document_id, Chunk.position, Chunk.id) + .order_by(Chunk.document_id, Chunk.id) ) t_fetch = time.perf_counter() diff --git a/surfsense_backend/app/retriever/documents_hybrid_search.py b/surfsense_backend/app/retriever/documents_hybrid_search.py index d856e93cf..9ce86d404 100644 --- a/surfsense_backend/app/retriever/documents_hybrid_search.py +++ b/surfsense_backend/app/retriever/documents_hybrid_search.py @@ -357,10 +357,7 @@ class DocumentHybridSearchRetriever: select( Chunk.id.label("chunk_id"), func.row_number() - .over( - partition_by=Chunk.document_id, - order_by=(Chunk.position, Chunk.id), - ) + .over(partition_by=Chunk.document_id, order_by=Chunk.id) .label("rn"), ) .where(Chunk.document_id.in_(doc_ids)) @@ -372,7 +369,7 @@ class DocumentHybridSearchRetriever: select(Chunk.id, Chunk.content, Chunk.document_id) .join(numbered, Chunk.id == numbered.c.chunk_id) .where(numbered.c.rn <= _MAX_FETCH_CHUNKS_PER_DOC) - .order_by(Chunk.document_id, Chunk.position, Chunk.id) + .order_by(Chunk.document_id, Chunk.id) ) t_fetch = time.perf_counter() diff --git a/surfsense_backend/app/routes/__init__.py b/surfsense_backend/app/routes/__init__.py index 8ce84d179..5cc029884 100644 --- a/surfsense_backend/app/routes/__init__.py +++ b/surfsense_backend/app/routes/__init__.py @@ -4,7 +4,6 @@ from app.automations.api import router as automations_router from app.file_storage.api import router as file_storage_router from app.gateway import require_gateway_enabled from app.notifications.api import router as notifications_router -from app.podcasts.api import router as podcasts_router from .agent_action_log_route import router as agent_action_log_router from .agent_flags_route import router as agent_flags_router @@ -24,10 +23,7 @@ from .dropbox_add_connector_route import router as dropbox_add_connector_router from .editor_routes import router as editor_router from .export_routes import router as export_router from .folders_routes import router as folders_router -from .gateway_webhook_routes import ( - config_router as gateway_config_router, - router as gateway_router, -) +from .gateway_webhook_routes import router as gateway_router from .gateway_whatsapp_baileys_routes import router as gateway_whatsapp_baileys_router from .gateway_whatsapp_webhook_routes import router as gateway_whatsapp_webhook_router from .google_calendar_add_connector_route import ( @@ -47,13 +43,14 @@ from .logs_routes import router as logs_router from .luma_add_connector_route import router as luma_add_connector_router from .mcp_oauth_route import router as mcp_oauth_router from .memory_routes import router as memory_router -from .model_connections_routes import router as model_connections_router from .model_list_routes import router as model_list_router from .new_chat_routes import router as new_chat_router +from .new_llm_config_routes import router as new_llm_config_router from .notes_routes import router as notes_router from .notion_add_connector_route import router as notion_add_connector_router from .obsidian_plugin_routes import router as obsidian_plugin_router from .onedrive_add_connector_route import router as onedrive_add_connector_router +from .podcasts_routes import router as podcasts_router from .prompts_routes import router as prompts_router from .public_chat_routes import router as public_chat_router from .rbac_routes import router as rbac_router @@ -66,6 +63,7 @@ from .stripe_routes import router as stripe_router from .team_memory_routes import router as team_memory_router from .teams_add_connector_route import router as teams_add_connector_router from .video_presentations_routes import router as video_presentations_router +from .vision_llm_routes import router as vision_llm_router from .youtube_routes import router as youtube_router router = APIRouter() @@ -77,7 +75,6 @@ router.include_router(export_router) router.include_router(documents_router) router.include_router(folders_router) _gateway_enabled_dep = [Depends(require_gateway_enabled)] -router.include_router(gateway_config_router) router.include_router(gateway_router, dependencies=_gateway_enabled_dep) router.include_router( gateway_whatsapp_webhook_router, dependencies=_gateway_enabled_dep @@ -101,6 +98,7 @@ router.include_router( ) # Video presentation status and streaming router.include_router(reports_router) # Report CRUD and multi-format export router.include_router(image_generation_router) # Image generation via litellm +router.include_router(vision_llm_router) # Vision LLM configs for screenshot analysis router.include_router(search_source_connectors_router) router.include_router(google_calendar_add_connector_router) router.include_router(google_gmail_add_connector_router) @@ -118,7 +116,7 @@ router.include_router(jira_add_connector_router) router.include_router(confluence_add_connector_router) router.include_router(clickup_add_connector_router) router.include_router(dropbox_add_connector_router) -router.include_router(model_connections_router) # Connection-centric model catalog +router.include_router(new_llm_config_router) # LLM configs with prompt configuration router.include_router(model_list_router) # Dynamic model catalogue from OpenRouter router.include_router(logs_router) router.include_router(circleback_webhook_router) # Circleback meeting webhooks diff --git a/surfsense_backend/app/routes/anonymous_chat_routes.py b/surfsense_backend/app/routes/anonymous_chat_routes.py index f6f984c20..ad3277375 100644 --- a/surfsense_backend/app/routes/anonymous_chat_routes.py +++ b/surfsense_backend/app/routes/anonymous_chat_routes.py @@ -18,7 +18,6 @@ from app.etl_pipeline.file_classifier import ( PLAINTEXT_EXTENSIONS, ) from app.rate_limiter import limiter -from app.tasks.chat.streaming.errors.classifier import classify_stream_exception logger = logging.getLogger(__name__) @@ -99,6 +98,7 @@ class AnonQuotaResponse(BaseModel): class AnonModelResponse(BaseModel): id: int name: str + description: str | None = None provider: str model_name: str billing_tier: str = "free" @@ -131,7 +131,8 @@ async def list_anonymous_models(): AnonModelResponse( id=cfg.get("id", 0), name=cfg.get("name", ""), - provider=cfg.get("provider") or cfg.get("litellm_provider", ""), + description=cfg.get("description"), + provider=cfg.get("provider", ""), model_name=cfg.get("model_name", ""), billing_tier=cfg.get("billing_tier", "free"), is_premium=cfg.get("billing_tier", "free") == "premium", @@ -159,7 +160,8 @@ async def get_anonymous_model(slug: str): return AnonModelResponse( id=cfg.get("id", 0), name=cfg.get("name", ""), - provider=cfg.get("provider") or cfg.get("litellm_provider", ""), + description=cfg.get("description"), + provider=cfg.get("provider", ""), model_name=cfg.get("model_name", ""), billing_tier=cfg.get("billing_tier", "free"), is_premium=cfg.get("billing_tier", "free") == "premium", @@ -472,15 +474,7 @@ async def stream_anonymous_chat( except Exception as e: logger.exception("Anonymous chat stream error") await TokenQuotaService.anon_release(session_key, ip_key, request_id) - _, error_code, _, _, user_message, extra = classify_stream_exception( - e, - flow_label="chat", - ) - yield streaming_service.format_error( - user_message, - error_code=error_code, - extra=extra, - ) + yield streaming_service.format_error(f"Error during chat: {e!s}") yield streaming_service.format_done() finally: await TokenQuotaService.anon_release_stream_slot(client_ip) diff --git a/surfsense_backend/app/routes/documents_routes.py b/surfsense_backend/app/routes/documents_routes.py index 53f03a0ca..865068fba 100644 --- a/surfsense_backend/app/routes/documents_routes.py +++ b/surfsense_backend/app/routes/documents_routes.py @@ -1014,8 +1014,8 @@ async def get_document_by_chunk_id( .filter( Chunk.document_id == document.id, or_( - Chunk.position < chunk.position, - and_(Chunk.position == chunk.position, Chunk.id < chunk.id), + Chunk.created_at < chunk.created_at, + and_(Chunk.created_at == chunk.created_at, Chunk.id < chunk.id), ), ) ) @@ -1027,7 +1027,7 @@ async def get_document_by_chunk_id( windowed_result = await session.execute( select(Chunk) .filter(Chunk.document_id == document.id) - .order_by(Chunk.position, Chunk.id) + .order_by(Chunk.created_at, Chunk.id) .offset(start) .limit(end - start) ) @@ -1137,7 +1137,7 @@ async def get_document_chunks_paginated( chunks_result = await session.execute( select(Chunk) .filter(Chunk.document_id == document_id) - .order_by(Chunk.position, Chunk.id) + .order_by(Chunk.created_at, Chunk.id) .offset(offset) .limit(page_size) ) diff --git a/surfsense_backend/app/routes/editor_routes.py b/surfsense_backend/app/routes/editor_routes.py index 8250fff98..166164c50 100644 --- a/surfsense_backend/app/routes/editor_routes.py +++ b/surfsense_backend/app/routes/editor_routes.py @@ -38,8 +38,7 @@ logger = logging.getLogger(__name__) router = APIRouter() -EDITOR_PLATE_MAX_BYTES = 1 * 1024 * 1024 -EDITOR_PLATE_MAX_LINES = 5000 +EDITOR_PLATE_MAX_BYTES = 5 * 1024 * 1024 @router.get("/search-spaces/{search_space_id}/documents/{document_id}/editor-content") @@ -84,22 +83,16 @@ async def get_editor_content( def _build_response(md: str) -> dict: size_bytes = len(md.encode("utf-8")) - line_count = md.count("\n") + 1 - too_large = ( - size_bytes > EDITOR_PLATE_MAX_BYTES or line_count > EDITOR_PLATE_MAX_LINES - ) - viewer_mode = "monaco" if too_large else "plate" + viewer_mode = "monaco" if size_bytes > EDITOR_PLATE_MAX_BYTES else "plate" return { "document_id": document.id, "title": document.title, "document_type": document.document_type.value, "source_markdown": md, "content_size_bytes": size_bytes, - "line_count": line_count, "chunk_count": chunk_count, "viewer_mode": viewer_mode, "editor_plate_max_bytes": EDITOR_PLATE_MAX_BYTES, - "editor_plate_max_lines": EDITOR_PLATE_MAX_LINES, "updated_at": document.updated_at.isoformat() if document.updated_at else None, @@ -126,7 +119,7 @@ async def get_editor_content( chunk_contents_result = await session.execute( select(Chunk.content) .filter(Chunk.document_id == document_id) - .order_by(Chunk.position, Chunk.id) + .order_by(Chunk.id) ) chunk_contents = chunk_contents_result.scalars().all() @@ -212,7 +205,7 @@ async def download_document_markdown( chunk_contents_result = await session.execute( select(Chunk.content) .filter(Chunk.document_id == document_id) - .order_by(Chunk.position, Chunk.id) + .order_by(Chunk.id) ) chunk_contents = chunk_contents_result.scalars().all() if chunk_contents: @@ -361,7 +354,7 @@ async def export_document( chunk_contents_result = await session.execute( select(Chunk.content) .filter(Chunk.document_id == document_id) - .order_by(Chunk.position, Chunk.id) + .order_by(Chunk.id) ) chunk_contents = chunk_contents_result.scalars().all() if chunk_contents: diff --git a/surfsense_backend/app/routes/gateway_webhook_routes.py b/surfsense_backend/app/routes/gateway_webhook_routes.py index 9b4af4b83..14f929567 100644 --- a/surfsense_backend/app/routes/gateway_webhook_routes.py +++ b/surfsense_backend/app/routes/gateway_webhook_routes.py @@ -56,7 +56,6 @@ from app.utils.oauth_security import OAuthStateManager, TokenEncryption from app.utils.rbac import check_search_space_access router = APIRouter(prefix="/gateway", tags=["gateway"]) -config_router = APIRouter(prefix="/gateway", tags=["gateway"]) logger = logging.getLogger(__name__) SLACK_AUTHORIZATION_URL = "https://slack.com/oauth/v2/authorize" @@ -968,20 +967,11 @@ async def list_platforms( ] -@config_router.get("/config") +@router.get("/config") async def get_gateway_config( user: User = Depends(current_active_user), ) -> dict[str, bool | str]: - if not config.GATEWAY_ENABLED: - return { - "enabled": False, - "telegram_enabled": False, - "whatsapp_intake_mode": "disabled", - "slack_enabled": False, - "discord_enabled": False, - } return { - "enabled": True, "telegram_enabled": _telegram_gateway_enabled(), "whatsapp_intake_mode": config.GATEWAY_WHATSAPP_INTAKE_MODE, "slack_enabled": _slack_gateway_enabled(), diff --git a/surfsense_backend/app/routes/image_generation_routes.py b/surfsense_backend/app/routes/image_generation_routes.py index cc3e51ed5..018234ad5 100644 --- a/surfsense_backend/app/routes/image_generation_routes.py +++ b/surfsense_backend/app/routes/image_generation_routes.py @@ -1,5 +1,7 @@ """ Image Generation routes: +- CRUD for ImageGenerationConfig (user-created image model configs) +- Global image gen configs endpoint (from YAML) - Image generation execution (calls litellm.aimage_generation()) - CRUD for ImageGeneration records (results) - Image serving endpoint (serves b64_json images from DB, protected by signed tokens) @@ -14,12 +16,11 @@ from litellm import aimage_generation from sqlalchemy import select from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import selectinload from app.config import config from app.db import ( ImageGeneration, - Model, + ImageGenerationConfig, Permission, SearchSpace, SearchSpaceMembership, @@ -27,14 +28,14 @@ from app.db import ( get_async_session, ) from app.schemas import ( + GlobalImageGenConfigRead, + ImageGenerationConfigCreate, + ImageGenerationConfigRead, + ImageGenerationConfigUpdate, ImageGenerationCreate, ImageGenerationListRead, ImageGenerationRead, ) -from app.services.auto_model_pin_service import ( - auto_model_candidates, - choose_auto_model_candidate, -) from app.services.billable_calls import ( DEFAULT_IMAGE_RESERVE_MICROS, QuotaInsufficientError, @@ -42,10 +43,10 @@ from app.services.billable_calls import ( ) from app.services.image_gen_router_service import ( IMAGE_GEN_AUTO_MODE_ID, + ImageGenRouterService, is_image_gen_auto_mode, ) -from app.services.model_capabilities import has_capability -from app.services.model_resolver import to_litellm +from app.services.provider_api_base import resolve_api_base from app.users import current_active_user from app.utils.rbac import check_permission from app.utils.signed_image_urls import verify_image_token @@ -53,16 +54,52 @@ from app.utils.signed_image_urls import verify_image_token router = APIRouter() logger = logging.getLogger(__name__) - -def _get_global_model(model_id: int) -> dict | None: - return next((m for m in config.GLOBAL_MODELS if m.get("id") == model_id), None) +# Provider mapping for building litellm model strings. +# Only includes providers that support image generation. +# See: https://docs.litellm.ai/docs/image_generation#supported-providers +_PROVIDER_MAP = { + "OPENAI": "openai", + "AZURE_OPENAI": "azure", + "GOOGLE": "gemini", # Google AI Studio + "VERTEX_AI": "vertex_ai", + "BEDROCK": "bedrock", # AWS Bedrock + "RECRAFT": "recraft", + "OPENROUTER": "openrouter", + "XINFERENCE": "xinference", + "NSCALE": "nscale", +} -def _get_global_connection(connection_id: int) -> dict | None: - return next( - (c for c in config.GLOBAL_CONNECTIONS if c.get("id") == connection_id), - None, - ) +def _get_global_image_gen_config(config_id: int) -> dict | None: + """Get a global image generation configuration by ID (negative IDs).""" + if config_id == IMAGE_GEN_AUTO_MODE_ID: + return { + "id": IMAGE_GEN_AUTO_MODE_ID, + "name": "Auto (Fastest)", + "provider": "AUTO", + "model_name": "auto", + "is_auto_mode": True, + } + if config_id > 0: + return None + for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS: + if cfg.get("id") == config_id: + return cfg + return None + + +def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str: + """Resolve the LiteLLM provider prefix used in model strings.""" + if custom_provider: + return custom_provider + return _PROVIDER_MAP.get(provider.upper(), provider.lower()) + + +def _build_model_string( + provider: str, model_name: str, custom_provider: str | None +) -> str: + """Build a litellm model string from provider + model_name.""" + return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}" async def _resolve_billing_for_image_gen( @@ -78,41 +115,34 @@ async def _resolve_billing_for_image_gen( config that will actually run, and so we don't open an ``ImageGeneration`` row for a request that's about to 402. - User-owned (positive ID) BYOK models are always free — they cost - the user nothing on our side. Auto mode resolves to one concrete - global or BYOK model before billing is calculated. + User-owned (positive ID) BYOK configs are always free — they cost + the user nothing on our side. Auto mode currently treats as free + because the underlying router can dispatch to either premium or + free YAML configs and we don't surface the resolved deployment up + here yet. Bringing Auto under premium billing would require + threading the chosen deployment back from ``ImageGenRouterService``. """ resolved_id = config_id if resolved_id is None: - resolved_id = search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID + resolved_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID if is_image_gen_auto_mode(resolved_id): - candidates = await auto_model_candidates( - session, - search_space_id=search_space.id, - user_id=search_space.user_id, - capability="image_gen", - ) - if not candidates: - return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS) - selected = choose_auto_model_candidate(candidates, search_space.id) - resolved_id = int(selected["id"]) + return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS) if resolved_id < 0: - global_model = _get_global_model(resolved_id) or {} - global_connection = _get_global_connection(global_model.get("connection_id", 0)) - billing_tier = str(global_model.get("billing_tier", "free")).lower() - if global_connection and global_model.get("model_id"): - base_model, _ = to_litellm(global_connection, global_model["model_id"]) - else: - base_model = "global_image_model" - catalog = global_model.get("catalog") or {} + cfg = _get_global_image_gen_config(resolved_id) or {} + billing_tier = str(cfg.get("billing_tier", "free")).lower() + base_model = _build_model_string( + cfg.get("provider", ""), + cfg.get("model_name", ""), + cfg.get("custom_provider"), + ) reserve_micros = int( - catalog.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS + cfg.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS ) return (billing_tier, base_model, reserve_micros) - # Positive ID = user-owned BYOK image-gen model — always free. + # Positive ID = user-owned BYOK image-gen config — always free. return ("free", "user_byok", DEFAULT_IMAGE_RESERVE_MICROS) @@ -125,14 +155,14 @@ async def _execute_image_generation( Call litellm.aimage_generation() with the appropriate config. Resolution order: - 1. Explicit image_gen_model_id on the request - 2. Search space's image_gen_model_id preference + 1. Explicit image_generation_config_id on the request + 2. Search space's image_generation_config_id preference 3. Falls back to Auto mode if available """ - config_id = image_gen.image_gen_model_id + config_id = image_gen.image_generation_config_id if config_id is None: - config_id = search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID - image_gen.image_gen_model_id = config_id + config_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID + image_gen.image_generation_config_id = config_id # Build kwargs gen_kwargs = {} @@ -148,30 +178,36 @@ async def _execute_image_generation( gen_kwargs["response_format"] = image_gen.response_format if is_image_gen_auto_mode(config_id): - candidates = await auto_model_candidates( - session, - search_space_id=search_space.id, - user_id=search_space.user_id, - capability="image_gen", + if not ImageGenRouterService.is_initialized(): + raise ValueError( + "Auto mode requested but Image Generation Router not initialized. " + "Ensure global_llm_config.yaml has global_image_generation_configs." + ) + response = await ImageGenRouterService.aimage_generation( + prompt=image_gen.prompt, model="auto", **gen_kwargs ) - if not candidates: - raise ValueError("No image-generation models are available for Auto mode") - config_id = int(choose_auto_model_candidate(candidates, search_space.id)["id"]) - image_gen.image_gen_model_id = config_id + elif config_id < 0: + # Global config from YAML + cfg = _get_global_image_gen_config(config_id) + if not cfg: + raise ValueError(f"Global image generation config {config_id} not found") - if config_id < 0: - global_model = _get_global_model(config_id) - if not global_model or not has_capability(global_model, "image_gen"): - raise ValueError(f"Global image generation model {config_id} not found") - global_connection = _get_global_connection(global_model["connection_id"]) - if not global_connection: - raise ValueError(f"Global connection for image model {config_id} not found") - - model_string, resolved_kwargs = to_litellm( - global_connection, - global_model["model_id"], + provider_prefix = _resolve_provider_prefix( + cfg.get("provider", ""), cfg.get("custom_provider") ) - gen_kwargs.update(resolved_kwargs) + model_string = f"{provider_prefix}/{cfg['model_name']}" + gen_kwargs["api_key"] = cfg.get("api_key") + api_base = resolve_api_base( + provider=cfg.get("provider"), + provider_prefix=provider_prefix, + config_api_base=cfg.get("api_base"), + ) + if api_base: + gen_kwargs["api_base"] = api_base + if cfg.get("api_version"): + gen_kwargs["api_version"] = cfg["api_version"] + if cfg.get("litellm_params"): + gen_kwargs.update(cfg["litellm_params"]) # User model override if image_gen.model: @@ -181,28 +217,30 @@ async def _execute_image_generation( prompt=image_gen.prompt, model=model_string, **gen_kwargs ) else: - # Positive ID = Model + Connection + # Positive ID = DB ImageGenerationConfig result = await session.execute( - select(Model) - .options(selectinload(Model.connection)) - .filter(Model.id == config_id, Model.enabled.is_(True)) + select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id) ) - db_model = result.scalars().first() - if not db_model or not db_model.connection or not db_model.connection.enabled: - raise ValueError(f"Image generation model {config_id} not found") - conn = db_model.connection - if conn.search_space_id is not None and conn.search_space_id != search_space.id: - raise ValueError(f"Image generation model {config_id} not found") - if conn.user_id is not None and conn.user_id != search_space.user_id: - raise ValueError(f"Image generation model {config_id} not found") - if not has_capability(db_model, "image_gen"): - raise ValueError(f"Model {config_id} is not image-generation capable") + db_cfg = result.scalars().first() + if not db_cfg: + raise ValueError(f"Image generation config {config_id} not found") - model_string, resolved_kwargs = to_litellm( - db_model.connection, - db_model.model_id, + provider_prefix = _resolve_provider_prefix( + db_cfg.provider.value, db_cfg.custom_provider ) - gen_kwargs.update(resolved_kwargs) + model_string = f"{provider_prefix}/{db_cfg.model_name}" + gen_kwargs["api_key"] = db_cfg.api_key + api_base = resolve_api_base( + provider=db_cfg.provider.value, + provider_prefix=provider_prefix, + config_api_base=db_cfg.api_base, + ) + if api_base: + gen_kwargs["api_base"] = api_base + if db_cfg.api_version: + gen_kwargs["api_version"] = db_cfg.api_version + if db_cfg.litellm_params: + gen_kwargs.update(db_cfg.litellm_params) # User model override if image_gen.model: @@ -222,6 +260,266 @@ async def _execute_image_generation( image_gen.model = hidden["model"] +# ============================================================================= +# Global Image Generation Configs (from YAML) +# ============================================================================= + + +@router.get( + "/global-image-generation-configs", + response_model=list[GlobalImageGenConfigRead], +) +async def get_global_image_gen_configs( + user: User = Depends(current_active_user), +): + """Get all global image generation configs. API keys are hidden.""" + try: + global_configs = config.GLOBAL_IMAGE_GEN_CONFIGS + safe_configs = [] + + if global_configs and len(global_configs) > 0: + safe_configs.append( + { + "id": 0, + "name": "Auto (Fastest)", + "description": "Automatically routes across available image generation providers.", + "provider": "AUTO", + "custom_provider": None, + "model_name": "auto", + "api_base": None, + "api_version": None, + "litellm_params": {}, + "is_global": True, + "is_auto_mode": True, + # Auto mode currently treated as free until per-deployment + # billing-tier surfacing lands (see _resolve_billing_for_image_gen). + "billing_tier": "free", + "is_premium": False, + } + ) + + for cfg in global_configs: + billing_tier = str(cfg.get("billing_tier", "free")).lower() + safe_configs.append( + { + "id": cfg.get("id"), + "name": cfg.get("name"), + "description": cfg.get("description"), + "provider": cfg.get("provider"), + "custom_provider": cfg.get("custom_provider"), + "model_name": cfg.get("model_name"), + "api_base": cfg.get("api_base") or None, + "api_version": cfg.get("api_version") or None, + "litellm_params": cfg.get("litellm_params", {}), + "is_global": True, + "billing_tier": billing_tier, + # Mirror chat (``new_llm_config_routes``) so the new-chat + # selector's premium badge logic keys off the same + # field across chat / image / vision tabs. + "is_premium": billing_tier == "premium", + "quota_reserve_micros": cfg.get("quota_reserve_micros"), + } + ) + + return safe_configs + except Exception as e: + logger.exception("Failed to fetch global image generation configs") + raise HTTPException( + status_code=500, detail=f"Failed to fetch configs: {e!s}" + ) from e + + +# ============================================================================= +# ImageGenerationConfig CRUD +# ============================================================================= + + +@router.post("/image-generation-configs", response_model=ImageGenerationConfigRead) +async def create_image_gen_config( + config_data: ImageGenerationConfigCreate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """Create a new image generation config for a search space.""" + try: + await check_permission( + session, + user, + config_data.search_space_id, + Permission.IMAGE_GENERATIONS_CREATE.value, + "You don't have permission to create image generation configs in this search space", + ) + + db_config = ImageGenerationConfig(**config_data.model_dump(), user_id=user.id) + session.add(db_config) + await session.commit() + await session.refresh(db_config) + return db_config + + except HTTPException: + raise + except Exception as e: + await session.rollback() + logger.exception("Failed to create ImageGenerationConfig") + raise HTTPException( + status_code=500, detail=f"Failed to create config: {e!s}" + ) from e + + +@router.get("/image-generation-configs", response_model=list[ImageGenerationConfigRead]) +async def list_image_gen_configs( + search_space_id: int, + skip: int = 0, + limit: int = 100, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """List image generation configs for a search space.""" + try: + await check_permission( + session, + user, + search_space_id, + Permission.IMAGE_GENERATIONS_READ.value, + "You don't have permission to view image generation configs in this search space", + ) + + result = await session.execute( + select(ImageGenerationConfig) + .filter(ImageGenerationConfig.search_space_id == search_space_id) + .order_by(ImageGenerationConfig.created_at.desc()) + .offset(skip) + .limit(limit) + ) + return result.scalars().all() + + except HTTPException: + raise + except Exception as e: + logger.exception("Failed to list ImageGenerationConfigs") + raise HTTPException( + status_code=500, detail=f"Failed to fetch configs: {e!s}" + ) from e + + +@router.get( + "/image-generation-configs/{config_id}", response_model=ImageGenerationConfigRead +) +async def get_image_gen_config( + config_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """Get a specific image generation config by ID.""" + try: + result = await session.execute( + select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id) + ) + db_config = result.scalars().first() + if not db_config: + raise HTTPException(status_code=404, detail="Config not found") + + await check_permission( + session, + user, + db_config.search_space_id, + Permission.IMAGE_GENERATIONS_READ.value, + "You don't have permission to view image generation configs in this search space", + ) + return db_config + + except HTTPException: + raise + except Exception as e: + logger.exception("Failed to get ImageGenerationConfig") + raise HTTPException( + status_code=500, detail=f"Failed to fetch config: {e!s}" + ) from e + + +@router.put( + "/image-generation-configs/{config_id}", response_model=ImageGenerationConfigRead +) +async def update_image_gen_config( + config_id: int, + update_data: ImageGenerationConfigUpdate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """Update an existing image generation config.""" + try: + result = await session.execute( + select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id) + ) + db_config = result.scalars().first() + if not db_config: + raise HTTPException(status_code=404, detail="Config not found") + + await check_permission( + session, + user, + db_config.search_space_id, + Permission.IMAGE_GENERATIONS_CREATE.value, + "You don't have permission to update image generation configs in this search space", + ) + + for key, value in update_data.model_dump(exclude_unset=True).items(): + setattr(db_config, key, value) + + await session.commit() + await session.refresh(db_config) + return db_config + + except HTTPException: + raise + except Exception as e: + await session.rollback() + logger.exception("Failed to update ImageGenerationConfig") + raise HTTPException( + status_code=500, detail=f"Failed to update config: {e!s}" + ) from e + + +@router.delete("/image-generation-configs/{config_id}", response_model=dict) +async def delete_image_gen_config( + config_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """Delete an image generation config.""" + try: + result = await session.execute( + select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id) + ) + db_config = result.scalars().first() + if not db_config: + raise HTTPException(status_code=404, detail="Config not found") + + await check_permission( + session, + user, + db_config.search_space_id, + Permission.IMAGE_GENERATIONS_DELETE.value, + "You don't have permission to delete image generation configs in this search space", + ) + + await session.delete(db_config) + await session.commit() + return { + "message": "Image generation config deleted successfully", + "id": config_id, + } + + except HTTPException: + raise + except Exception as e: + await session.rollback() + logger.exception("Failed to delete ImageGenerationConfig") + raise HTTPException( + status_code=500, detail=f"Failed to delete config: {e!s}" + ) from e + + # ============================================================================= # Image Generation Execution + Results CRUD # ============================================================================= @@ -270,7 +568,7 @@ async def create_image_generation( raise HTTPException(status_code=404, detail="Search space not found") billing_tier, base_model, reserve_micros = await _resolve_billing_for_image_gen( - session, data.image_gen_model_id, search_space + session, data.image_generation_config_id, search_space ) # billable_call runs OUTSIDE the inner try/except so QuotaInsufficientError @@ -296,7 +594,7 @@ async def create_image_generation( size=data.size, style=data.style, response_format=data.response_format, - image_gen_model_id=data.image_gen_model_id, + image_generation_config_id=data.image_generation_config_id, search_space_id=data.search_space_id, created_by_id=user.id, ) @@ -324,10 +622,11 @@ async def create_image_generation( detail={ "error_code": "premium_quota_exhausted", "usage_type": exc.usage_type, - "balance_micros": exc.balance_micros, + "used_micros": exc.used_micros, + "limit_micros": exc.limit_micros, "remaining_micros": exc.remaining_micros, "message": ( - "Out of credits for image generation. " + "Out of premium credits for image generation. " "Purchase additional credits or switch to a free model." ), }, diff --git a/surfsense_backend/app/routes/incentive_tasks_routes.py b/surfsense_backend/app/routes/incentive_tasks_routes.py index 1dae09a2d..496b07d06 100644 --- a/surfsense_backend/app/routes/incentive_tasks_routes.py +++ b/surfsense_backend/app/routes/incentive_tasks_routes.py @@ -1,6 +1,6 @@ """ Incentive Tasks API routes. -Allows users to complete tasks (like starring GitHub repo) to earn free credits. +Allows users to complete tasks (like starring GitHub repo) to earn free pages. Each task can only be completed once per user. """ @@ -42,21 +42,21 @@ async def get_incentive_tasks( # Build task list with completion status tasks = [] - total_credit_micros_earned = 0 + total_pages_earned = 0 for task_type, config in INCENTIVE_TASKS_CONFIG.items(): completed_task = completed_tasks.get(task_type) is_completed = completed_task is not None if is_completed: - total_credit_micros_earned += completed_task.credit_micros_awarded + total_pages_earned += completed_task.pages_awarded tasks.append( IncentiveTaskInfo( task_type=task_type, title=config["title"], description=config["description"], - credit_micros_reward=config["credit_micros_reward"], + pages_reward=config["pages_reward"], action_url=config["action_url"], completed=is_completed, completed_at=completed_task.completed_at if completed_task else None, @@ -65,7 +65,7 @@ async def get_incentive_tasks( return IncentiveTasksResponse( tasks=tasks, - total_credit_micros_earned=total_credit_micros_earned, + total_pages_earned=total_pages_earned, ) @@ -79,10 +79,10 @@ async def complete_task( session: AsyncSession = Depends(get_async_session), ) -> CompleteTaskResponse | TaskAlreadyCompletedResponse: """ - Mark an incentive task as completed and award credit to the user. + Mark an incentive task as completed and award pages to the user. Each task can only be completed once. If the task was already completed, - returns the existing completion information without awarding additional credit. + returns the existing completion information without awarding additional pages. """ # Validate task type exists in config task_config = INCENTIVE_TASKS_CONFIG.get(task_type) @@ -109,23 +109,25 @@ async def complete_task( ) # Create the task completion record - credit_micros_reward = task_config["credit_micros_reward"] + pages_reward = task_config["pages_reward"] new_task = UserIncentiveTask( user_id=user.id, task_type=task_type, - credit_micros_awarded=credit_micros_reward, + pages_awarded=pages_reward, ) session.add(new_task) - # Add the reward directly to the user's spendable wallet balance. - user.credit_micros_balance = user.credit_micros_balance + credit_micros_reward + # pages_used can exceed pages_limit when a document's final page count is + # determined after processing. Base the new limit on the higher of the two + # so the rewarded pages are fully usable above the current high-water mark. + user.pages_limit = max(user.pages_used, user.pages_limit) + pages_reward await session.commit() await session.refresh(user) return CompleteTaskResponse( success=True, - message=f"Task completed! You earned ${credit_micros_reward / 1_000_000:.2f} of credit.", - credit_micros_awarded=credit_micros_reward, - new_balance_micros=user.credit_micros_balance, + message=f"Task completed! You earned {pages_reward} pages.", + pages_awarded=pages_reward, + new_pages_limit=user.pages_limit, ) diff --git a/surfsense_backend/app/routes/model_connections_routes.py b/surfsense_backend/app/routes/model_connections_routes.py deleted file mode 100644 index 4d32a32af..000000000 --- a/surfsense_backend/app/routes/model_connections_routes.py +++ /dev/null @@ -1,811 +0,0 @@ -import logging - -from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy import select, update -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import selectinload - -from app.config import config -from app.db import ( - Connection, - ConnectionScope, - Model, - ModelSource, - NewChatThread, - Permission, - SearchSpace, - User, - get_async_session, -) -from app.schemas import ( - ConnectionCreate, - ConnectionRead, - ConnectionUpdate, - ModelCreate, - ModelPreviewRead, - ModelProviderRead, - ModelRead, - ModelRolesRead, - ModelRolesUpdate, - ModelsBulkUpdate, - ModelSelection, - ModelTestPreview, - ModelUpdate, - VerifyConnectionResponse, -) -from app.services.model_capabilities import has_capability -from app.services.model_connection_service import ( - ModelDiscoveryError, - derive_capabilities, - discover_models, - test_model, - verify_connection, -) -from app.services.provider_registry import REGISTRY -from app.users import current_active_user -from app.utils.rbac import check_permission - -router = APIRouter() -logger = logging.getLogger(__name__) - - -def _model_read(model: Model | dict) -> ModelRead: - return ModelRead.model_validate(model) - - -def _preview_model_read(item: dict) -> ModelPreviewRead: - return ModelPreviewRead( - model_id=item["model_id"], - display_name=item.get("display_name"), - source=item.get("source", ModelSource.DISCOVERED), - supports_chat=item.get("supports_chat"), - max_input_tokens=item.get("max_input_tokens"), - supports_image_input=item.get("supports_image_input"), - supports_tools=item.get("supports_tools"), - supports_image_generation=item.get("supports_image_generation"), - enabled=item.get("enabled", False), - metadata=item.get("metadata") or item.get("catalog") or {}, - ) - - -def _connection_read( - conn: Connection | dict, models: list[Model | dict] | None = None -) -> ConnectionRead: - if isinstance(conn, dict): - payload = { - **conn, - "has_api_key": bool(conn.get("api_key")), - "api_key": None, - "models": [_model_read(model) for model in (models or [])], - } - payload.pop("api_key", None) - return ConnectionRead.model_validate(payload) - - return ConnectionRead( - id=conn.id, - provider=conn.provider, - base_url=conn.base_url, - api_key=conn.api_key, - extra=conn.extra or {}, - scope=conn.scope, - search_space_id=conn.search_space_id, - user_id=conn.user_id, - enabled=conn.enabled, - has_api_key=bool(conn.api_key), - models=[_model_read(model) for model in (models or [])], - created_at=conn.created_at, - ) - - -def _apply_model_facts(model: Model, facts: dict) -> None: - model.supports_chat = facts.get("supports_chat") - model.max_input_tokens = facts.get("max_input_tokens") - model.supports_image_input = facts.get("supports_image_input") - model.supports_tools = facts.get("supports_tools") - model.supports_image_generation = facts.get("supports_image_generation") - - -def _complete_selection_facts(conn: Connection, selection: ModelSelection) -> dict: - facts = selection.model_dump() - derived = derive_capabilities(conn, selection.model_id.strip(), selection.metadata) - for key, value in derived.items(): - if facts.get(key) is None: - facts[key] = value - return facts - - -def _selection_to_model(conn: Connection, selection: ModelSelection) -> Model: - source = ( - selection.source - if isinstance(selection.source, ModelSource) - else ModelSource(selection.source) - ) - model = Model( - connection_id=conn.id, - model_id=selection.model_id.strip(), - display_name=selection.display_name, - source=source, - capabilities_override={}, - enabled=selection.enabled, - catalog=selection.metadata, - ) - _apply_model_facts(model, _complete_selection_facts(conn, selection)) - return model - - -def _default_model_for(models: list[Model], capability: str) -> int | None: - for model in models: - if model.enabled and has_capability(model, capability): - return model.id - return None - - -async def _load_role_model( - session: AsyncSession, - search_space_id: int, - model_id: int, -) -> Model | dict | None: - if model_id < 0: - return next( - (model for model in config.GLOBAL_MODELS if model.get("id") == model_id), - None, - ) - - result = await session.execute( - select(Model) - .options(selectinload(Model.connection)) - .where(Model.id == model_id) - ) - model = result.scalars().first() - if model is None or model.connection.search_space_id != search_space_id: - return None - return model - - -def _role_model_enabled(model: Model | dict) -> bool: - if isinstance(model, dict): - return bool(model.get("enabled", True)) - return bool(model.enabled and model.connection.enabled) - - -async def _validate_role_model_id( - session: AsyncSession, - *, - search_space_id: int, - model_id: int | None, - capability: str, -) -> int: - if model_id is None or model_id == 0: - return 0 - - model = await _load_role_model(session, search_space_id, model_id) - if model and _role_model_enabled(model) and has_capability(model, capability): - return model_id - - raise HTTPException( - status_code=400, - detail=f"Selected model is not available for {capability}", - ) - - -async def _resolve_role_model_id( - session: AsyncSession, - *, - search_space_id: int, - model_id: int | None, - capability: str, -) -> int: - try: - return await _validate_role_model_id( - session, - search_space_id=search_space_id, - model_id=model_id, - capability=capability, - ) - except HTTPException: - return 0 - - -async def _clear_invalid_roles( - session: AsyncSession, search_space_id: int -) -> SearchSpace: - search_space = await _get_search_space(session, search_space_id) - search_space.chat_model_id = await _resolve_role_model_id( - session, - search_space_id=search_space_id, - model_id=search_space.chat_model_id, - capability="chat", - ) - search_space.vision_model_id = await _resolve_role_model_id( - session, - search_space_id=search_space_id, - model_id=search_space.vision_model_id, - capability="vision", - ) - search_space.image_gen_model_id = await _resolve_role_model_id( - session, - search_space_id=search_space_id, - model_id=search_space.image_gen_model_id, - capability="image_gen", - ) - return search_space - - -async def _default_unset_roles( - session: AsyncSession, - conn: Connection, - models: list[Model], -) -> None: - if conn.scope != ConnectionScope.SEARCH_SPACE or conn.search_space_id is None: - return - search_space = await _get_search_space(session, conn.search_space_id) - if search_space.chat_model_id is None: - search_space.chat_model_id = _default_model_for(models, "chat") - if search_space.vision_model_id is None: - vision_default = None - if search_space.chat_model_id: - chat_model = next( - (m for m in models if m.id == search_space.chat_model_id), None - ) - if chat_model and has_capability(chat_model, "vision"): - vision_default = chat_model.id - search_space.vision_model_id = vision_default or _default_model_for( - models, "vision" - ) - if search_space.image_gen_model_id is None: - search_space.image_gen_model_id = _default_model_for(models, "image_gen") - - -@router.get("/model-providers", response_model=list[ModelProviderRead]) -async def list_model_providers(user: User = Depends(current_active_user)): - del user - local_only = {"ollama_chat", "lm_studio"} - return [ - ModelProviderRead( - provider=provider, - transport=spec.transport.value, - discovery=spec.discovery, - default_base_url=spec.default_base_url, - base_url_required=spec.base_url_required, - auth_style=spec.auth_style, - local_only=provider in local_only, - ) - for provider, spec in sorted(REGISTRY.items()) - ] - - -async def _get_search_space(session: AsyncSession, search_space_id: int) -> SearchSpace: - result = await session.execute( - select(SearchSpace).where(SearchSpace.id == search_space_id) - ) - search_space = result.scalars().first() - if not search_space: - raise HTTPException(status_code=404, detail="Search space not found") - return search_space - - -async def _load_connection(session: AsyncSession, connection_id: int) -> Connection: - result = await session.execute( - select(Connection) - .options(selectinload(Connection.models)) - .where(Connection.id == connection_id) - ) - conn = result.scalars().first() - if not conn: - raise HTTPException(status_code=404, detail="Connection not found") - return conn - - -async def _assert_connection_access( - session: AsyncSession, - user: User, - conn: Connection, - permission: str = Permission.LLM_CONFIGS_CREATE.value, -) -> None: - if conn.search_space_id: - await check_permission( - session, - user, - conn.search_space_id, - permission, - "You don't have permission to manage model connections in this search space", - ) - return - if conn.user_id != user.id: - raise HTTPException( - status_code=403, detail="Connection does not belong to user" - ) - - -@router.get("/global-llm-config-status") -async def global_llm_config_status(user: User = Depends(current_active_user)): - del user - return {"exists": config.GLOBAL_LLM_CONFIG_FILE_EXISTS} - - -@router.get("/global-model-connections", response_model=list[ConnectionRead]) -async def list_global_connections(user: User = Depends(current_active_user)): - del user - models_by_connection: dict[int, list[dict]] = {} - for model in config.GLOBAL_MODELS: - models_by_connection.setdefault(model["connection_id"], []).append(model) - return [ - _connection_read(conn, models_by_connection.get(conn["id"], [])) - for conn in config.GLOBAL_CONNECTIONS - ] - - -@router.get("/model-connections", response_model=list[ConnectionRead]) -async def list_connections( - search_space_id: int | None = None, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - stmt = select(Connection).options(selectinload(Connection.models)) - if search_space_id is not None: - await check_permission( - session, - user, - search_space_id, - Permission.LLM_CONFIGS_CREATE.value, - "You don't have permission to view model connections in this search space", - ) - stmt = stmt.where(Connection.search_space_id == search_space_id) - else: - stmt = stmt.where(Connection.user_id == user.id) - result = await session.execute(stmt.order_by(Connection.id)) - return [ - _connection_read(conn, list(conn.models)) for conn in result.scalars().all() - ] - - -@router.post("/model-connections", response_model=ConnectionRead) -async def create_connection( - data: ConnectionCreate, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - if data.scope == ConnectionScope.GLOBAL: - raise HTTPException(status_code=400, detail="GLOBAL connections are YAML-only") - if data.scope == ConnectionScope.SEARCH_SPACE: - if data.search_space_id is None: - raise HTTPException(status_code=400, detail="search_space_id is required") - await check_permission( - session, - user, - data.search_space_id, - Permission.LLM_CONFIGS_CREATE.value, - "You don't have permission to create model connections in this search space", - ) - payload = data.model_dump(exclude={"search_space_id", "models"}) - - conn = Connection( - **payload, - search_space_id=data.search_space_id - if data.scope == ConnectionScope.SEARCH_SPACE - else None, - user_id=user.id, - ) - session.add(conn) - await session.flush() - - seen_model_ids: set[str] = set() - for selection in data.models: - model_id = selection.model_id.strip() - if not model_id or model_id in seen_model_ids: - continue - seen_model_ids.add(model_id) - session.add(_selection_to_model(conn, selection)) - - await session.commit() - conn = await _load_connection(session, conn.id) - await _default_unset_roles(session, conn, list(conn.models)) - await session.commit() - conn = await _load_connection(session, conn.id) - return _connection_read(conn, list(conn.models)) - - -@router.post( - "/model-connections/discover-preview", response_model=list[ModelPreviewRead] -) -async def preview_connection_models( - data: ConnectionCreate, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - if data.scope == ConnectionScope.SEARCH_SPACE and data.search_space_id is not None: - await check_permission( - session, - user, - data.search_space_id, - Permission.LLM_CONFIGS_CREATE.value, - "You don't have permission to create model connections in this search space", - ) - - draft = Connection( - provider=data.provider, - base_url=data.base_url, - api_key=data.api_key, - extra=data.extra or {}, - scope=data.scope, - enabled=data.enabled, - search_space_id=data.search_space_id - if data.scope == ConnectionScope.SEARCH_SPACE - else None, - user_id=user.id, - ) - try: - discovered = await discover_models(draft) - except ModelDiscoveryError as exc: - raise HTTPException(status_code=400, detail=str(exc)) from exc - return [_preview_model_read(item) for item in discovered] - - -@router.post("/model-connections/test-preview", response_model=VerifyConnectionResponse) -async def test_preview_connection_model( - data: ModelTestPreview, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - if data.scope == ConnectionScope.SEARCH_SPACE and data.search_space_id is not None: - await check_permission( - session, - user, - data.search_space_id, - Permission.LLM_CONFIGS_CREATE.value, - "You don't have permission to create model connections in this search space", - ) - - model_id = data.model_id.strip() - if not model_id: - raise HTTPException(status_code=400, detail="model_id is required") - - draft = Connection( - provider=data.provider, - base_url=data.base_url, - api_key=data.api_key, - extra=data.extra or {}, - scope=data.scope, - enabled=data.enabled, - search_space_id=data.search_space_id - if data.scope == ConnectionScope.SEARCH_SPACE - else None, - user_id=user.id, - ) - model = Model( - connection_id=0, - model_id=model_id, - source=ModelSource.MANUAL, - enabled=True, - capabilities_override={}, - catalog={}, - ) - result = await test_model(draft, model) - return VerifyConnectionResponse( - status=result.status, ok=result.ok, message=result.message - ) - - -@router.put("/model-connections/{connection_id}", response_model=ConnectionRead) -async def update_connection( - connection_id: int, - data: ConnectionUpdate, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - conn = await _load_connection(session, connection_id) - await _assert_connection_access( - session, user, conn, Permission.LLM_CONFIGS_UPDATE.value - ) - search_space_id = conn.search_space_id - for key, value in data.model_dump(exclude_unset=True).items(): - setattr(conn, key, value) - await session.commit() - if search_space_id is not None: - await _clear_invalid_roles(session, search_space_id) - await session.commit() - conn = await _load_connection(session, connection_id) - return _connection_read(conn, list(conn.models)) - - -@router.delete("/model-connections/{connection_id}") -async def delete_connection( - connection_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - conn = await _load_connection(session, connection_id) - await _assert_connection_access( - session, user, conn, Permission.LLM_CONFIGS_DELETE.value - ) - search_space_id = conn.search_space_id - await session.delete(conn) - await session.commit() - if search_space_id is not None: - await _clear_invalid_roles(session, search_space_id) - await session.commit() - return {"status": "deleted"} - - -@router.post( - "/model-connections/{connection_id}/verify", response_model=VerifyConnectionResponse -) -async def verify_model_connection( - connection_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - conn = await _load_connection(session, connection_id) - await _assert_connection_access( - session, user, conn, Permission.LLM_CONFIGS_CREATE.value - ) - result = await verify_connection(conn) - return VerifyConnectionResponse( - status=result.status, ok=result.ok, message=result.message - ) - - -@router.post( - "/model-connections/{connection_id}/discover", response_model=list[ModelRead] -) -async def discover_connection_models( - connection_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - conn = await _load_connection(session, connection_id) - await _assert_connection_access( - session, user, conn, Permission.LLM_CONFIGS_CREATE.value - ) - try: - discovered = await discover_models(conn) - except ModelDiscoveryError as exc: - raise HTTPException(status_code=400, detail=str(exc)) from exc - by_model_id = {model.model_id: model for model in conn.models} - for item in discovered: - db_model = by_model_id.get(item["model_id"]) - if db_model is None: - db_model = Model( - connection_id=conn.id, - model_id=item["model_id"], - display_name=item.get("display_name"), - source=item["source"], - capabilities_override={}, - enabled=False, - catalog=item.get("metadata") or {}, - ) - _apply_model_facts(db_model, item) - session.add(db_model) - else: - db_model.display_name = item.get("display_name") or db_model.display_name - _apply_model_facts(db_model, item) - db_model.catalog = item.get("metadata") or db_model.catalog - await session.commit() - conn = await _load_connection(session, connection_id) - await _default_unset_roles(session, conn, list(conn.models)) - if conn.search_space_id is not None: - await _clear_invalid_roles(session, conn.search_space_id) - await session.commit() - conn = await _load_connection(session, connection_id) - return [_model_read(model) for model in conn.models] - - -@router.post("/model-connections/{connection_id}/models", response_model=ModelRead) -async def add_manual_model( - connection_id: int, - data: ModelCreate, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - conn = await _load_connection(session, connection_id) - await _assert_connection_access( - session, user, conn, Permission.LLM_CONFIGS_UPDATE.value - ) - - model_id = data.model_id.strip() - if not model_id: - raise HTTPException(status_code=400, detail="model_id is required") - if any(existing.model_id == model_id for existing in conn.models): - raise HTTPException( - status_code=400, detail="Model already exists on this connection" - ) - - capabilities = derive_capabilities(conn, model_id) - model = Model( - connection_id=conn.id, - model_id=model_id, - display_name=data.display_name or None, - source=ModelSource.MANUAL, - capabilities_override={}, - enabled=True, - catalog={}, - ) - _apply_model_facts(model, capabilities) - session.add(model) - await session.commit() - await session.refresh(model) - conn = await _load_connection(session, connection_id) - await _default_unset_roles(session, conn, list(conn.models)) - if conn.search_space_id is not None: - await _clear_invalid_roles(session, conn.search_space_id) - await session.commit() - await session.refresh(model) - return _model_read(model) - - -@router.patch( - "/model-connections/{connection_id}/models", response_model=list[ModelRead] -) -async def bulk_update_models( - connection_id: int, - data: ModelsBulkUpdate, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - conn = await _load_connection(session, connection_id) - await _assert_connection_access( - session, user, conn, Permission.LLM_CONFIGS_UPDATE.value - ) - search_space_id = conn.search_space_id - - model_ids = set(data.model_ids) - await session.execute( - update(Model) - .where(Model.connection_id == connection_id, Model.id.in_(model_ids)) - .values(enabled=data.enabled) - ) - await session.commit() - session.expire_all() - if search_space_id is not None: - await _clear_invalid_roles(session, search_space_id) - await session.commit() - session.expire_all() - - result = await session.execute( - select(Model) - .where(Model.connection_id == connection_id, Model.id.in_(model_ids)) - .order_by(Model.id) - ) - return [_model_read(model) for model in result.scalars().all()] - - -@router.put("/models/{model_id}", response_model=ModelRead) -async def update_model( - model_id: int, - data: ModelUpdate, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - result = await session.execute( - select(Model) - .options(selectinload(Model.connection)) - .where(Model.id == model_id) - ) - model = result.scalars().first() - if not model: - raise HTTPException(status_code=404, detail="Model not found") - await _assert_connection_access( - session, user, model.connection, Permission.LLM_CONFIGS_UPDATE.value - ) - search_space_id = model.connection.search_space_id - update = data.model_dump(exclude_unset=True) - for key, value in update.items(): - setattr(model, key, value) - await session.commit() - await session.refresh(model) - if search_space_id is not None: - await _clear_invalid_roles(session, search_space_id) - await session.commit() - await session.refresh(model) - return _model_read(model) - - -@router.post("/models/{model_id}/test", response_model=VerifyConnectionResponse) -async def test_connection_model( - model_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - result = await session.execute( - select(Model) - .options(selectinload(Model.connection)) - .where(Model.id == model_id) - ) - model = result.scalars().first() - if not model: - raise HTTPException(status_code=404, detail="Model not found") - await _assert_connection_access( - session, user, model.connection, Permission.LLM_CONFIGS_UPDATE.value - ) - result = await test_model(model.connection, model) - await session.commit() - return VerifyConnectionResponse( - status=result.status, ok=result.ok, message=result.message - ) - - -@router.get( - "/search-spaces/{search_space_id}/model-roles", response_model=ModelRolesRead -) -async def get_model_roles( - search_space_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - await check_permission( - session, - user, - search_space_id, - Permission.LLM_CONFIGS_CREATE.value, - "You don't have permission to view model roles in this search space", - ) - search_space = await _clear_invalid_roles(session, search_space_id) - await session.commit() - await session.refresh(search_space) - return ModelRolesRead( - chat_model_id=search_space.chat_model_id, - vision_model_id=search_space.vision_model_id, - image_gen_model_id=search_space.image_gen_model_id, - ) - - -@router.put( - "/search-spaces/{search_space_id}/model-roles", response_model=ModelRolesRead -) -async def update_model_roles( - search_space_id: int, - data: ModelRolesUpdate, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - await check_permission( - session, - user, - search_space_id, - Permission.LLM_CONFIGS_UPDATE.value, - "You don't have permission to update model roles in this search space", - ) - search_space = await _get_search_space(session, search_space_id) - updates = data.model_dump(exclude_unset=True) - if "chat_model_id" in updates: - previous_chat_model_id = search_space.chat_model_id - next_chat_model_id = await _validate_role_model_id( - session, - search_space_id=search_space_id, - model_id=updates["chat_model_id"], - capability="chat", - ) - search_space.chat_model_id = next_chat_model_id - if next_chat_model_id != previous_chat_model_id: - await session.execute( - update(NewChatThread) - .where(NewChatThread.search_space_id == search_space_id) - .values(pinned_llm_config_id=None) - ) - logger.info( - "Cleared auto model pins for search_space_id=%s after chat_model_id change (%s -> %s)", - search_space_id, - previous_chat_model_id, - next_chat_model_id, - ) - if "vision_model_id" in updates: - search_space.vision_model_id = await _validate_role_model_id( - session, - search_space_id=search_space_id, - model_id=updates["vision_model_id"], - capability="vision", - ) - if "image_gen_model_id" in updates: - search_space.image_gen_model_id = await _validate_role_model_id( - session, - search_space_id=search_space_id, - model_id=updates["image_gen_model_id"], - capability="image_gen", - ) - await session.commit() - await session.refresh(search_space) - return ModelRolesRead( - chat_model_id=search_space.chat_model_id, - vision_model_id=search_space.vision_model_id, - image_gen_model_id=search_space.image_gen_model_id, - ) diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index b5bc2571e..0e4e557be 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -1741,11 +1741,12 @@ async def handle_new_chat( if not search_space: raise HTTPException(status_code=404, detail="Search space not found") - # Use the converged model-connections role for chat operations. - # Positive IDs load Model + Connection rows; negative IDs load - # virtual GLOBAL models; 0 means Auto. + # Use agent_llm_id from search space for chat operations + # Positive IDs load from NewLLMConfig database table + # Negative IDs load from YAML global configs + # Falls back to -1 (first global config) if not configured llm_config_id = ( - search_space.chat_model_id if search_space.chat_model_id is not None else 0 + search_space.agent_llm_id if search_space.agent_llm_id is not None else -1 ) # Release the read-transaction so we don't hold ACCESS SHARE locks @@ -2227,7 +2228,7 @@ async def regenerate_response( raise HTTPException(status_code=404, detail="Search space not found") llm_config_id = ( - search_space.chat_model_id if search_space.chat_model_id is not None else 0 + search_space.agent_llm_id if search_space.agent_llm_id is not None else -1 ) # Release the read-transaction so we don't hold ACCESS SHARE locks @@ -2392,7 +2393,7 @@ async def resume_chat( raise HTTPException(status_code=404, detail="Search space not found") llm_config_id = ( - search_space.chat_model_id if search_space.chat_model_id is not None else 0 + search_space.agent_llm_id if search_space.agent_llm_id is not None else -1 ) decisions = [d.model_dump() for d in request.decisions] diff --git a/surfsense_backend/app/routes/new_llm_config_routes.py b/surfsense_backend/app/routes/new_llm_config_routes.py new file mode 100644 index 000000000..84d66bb13 --- /dev/null +++ b/surfsense_backend/app/routes/new_llm_config_routes.py @@ -0,0 +1,480 @@ +""" +API routes for NewLLMConfig CRUD operations. + +NewLLMConfig combines model settings with prompt configuration: +- LLM provider, model, API key, etc. +- Configurable system instructions +- Citation toggle +""" + +import logging + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + +from app.config import config +from app.db import ( + NewLLMConfig, + Permission, + User, + get_async_session, +) +from app.prompts.default_system_instructions import get_default_system_instructions +from app.schemas import ( + DefaultSystemInstructionsResponse, + GlobalNewLLMConfigRead, + NewLLMConfigCreate, + NewLLMConfigRead, + NewLLMConfigUpdate, +) +from app.services.llm_service import validate_llm_config +from app.services.provider_capabilities import derive_supports_image_input +from app.users import current_active_user +from app.utils.rbac import check_permission + +router = APIRouter() +logger = logging.getLogger(__name__) + + +def _serialize_byok_config(config: NewLLMConfig) -> NewLLMConfigRead: + """Augment a BYOK chat config row with the derived ``supports_image_input``. + + There is no DB column for ``supports_image_input`` — the value is + resolved at the API boundary from LiteLLM's authoritative model map + (default-allow on unknown). Returning ``NewLLMConfigRead`` here keeps + the response shape consistent across list / detail / create / update + endpoints without having to remember to set the field at every call + site. + """ + provider_value = ( + config.provider.value + if hasattr(config.provider, "value") + else str(config.provider) + ) + litellm_params = config.litellm_params or {} + base_model = ( + litellm_params.get("base_model") if isinstance(litellm_params, dict) else None + ) + supports_image_input = derive_supports_image_input( + provider=provider_value, + model_name=config.model_name, + base_model=base_model, + custom_provider=config.custom_provider, + ) + # ``model_validate`` runs the Pydantic conversion using the ORM + # attribute access path enabled by ``ConfigDict(from_attributes=True)``, + # then we layer the derived field on. ``model_copy(update=...)`` keeps + # the surface immutable from the caller's perspective. + base_read = NewLLMConfigRead.model_validate(config) + return base_read.model_copy(update={"supports_image_input": supports_image_input}) + + +# ============================================================================= +# Global Configs Routes +# ============================================================================= + + +@router.get("/global-new-llm-configs", response_model=list[GlobalNewLLMConfigRead]) +async def get_global_new_llm_configs( + user: User = Depends(current_active_user), +): + """ + Get all available global NewLLMConfig configurations. + These are pre-configured by the system administrator and available to all users. + API keys are not exposed through this endpoint. + + Includes: + - Auto mode (ID 0): Uses LiteLLM Router for automatic load balancing + - Global configs (negative IDs): Individual pre-configured LLM providers + """ + try: + global_configs = config.GLOBAL_LLM_CONFIGS + safe_configs = [] + + # Only include Auto mode if there are actual global configs to route to + # Auto mode requires at least one global config with valid API key + if global_configs and len(global_configs) > 0: + safe_configs.append( + { + "id": 0, + "name": "Auto (Fastest)", + "description": "Automatically routes requests across available LLM providers for optimal performance and rate limit handling. Recommended for most users.", + "provider": "AUTO", + "custom_provider": None, + "model_name": "auto", + "api_base": None, + "litellm_params": {}, + "system_instructions": "", + "use_default_system_instructions": True, + "citations_enabled": True, + "is_global": True, + "is_auto_mode": True, + "billing_tier": "free", + "is_premium": False, + "anonymous_enabled": False, + "seo_enabled": False, + "seo_slug": None, + "seo_title": None, + "seo_description": None, + "quota_reserve_tokens": None, + # Auto routes across the configured pool, which usually + # includes at least one vision-capable deployment, so + # treat Auto as image-capable. The router itself will + # still pick a vision-capable deployment for messages + # carrying image_url blocks (LiteLLM Router falls back + # on ``404`` per its ``allowed_fails`` policy). + "supports_image_input": True, + } + ) + + # Add individual global configs + for cfg in global_configs: + # Capability resolution: explicit value (YAML override or OR + # `_supports_image_input(model)` payload baked in by the + # OpenRouter integration service) wins. Fall back to the + # LiteLLM-driven helper which default-allows on unknown so + # we don't hide vision-capable models that happen to lack a + # YAML annotation. The streaming task safety net is the + # only place a False ever blocks. + if "supports_image_input" in cfg: + supports_image_input = bool(cfg.get("supports_image_input")) + else: + cfg_litellm_params = cfg.get("litellm_params") or {} + cfg_base_model = ( + cfg_litellm_params.get("base_model") + if isinstance(cfg_litellm_params, dict) + else None + ) + supports_image_input = derive_supports_image_input( + provider=cfg.get("provider"), + model_name=cfg.get("model_name"), + base_model=cfg_base_model, + custom_provider=cfg.get("custom_provider"), + ) + + safe_config = { + "id": cfg.get("id"), + "name": cfg.get("name"), + "description": cfg.get("description"), + "provider": cfg.get("provider"), + "custom_provider": cfg.get("custom_provider"), + "model_name": cfg.get("model_name"), + "api_base": cfg.get("api_base") or None, + "litellm_params": cfg.get("litellm_params", {}), + # New prompt configuration fields + "system_instructions": cfg.get("system_instructions", ""), + "use_default_system_instructions": cfg.get( + "use_default_system_instructions", True + ), + "citations_enabled": cfg.get("citations_enabled", True), + "is_global": True, + "billing_tier": cfg.get("billing_tier", "free"), + "is_premium": cfg.get("billing_tier", "free") == "premium", + "anonymous_enabled": cfg.get("anonymous_enabled", False), + "seo_enabled": cfg.get("seo_enabled", False), + "seo_slug": cfg.get("seo_slug"), + "seo_title": cfg.get("seo_title"), + "seo_description": cfg.get("seo_description"), + "quota_reserve_tokens": cfg.get("quota_reserve_tokens"), + "supports_image_input": supports_image_input, + } + safe_configs.append(safe_config) + + return safe_configs + except Exception as e: + logger.exception("Failed to fetch global NewLLMConfigs") + raise HTTPException( + status_code=500, detail=f"Failed to fetch global configurations: {e!s}" + ) from e + + +# ============================================================================= +# CRUD Routes +# ============================================================================= + + +@router.post("/new-llm-configs", response_model=NewLLMConfigRead) +async def create_new_llm_config( + config_data: NewLLMConfigCreate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Create a new NewLLMConfig for a search space. + Requires LLM_CONFIGS_CREATE permission. + """ + try: + # Verify user has permission + await check_permission( + session, + user, + config_data.search_space_id, + Permission.LLM_CONFIGS_CREATE.value, + "You don't have permission to create LLM configurations in this search space", + ) + + # Validate the LLM configuration by making a test API call + is_valid, error_message = await validate_llm_config( + provider=config_data.provider.value, + model_name=config_data.model_name, + api_key=config_data.api_key, + api_base=config_data.api_base, + custom_provider=config_data.custom_provider, + litellm_params=config_data.litellm_params, + ) + + if not is_valid: + raise HTTPException( + status_code=400, + detail=f"Invalid LLM configuration: {error_message}", + ) + + # Create the config with user association + db_config = NewLLMConfig(**config_data.model_dump(), user_id=user.id) + session.add(db_config) + await session.commit() + await session.refresh(db_config) + + return _serialize_byok_config(db_config) + + except HTTPException: + raise + except Exception as e: + await session.rollback() + logger.exception("Failed to create NewLLMConfig") + raise HTTPException( + status_code=500, detail=f"Failed to create configuration: {e!s}" + ) from e + + +@router.get("/new-llm-configs", response_model=list[NewLLMConfigRead]) +async def list_new_llm_configs( + search_space_id: int, + skip: int = 0, + limit: int = 100, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Get all NewLLMConfigs for a search space. + Requires LLM_CONFIGS_READ permission. + """ + try: + # Verify user has permission + await check_permission( + session, + user, + search_space_id, + Permission.LLM_CONFIGS_READ.value, + "You don't have permission to view LLM configurations in this search space", + ) + + result = await session.execute( + select(NewLLMConfig) + .filter(NewLLMConfig.search_space_id == search_space_id) + .order_by(NewLLMConfig.created_at.desc()) + .offset(skip) + .limit(limit) + ) + + return [_serialize_byok_config(cfg) for cfg in result.scalars().all()] + + except HTTPException: + raise + except Exception as e: + logger.exception("Failed to list NewLLMConfigs") + raise HTTPException( + status_code=500, detail=f"Failed to fetch configurations: {e!s}" + ) from e + + +@router.get( + "/new-llm-configs/default-system-instructions", + response_model=DefaultSystemInstructionsResponse, +) +async def get_default_system_instructions_endpoint( + user: User = Depends(current_active_user), +): + """ + Get the default SURFSENSE_SYSTEM_INSTRUCTIONS template. + Useful for pre-populating the UI when creating a new configuration. + """ + return DefaultSystemInstructionsResponse( + default_system_instructions=get_default_system_instructions() + ) + + +@router.get("/new-llm-configs/{config_id}", response_model=NewLLMConfigRead) +async def get_new_llm_config( + config_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Get a specific NewLLMConfig by ID. + Requires LLM_CONFIGS_READ permission. + """ + try: + result = await session.execute( + select(NewLLMConfig).filter(NewLLMConfig.id == config_id) + ) + config = result.scalars().first() + + if not config: + raise HTTPException(status_code=404, detail="Configuration not found") + + # Verify user has permission + await check_permission( + session, + user, + config.search_space_id, + Permission.LLM_CONFIGS_READ.value, + "You don't have permission to view LLM configurations in this search space", + ) + + return _serialize_byok_config(config) + + except HTTPException: + raise + except Exception as e: + logger.exception("Failed to get NewLLMConfig") + raise HTTPException( + status_code=500, detail=f"Failed to fetch configuration: {e!s}" + ) from e + + +@router.put("/new-llm-configs/{config_id}", response_model=NewLLMConfigRead) +async def update_new_llm_config( + config_id: int, + update_data: NewLLMConfigUpdate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Update an existing NewLLMConfig. + Requires LLM_CONFIGS_UPDATE permission. + """ + try: + result = await session.execute( + select(NewLLMConfig).filter(NewLLMConfig.id == config_id) + ) + config = result.scalars().first() + + if not config: + raise HTTPException(status_code=404, detail="Configuration not found") + + # Verify user has permission + await check_permission( + session, + user, + config.search_space_id, + Permission.LLM_CONFIGS_UPDATE.value, + "You don't have permission to update LLM configurations in this search space", + ) + + update_dict = update_data.model_dump(exclude_unset=True) + + # If updating LLM settings, validate them + if any( + key in update_dict + for key in [ + "provider", + "model_name", + "api_key", + "api_base", + "custom_provider", + "litellm_params", + ] + ): + # Build the validation config from existing + updates + validation_config = { + "provider": update_dict.get("provider", config.provider).value + if hasattr(update_dict.get("provider", config.provider), "value") + else update_dict.get("provider", config.provider.value), + "model_name": update_dict.get("model_name", config.model_name), + "api_key": update_dict.get("api_key", config.api_key), + "api_base": update_dict.get("api_base", config.api_base), + "custom_provider": update_dict.get( + "custom_provider", config.custom_provider + ), + "litellm_params": update_dict.get( + "litellm_params", config.litellm_params + ), + } + + is_valid, error_message = await validate_llm_config( + provider=validation_config["provider"], + model_name=validation_config["model_name"], + api_key=validation_config["api_key"], + api_base=validation_config["api_base"], + custom_provider=validation_config["custom_provider"], + litellm_params=validation_config["litellm_params"], + ) + + if not is_valid: + raise HTTPException( + status_code=400, + detail=f"Invalid LLM configuration: {error_message}", + ) + + # Apply updates + for key, value in update_dict.items(): + setattr(config, key, value) + + await session.commit() + await session.refresh(config) + + return _serialize_byok_config(config) + + except HTTPException: + raise + except Exception as e: + await session.rollback() + logger.exception("Failed to update NewLLMConfig") + raise HTTPException( + status_code=500, detail=f"Failed to update configuration: {e!s}" + ) from e + + +@router.delete("/new-llm-configs/{config_id}", response_model=dict) +async def delete_new_llm_config( + config_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Delete a NewLLMConfig. + Requires LLM_CONFIGS_DELETE permission. + """ + try: + result = await session.execute( + select(NewLLMConfig).filter(NewLLMConfig.id == config_id) + ) + config = result.scalars().first() + + if not config: + raise HTTPException(status_code=404, detail="Configuration not found") + + # Verify user has permission + await check_permission( + session, + user, + config.search_space_id, + Permission.LLM_CONFIGS_DELETE.value, + "You don't have permission to delete LLM configurations in this search space", + ) + + await session.delete(config) + await session.commit() + + return {"message": "Configuration deleted successfully", "id": config_id} + + except HTTPException: + raise + except Exception as e: + await session.rollback() + logger.exception("Failed to delete NewLLMConfig") + raise HTTPException( + status_code=500, detail=f"Failed to delete configuration: {e!s}" + ) from e diff --git a/surfsense_backend/app/routes/podcasts_routes.py b/surfsense_backend/app/routes/podcasts_routes.py new file mode 100644 index 000000000..f991f698f --- /dev/null +++ b/surfsense_backend/app/routes/podcasts_routes.py @@ -0,0 +1,211 @@ +""" +Podcast routes for CRUD operations and audio streaming. + +These routes support the podcast generation feature in new-chat. +Frontend polls GET /podcasts/{podcast_id} to check status field. +""" + +import os +from pathlib import Path + +from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import StreamingResponse +from sqlalchemy import select +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import ( + Permission, + Podcast, + SearchSpace, + SearchSpaceMembership, + User, + get_async_session, +) +from app.schemas import PodcastRead +from app.users import current_active_user +from app.utils.rbac import check_permission + +router = APIRouter() + + +@router.get("/podcasts", response_model=list[PodcastRead]) +async def read_podcasts( + skip: int = 0, + limit: int = 100, + search_space_id: int | None = None, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + List podcasts the user has access to. + Requires PODCASTS_READ permission for the search space(s). + """ + if skip < 0 or limit < 1: + raise HTTPException(status_code=400, detail="Invalid pagination parameters") + try: + if search_space_id is not None: + # Check permission for specific search space + await check_permission( + session, + user, + search_space_id, + Permission.PODCASTS_READ.value, + "You don't have permission to read podcasts in this search space", + ) + result = await session.execute( + select(Podcast) + .filter(Podcast.search_space_id == search_space_id) + .offset(skip) + .limit(limit) + ) + else: + # Get podcasts from all search spaces user has membership in + result = await session.execute( + select(Podcast) + .join(SearchSpace) + .join(SearchSpaceMembership) + .filter(SearchSpaceMembership.user_id == user.id) + .offset(skip) + .limit(limit) + ) + return result.scalars().all() + except HTTPException: + raise + except SQLAlchemyError: + raise HTTPException( + status_code=500, detail="Database error occurred while fetching podcasts" + ) from None + + +@router.get("/podcasts/{podcast_id}", response_model=PodcastRead) +async def read_podcast( + podcast_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Get a specific podcast by ID. + + Requires authentication with PODCASTS_READ permission. + For public podcast access, use /public/{share_token}/podcasts/{podcast_id}/stream + """ + try: + result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id)) + podcast = result.scalars().first() + + if not podcast: + raise HTTPException( + status_code=404, + detail="Podcast not found", + ) + + await check_permission( + session, + user, + podcast.search_space_id, + Permission.PODCASTS_READ.value, + "You don't have permission to read podcasts in this search space", + ) + + return PodcastRead.from_orm_with_entries(podcast) + except HTTPException as he: + raise he + except SQLAlchemyError: + raise HTTPException( + status_code=500, detail="Database error occurred while fetching podcast" + ) from None + + +@router.delete("/podcasts/{podcast_id}", response_model=dict) +async def delete_podcast( + podcast_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Delete a podcast. + Requires PODCASTS_DELETE permission for the search space. + """ + try: + result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id)) + db_podcast = result.scalars().first() + + if not db_podcast: + raise HTTPException(status_code=404, detail="Podcast not found") + + # Check permission for the search space + await check_permission( + session, + user, + db_podcast.search_space_id, + Permission.PODCASTS_DELETE.value, + "You don't have permission to delete podcasts in this search space", + ) + + await session.delete(db_podcast) + await session.commit() + return {"message": "Podcast deleted successfully"} + except HTTPException as he: + raise he + except SQLAlchemyError: + await session.rollback() + raise HTTPException( + status_code=500, detail="Database error occurred while deleting podcast" + ) from None + + +@router.get("/podcasts/{podcast_id}/stream") +@router.get("/podcasts/{podcast_id}/audio") +async def stream_podcast( + podcast_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Stream a podcast audio file. + + Requires authentication with PODCASTS_READ permission. + For public podcast access, use /public/{share_token}/podcasts/{podcast_id}/stream + + Note: Both /stream and /audio endpoints are supported for compatibility. + """ + try: + result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id)) + podcast = result.scalars().first() + + if not podcast: + raise HTTPException(status_code=404, detail="Podcast not found") + + await check_permission( + session, + user, + podcast.search_space_id, + Permission.PODCASTS_READ.value, + "You don't have permission to access podcasts in this search space", + ) + + file_path = podcast.file_location + + if not file_path or not os.path.isfile(file_path): + raise HTTPException(status_code=404, detail="Podcast audio file not found") + + def iterfile(): + with open(file_path, mode="rb") as file_like: + yield from file_like + + return StreamingResponse( + iterfile(), + media_type="audio/mpeg", + headers={ + "Accept-Ranges": "bytes", + "Content-Disposition": f"inline; filename={Path(file_path).name}", + }, + ) + + except HTTPException as he: + raise he + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Error streaming podcast: {e!s}" + ) from e diff --git a/surfsense_backend/app/routes/public_chat_routes.py b/surfsense_backend/app/routes/public_chat_routes.py index 516e976e6..3181e117c 100644 --- a/surfsense_backend/app/routes/public_chat_routes.py +++ b/surfsense_backend/app/routes/public_chat_routes.py @@ -99,23 +99,6 @@ async def stream_public_podcast( if not podcast_info: raise HTTPException(status_code=404, detail="Podcast not found") - storage_key = podcast_info.get("storage_key") - if storage_key: - from app.file_storage.factory import get_storage_backend - - backend = get_storage_backend() - # Verify first so a missing object is a 404, not a mid-stream crash. - if not await backend.exists(storage_key): - raise HTTPException( - status_code=404, detail="Podcast audio is no longer available" - ) - return StreamingResponse( - backend.open_stream(storage_key), - media_type="audio/mpeg", - headers={"Accept-Ranges": "bytes"}, - ) - - # Legacy fallback for snapshots taken before the storage migration. file_path = podcast_info.get("file_path") if not file_path or not os.path.isfile(file_path): diff --git a/surfsense_backend/app/routes/search_source_connectors_routes.py b/surfsense_backend/app/routes/search_source_connectors_routes.py index 512b52ae4..dc26b4c02 100644 --- a/surfsense_backend/app/routes/search_source_connectors_routes.py +++ b/surfsense_backend/app/routes/search_source_connectors_routes.py @@ -745,23 +745,11 @@ async def index_connector_content( if not connector: raise HTTPException(status_code=404, detail="Connector not found") - # Ensure the connector actually belongs to the requested search space. - # Without this, the permission check below would authorize against the - # caller-supplied search_space_id (their own space) while the connector - # lives in another user's space, allowing cross-tenant indexing of a - # foreign connector (and use of its stored credentials). Returning 404 - # (rather than 403) on a mismatch also avoids disclosing the existence of - # connectors in other search spaces. - if connector.search_space_id != search_space_id: - raise HTTPException(status_code=404, detail="Connector not found") - - # Check if user has permission to update connectors (indexing is an update - # operation). Authorize against the connector's OWN search space — matching - # the read/update/delete handlers — not the client-supplied query param. + # Check if user has permission to update connectors (indexing is an update operation) await check_permission( session, user, - connector.search_space_id, + search_space_id, Permission.CONNECTORS_UPDATE.value, "You don't have permission to index content in this search space", ) diff --git a/surfsense_backend/app/routes/search_spaces_routes.py b/surfsense_backend/app/routes/search_spaces_routes.py index 592a9dd0e..898077b7a 100644 --- a/surfsense_backend/app/routes/search_spaces_routes.py +++ b/surfsense_backend/app/routes/search_spaces_routes.py @@ -1,20 +1,27 @@ import logging from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy import func +from sqlalchemy import func, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select +from app.config import config from app.db import ( + ImageGenerationConfig, + NewChatThread, + NewLLMConfig, Permission, SearchSpace, SearchSpaceMembership, SearchSpaceRole, User, + VisionLLMConfig, get_async_session, get_default_roles_config, ) from app.schemas import ( + LLMPreferencesRead, + LLMPreferencesUpdate, SearchSpaceCreate, SearchSpaceRead, SearchSpaceUpdate, @@ -370,6 +377,357 @@ async def delete_search_space( ) from e +# ============================================================================= +# LLM Preferences Routes +# ============================================================================= + + +async def _get_llm_config_by_id( + session: AsyncSession, config_id: int | None +) -> dict | None: + """ + Get an LLM config by ID as a dictionary. Returns database config for positive IDs, + global config for negative IDs, Auto mode config for ID 0, or None if ID is None. + """ + if config_id is None: + return None + + # Auto mode (ID 0) - uses LiteLLM Router for load balancing + if config_id == 0: + return { + "id": 0, + "name": "Auto (Fastest)", + "description": "Automatically routes requests across available LLM providers for optimal performance and rate limit handling", + "provider": "AUTO", + "custom_provider": None, + "model_name": "auto", + "api_base": None, + "litellm_params": {}, + "system_instructions": "", + "use_default_system_instructions": True, + "citations_enabled": True, + "is_global": True, + "is_auto_mode": True, + } + + if config_id < 0: + # Global config - find from YAML + global_configs = config.GLOBAL_LLM_CONFIGS + for cfg in global_configs: + if cfg.get("id") == config_id: + return { + "id": cfg.get("id"), + "name": cfg.get("name"), + "description": cfg.get("description"), + "provider": cfg.get("provider"), + "custom_provider": cfg.get("custom_provider"), + "model_name": cfg.get("model_name"), + "api_base": cfg.get("api_base"), + "litellm_params": cfg.get("litellm_params", {}), + "system_instructions": cfg.get("system_instructions", ""), + "use_default_system_instructions": cfg.get( + "use_default_system_instructions", True + ), + "citations_enabled": cfg.get("citations_enabled", True), + "is_global": True, + } + return None + else: + # Database config - convert to dict + result = await session.execute( + select(NewLLMConfig).filter(NewLLMConfig.id == config_id) + ) + db_config = result.scalars().first() + if db_config: + return { + "id": db_config.id, + "name": db_config.name, + "description": db_config.description, + "provider": db_config.provider.value if db_config.provider else None, + "custom_provider": db_config.custom_provider, + "model_name": db_config.model_name, + "api_key": db_config.api_key, + "api_base": db_config.api_base, + "litellm_params": db_config.litellm_params or {}, + "system_instructions": db_config.system_instructions or "", + "use_default_system_instructions": db_config.use_default_system_instructions, + "citations_enabled": db_config.citations_enabled, + "created_at": db_config.created_at.isoformat() + if db_config.created_at + else None, + "search_space_id": db_config.search_space_id, + } + return None + + +async def _get_image_gen_config_by_id( + session: AsyncSession, config_id: int | None +) -> dict | None: + """ + Get an image generation config by ID as a dictionary. + Returns Auto mode for ID 0, global config for negative IDs, + DB ImageGenerationConfig for positive IDs, or None. + """ + if config_id is None: + return None + + if config_id == 0: + return { + "id": 0, + "name": "Auto (Fastest)", + "description": "Automatically routes requests across available image generation providers", + "provider": "AUTO", + "model_name": "auto", + "is_global": True, + "is_auto_mode": True, + "billing_tier": "free", + } + + if config_id < 0: + for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS: + if cfg.get("id") == config_id: + return { + "id": cfg.get("id"), + "name": cfg.get("name"), + "description": cfg.get("description"), + "provider": cfg.get("provider"), + "custom_provider": cfg.get("custom_provider"), + "model_name": cfg.get("model_name"), + "api_base": cfg.get("api_base") or None, + "api_version": cfg.get("api_version") or None, + "litellm_params": cfg.get("litellm_params", {}), + "is_global": True, + "billing_tier": cfg.get("billing_tier", "free"), + } + return None + + # Positive ID: query ImageGenerationConfig table + result = await session.execute( + select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id) + ) + db_config = result.scalars().first() + if db_config: + return { + "id": db_config.id, + "name": db_config.name, + "description": db_config.description, + "provider": db_config.provider.value if db_config.provider else None, + "custom_provider": db_config.custom_provider, + "model_name": db_config.model_name, + "api_base": db_config.api_base, + "api_version": db_config.api_version, + "litellm_params": db_config.litellm_params or {}, + "created_at": db_config.created_at.isoformat() + if db_config.created_at + else None, + "search_space_id": db_config.search_space_id, + } + return None + + +async def _get_vision_llm_config_by_id( + session: AsyncSession, config_id: int | None +) -> dict | None: + if config_id is None: + return None + + if config_id == 0: + return { + "id": 0, + "name": "Auto (Fastest)", + "description": "Automatically routes requests across available vision LLM providers", + "provider": "AUTO", + "model_name": "auto", + "is_global": True, + "is_auto_mode": True, + "billing_tier": "free", + } + + if config_id < 0: + for cfg in config.GLOBAL_VISION_LLM_CONFIGS: + if cfg.get("id") == config_id: + return { + "id": cfg.get("id"), + "name": cfg.get("name"), + "description": cfg.get("description"), + "provider": cfg.get("provider"), + "custom_provider": cfg.get("custom_provider"), + "model_name": cfg.get("model_name"), + "api_base": cfg.get("api_base") or None, + "api_version": cfg.get("api_version") or None, + "litellm_params": cfg.get("litellm_params", {}), + "is_global": True, + "billing_tier": cfg.get("billing_tier", "free"), + } + return None + + result = await session.execute( + select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id) + ) + db_config = result.scalars().first() + if db_config: + return { + "id": db_config.id, + "name": db_config.name, + "description": db_config.description, + "provider": db_config.provider.value if db_config.provider else None, + "custom_provider": db_config.custom_provider, + "model_name": db_config.model_name, + "api_base": db_config.api_base, + "api_version": db_config.api_version, + "litellm_params": db_config.litellm_params or {}, + "created_at": db_config.created_at.isoformat() + if db_config.created_at + else None, + "search_space_id": db_config.search_space_id, + } + return None + + +@router.get( + "/search-spaces/{search_space_id}/llm-preferences", + response_model=LLMPreferencesRead, +) +async def get_llm_preferences( + search_space_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Get LLM preferences (role assignments) for a search space. + Requires LLM_CONFIGS_READ permission. + """ + try: + # Check permission + await check_permission( + session, + user, + search_space_id, + Permission.LLM_CONFIGS_READ.value, + "You don't have permission to view LLM preferences", + ) + + result = await session.execute( + select(SearchSpace).filter(SearchSpace.id == search_space_id) + ) + search_space = result.scalars().first() + + if not search_space: + raise HTTPException(status_code=404, detail="Search space not found") + + # Get full config objects for each role + agent_llm = await _get_llm_config_by_id(session, search_space.agent_llm_id) + image_generation_config = await _get_image_gen_config_by_id( + session, search_space.image_generation_config_id + ) + vision_llm_config = await _get_vision_llm_config_by_id( + session, search_space.vision_llm_config_id + ) + + return LLMPreferencesRead( + agent_llm_id=search_space.agent_llm_id, + image_generation_config_id=search_space.image_generation_config_id, + vision_llm_config_id=search_space.vision_llm_config_id, + agent_llm=agent_llm, + image_generation_config=image_generation_config, + vision_llm_config=vision_llm_config, + ) + + except HTTPException: + raise + except Exception as e: + logger.exception("Failed to get LLM preferences") + raise HTTPException( + status_code=500, detail=f"Failed to get LLM preferences: {e!s}" + ) from e + + +@router.put( + "/search-spaces/{search_space_id}/llm-preferences", + response_model=LLMPreferencesRead, +) +async def update_llm_preferences( + search_space_id: int, + preferences: LLMPreferencesUpdate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Update LLM preferences (role assignments) for a search space. + Requires LLM_CONFIGS_UPDATE permission. + """ + try: + # Check permission + await check_permission( + session, + user, + search_space_id, + Permission.LLM_CONFIGS_UPDATE.value, + "You don't have permission to update LLM preferences", + ) + + result = await session.execute( + select(SearchSpace).filter(SearchSpace.id == search_space_id) + ) + search_space = result.scalars().first() + + if not search_space: + raise HTTPException(status_code=404, detail="Search space not found") + + # Update preferences + update_data = preferences.model_dump(exclude_unset=True) + previous_agent_llm_id = search_space.agent_llm_id + for key, value in update_data.items(): + setattr(search_space, key, value) + + agent_llm_changed = ( + "agent_llm_id" in update_data + and update_data["agent_llm_id"] != previous_agent_llm_id + ) + if agent_llm_changed: + await session.execute( + update(NewChatThread) + .where(NewChatThread.search_space_id == search_space_id) + .values(pinned_llm_config_id=None) + ) + logger.info( + "Cleared auto model pins for search_space_id=%s after agent_llm_id change (%s -> %s)", + search_space_id, + previous_agent_llm_id, + update_data["agent_llm_id"], + ) + + await session.commit() + await session.refresh(search_space) + + # Get full config objects for response + agent_llm = await _get_llm_config_by_id(session, search_space.agent_llm_id) + image_generation_config = await _get_image_gen_config_by_id( + session, search_space.image_generation_config_id + ) + vision_llm_config = await _get_vision_llm_config_by_id( + session, search_space.vision_llm_config_id + ) + + return LLMPreferencesRead( + agent_llm_id=search_space.agent_llm_id, + image_generation_config_id=search_space.image_generation_config_id, + vision_llm_config_id=search_space.vision_llm_config_id, + agent_llm=agent_llm, + image_generation_config=image_generation_config, + vision_llm_config=vision_llm_config, + ) + + except HTTPException: + raise + except Exception as e: + await session.rollback() + logger.exception("Failed to update LLM preferences") + raise HTTPException( + status_code=500, detail=f"Failed to update LLM preferences: {e!s}" + ) from e + + @router.get("/searchspaces/{search_space_id}/snapshots") async def list_search_space_snapshots( search_space_id: int, diff --git a/surfsense_backend/app/routes/stripe_routes.py b/surfsense_backend/app/routes/stripe_routes.py index 23dce58cd..fc5fded84 100644 --- a/surfsense_backend/app/routes/stripe_routes.py +++ b/surfsense_backend/app/routes/stripe_routes.py @@ -1,10 +1,4 @@ -"""Stripe routes for the unified credit wallet. - -Buying credit packs ($1 == 1_000_000 micro-USD by default) tops up -``user.credit_micros_balance``. The same balance is debited for ETL page -processing and premium model calls. Legacy page-pack buying has been removed; -``page_purchases`` history is still readable via ``GET /stripe/purchases``. -""" +"""Stripe routes for pay-as-you-go page purchases.""" from __future__ import annotations @@ -20,24 +14,24 @@ from stripe import SignatureVerificationError, StripeClient, StripeError from app.config import config from app.db import ( - CreditPurchase, - CreditPurchaseStatus, PagePurchase, + PagePurchaseStatus, + PremiumTokenPurchase, + PremiumTokenPurchaseStatus, User, get_async_session, ) from app.schemas.stripe import ( - AutoReloadSettingsResponse, - CreateAutoReloadSetupSessionRequest, - CreateAutoReloadSetupSessionResponse, - CreateCreditCheckoutSessionRequest, - CreateCreditCheckoutSessionResponse, - CreditPurchaseHistoryResponse, - CreditStripeStatusResponse, + CreateCheckoutSessionRequest, + CreateCheckoutSessionResponse, + CreateTokenCheckoutSessionRequest, + CreateTokenCheckoutSessionResponse, FinalizeCheckoutResponse, PagePurchaseHistoryResponse, + StripeStatusResponse, StripeWebhookResponse, - UpdateAutoReloadSettingsRequest, + TokenPurchaseHistoryResponse, + TokenStripeStatusResponse, ) from app.users import current_active_user @@ -56,11 +50,11 @@ def get_stripe_client() -> StripeClient: return StripeClient(config.STRIPE_SECRET_KEY) -def _ensure_credit_buying_enabled() -> None: - if not config.STRIPE_CREDIT_BUYING_ENABLED: +def _ensure_page_buying_enabled() -> None: + if not config.STRIPE_PAGE_BUYING_ENABLED: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Credit purchases are temporarily unavailable.", + detail="Page purchases are temporarily unavailable.", ) @@ -85,62 +79,13 @@ def _get_checkout_urls(search_space_id: int) -> tuple[str, str]: return success_url, cancel_url -def _get_required_credit_price_id() -> str: - if not config.STRIPE_CREDIT_PRICE_ID: +def _get_required_stripe_price_id() -> str: + if not config.STRIPE_PRICE_ID: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="STRIPE_CREDIT_PRICE_ID is not configured.", + detail="STRIPE_PRICE_ID is not configured.", ) - return config.STRIPE_CREDIT_PRICE_ID - - -def _ensure_auto_reload_enabled() -> None: - if not config.AUTO_RELOAD_ENABLED: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Auto-reload is not available.", - ) - - -async def _get_or_create_stripe_customer( - stripe_client: StripeClient, db_session: AsyncSession, user: User -) -> str: - """Return the user's Stripe Customer id, creating + persisting one if needed. - - A Customer object is required to save and later reuse a card off-session - (Stripe: save-and-reuse). New checkouts attach to this customer so the same - saved card powers both manual top-ups and auto-reload. - """ - if user.stripe_customer_id: - return user.stripe_customer_id - - customer = stripe_client.v1.customers.create( - params={ - "email": user.email, - "metadata": {"user_id": str(user.id)}, - } - ) - customer_id = str(customer.id) - - # Persist on the live row with a lock to avoid two concurrent checkouts - # creating duplicate customers. - locked = ( - ( - await db_session.execute( - select(User).where(User.id == user.id).with_for_update(of=User) - ) - ) - .unique() - .scalar_one_or_none() - ) - if locked is not None: - if locked.stripe_customer_id: - # Another request won the race; reuse theirs. - customer_id = locked.stripe_customer_id - else: - locked.stripe_customer_id = customer_id - await db_session.commit() - return customer_id + return config.STRIPE_PRICE_ID def _normalize_optional_string(value: Any) -> str | None: @@ -165,9 +110,14 @@ def _get_metadata(checkout_session: Any) -> dict[str, str]: if metadata is None: return {} + # 1. Plain dict (older SDKs that subclassed dict, JSON-decoded events + # in tests, etc.). if isinstance(metadata, dict): return {str(k): str(v) for k, v in metadata.items()} + # 2. Modern Stripe SDK: every ``StripeObject`` has ``to_dict()``. + # ``recursive=False`` is correct because Stripe metadata values + # are always primitive strings. to_dict = getattr(metadata, "to_dict", None) if callable(to_dict): try: @@ -180,6 +130,8 @@ def _get_metadata(checkout_session: Any) -> dict[str, str]: getattr(checkout_session, "id", "?"), ) + # 3. Last-resort: read the SDK's private ``_data`` backing dict. + # Stable across stripe-python 6.x -> 15.x. inner = getattr(metadata, "_data", None) if isinstance(inner, dict): return {str(k): str(v) for k, v in inner.items()} @@ -192,90 +144,120 @@ def _get_metadata(checkout_session: Any) -> dict[str, str]: return {} -# Canonical purchase_type metadata value is ``credits``. ``premium_tokens`` and -# ``premium_credit`` were emitted by earlier releases so they're still accepted -# on the read side for any in-flight checkout sessions. -_PURCHASE_TYPE_CREDIT_VALUES = frozenset( - {"credits", "premium_tokens", "premium_credit"} -) +# Canonical purchase_type metadata values. ``premium_credit`` was emitted +# by an earlier release of ``create_token_checkout_session`` so it's still +# accepted on the read side for backward compat with in-flight sessions. +_PURCHASE_TYPE_TOKEN_VALUES = frozenset({"premium_tokens", "premium_credit"}) -def _is_credit_purchase(metadata: dict[str, str]) -> bool: - """Return True for a credit purchase (default for all live checkouts).""" - return metadata.get("purchase_type", "credits") in _PURCHASE_TYPE_CREDIT_VALUES +def _is_token_purchase(metadata: dict[str, str]) -> bool: + """Return True for premium-credit (a.k.a. premium_token) purchases.""" + return metadata.get("purchase_type", "page_packs") in _PURCHASE_TYPE_TOKEN_VALUES -async def _mark_credit_purchase_failed( +async def _get_or_create_purchase_from_checkout_session( + db_session: AsyncSession, + checkout_session: Any, +) -> PagePurchase | None: + """Look up a PagePurchase by checkout session ID (with FOR UPDATE lock). + + If the row doesn't exist yet (e.g. the webhook arrived before the API + response committed), create one from the Stripe session metadata. + """ + checkout_session_id = str(checkout_session.id) + purchase = ( + await db_session.execute( + select(PagePurchase) + .where(PagePurchase.stripe_checkout_session_id == checkout_session_id) + .with_for_update() + ) + ).scalar_one_or_none() + if purchase is not None: + return purchase + + metadata = _get_metadata(checkout_session) + user_id = metadata.get("user_id") + quantity = int(metadata.get("quantity", "0")) + pages_per_unit = int(metadata.get("pages_per_unit", "0")) + + if not user_id or quantity <= 0 or pages_per_unit <= 0: + logger.error( + "Skipping Stripe fulfillment for session %s due to incomplete metadata: %s", + checkout_session_id, + metadata, + ) + return None + + purchase = PagePurchase( + user_id=uuid.UUID(user_id), + stripe_checkout_session_id=checkout_session_id, + stripe_payment_intent_id=_normalize_optional_string( + getattr(checkout_session, "payment_intent", None) + ), + quantity=quantity, + pages_granted=quantity * pages_per_unit, + amount_total=getattr(checkout_session, "amount_total", None), + currency=getattr(checkout_session, "currency", None), + status=PagePurchaseStatus.PENDING, + ) + db_session.add(purchase) + await db_session.flush() + return purchase + + +async def _mark_purchase_failed( db_session: AsyncSession, checkout_session_id: str ) -> StripeWebhookResponse: purchase = ( await db_session.execute( - select(CreditPurchase) - .where(CreditPurchase.stripe_checkout_session_id == checkout_session_id) + select(PagePurchase) + .where(PagePurchase.stripe_checkout_session_id == checkout_session_id) .with_for_update() ) ).scalar_one_or_none() - if purchase is not None and purchase.status == CreditPurchaseStatus.PENDING: - purchase.status = CreditPurchaseStatus.FAILED + if purchase is not None and purchase.status == PagePurchaseStatus.PENDING: + purchase.status = PagePurchaseStatus.FAILED await db_session.commit() return StripeWebhookResponse() -async def _fulfill_completed_credit_purchase( - db_session: AsyncSession, checkout_session: Any +async def _mark_token_purchase_failed( + db_session: AsyncSession, checkout_session_id: str ) -> StripeWebhookResponse: - """Grant credit to the user after a confirmed Stripe payment. - - Uses ``SELECT ... FOR UPDATE`` on both the CreditPurchase and User rows to - prevent double-granting when Stripe retries the webhook concurrently. - """ - checkout_session_id = str(checkout_session.id) purchase = ( await db_session.execute( - select(CreditPurchase) - .where(CreditPurchase.stripe_checkout_session_id == checkout_session_id) + select(PremiumTokenPurchase) + .where( + PremiumTokenPurchase.stripe_checkout_session_id == checkout_session_id + ) .with_for_update() ) ).scalar_one_or_none() + if purchase is not None and purchase.status == PremiumTokenPurchaseStatus.PENDING: + purchase.status = PremiumTokenPurchaseStatus.FAILED + await db_session.commit() + + return StripeWebhookResponse() + + +async def _fulfill_completed_purchase( + db_session: AsyncSession, checkout_session: Any +) -> StripeWebhookResponse: + """Grant pages to the user after a confirmed Stripe payment. + + Uses SELECT ... FOR UPDATE on both the PagePurchase and User rows to + prevent double-granting when Stripe retries the webhook concurrently. + """ + purchase = await _get_or_create_purchase_from_checkout_session( + db_session, checkout_session + ) if purchase is None: - metadata = _get_metadata(checkout_session) - user_id = metadata.get("user_id") - quantity = int(metadata.get("quantity", "0")) - # Read the new metadata key first, fall back to legacy ones so - # in-flight checkout sessions created before the rename still fulfil. - credit_micros_per_unit = int( - metadata.get("credit_micros_per_unit") - or metadata.get("tokens_per_unit", "0") - ) + return StripeWebhookResponse() - if not user_id or quantity <= 0 or credit_micros_per_unit <= 0: - logger.error( - "Skipping credit fulfillment for session %s: incomplete metadata %s", - checkout_session_id, - metadata, - ) - return StripeWebhookResponse() - - purchase = CreditPurchase( - user_id=uuid.UUID(user_id), - stripe_checkout_session_id=checkout_session_id, - stripe_payment_intent_id=_normalize_optional_string( - getattr(checkout_session, "payment_intent", None) - ), - quantity=quantity, - credit_micros_granted=quantity * credit_micros_per_unit, - amount_total=getattr(checkout_session, "amount_total", None), - currency=getattr(checkout_session, "currency", None), - source="checkout", - status=CreditPurchaseStatus.PENDING, - ) - db_session.add(purchase) - await db_session.flush() - - if purchase.status == CreditPurchaseStatus.COMPLETED: + if purchase.status == PagePurchaseStatus.COMPLETED: return StripeWebhookResponse() user = ( @@ -289,188 +271,132 @@ async def _fulfill_completed_credit_purchase( ) if user is None: logger.error( - "Skipping credit fulfillment for session %s: user %s not found", + "Skipping Stripe fulfillment for session %s because user %s was not found.", purchase.stripe_checkout_session_id, purchase.user_id, ) return StripeWebhookResponse() - purchase.status = CreditPurchaseStatus.COMPLETED + purchase.status = PagePurchaseStatus.COMPLETED purchase.completed_at = datetime.now(UTC) purchase.amount_total = getattr(checkout_session, "amount_total", None) purchase.currency = getattr(checkout_session, "currency", None) purchase.stripe_payment_intent_id = _normalize_optional_string( getattr(checkout_session, "payment_intent", None) ) - # Add the granted micro-USD directly to the spendable wallet balance. - user.credit_micros_balance = ( - user.credit_micros_balance + purchase.credit_micros_granted - ) + # pages_used can exceed pages_limit when a document's final page count is + # determined after processing. Base the new limit on the higher of the two + # so the purchased pages are fully usable above the current high-water mark. + user.pages_limit = max(user.pages_used, user.pages_limit) + purchase.pages_granted await db_session.commit() return StripeWebhookResponse() -async def _handle_setup_session_completed( - stripe_client: StripeClient, - db_session: AsyncSession, - checkout_session: Any, +async def _fulfill_completed_token_purchase( + db_session: AsyncSession, checkout_session: Any ) -> StripeWebhookResponse: - """Persist the saved card from a completed ``mode=setup`` checkout session. - - The setup session saves a card on the customer (Stripe save-and-reuse). We - pull the resulting payment method off the SetupIntent and store it as the - user's ``auto_reload_payment_method_id`` so the off-session charge can use - it. Auto-reload itself is only armed once the user enables it via the - settings endpoint. - """ - metadata = _get_metadata(checkout_session) - user_id = metadata.get("user_id") - if not user_id: - logger.warning( - "Setup session %s completed without user_id metadata", - getattr(checkout_session, "id", "?"), - ) - return StripeWebhookResponse() - - setup_intent_id = _normalize_optional_string( - getattr(checkout_session, "setup_intent", None) - ) - payment_method_id: str | None = None - if setup_intent_id: - try: - setup_intent = stripe_client.v1.setup_intents.retrieve(setup_intent_id) - payment_method_id = _normalize_optional_string( - getattr(setup_intent, "payment_method", None) + """Grant premium tokens to the user after a confirmed Stripe payment.""" + checkout_session_id = str(checkout_session.id) + purchase = ( + await db_session.execute( + select(PremiumTokenPurchase) + .where( + PremiumTokenPurchase.stripe_checkout_session_id == checkout_session_id ) - except StripeError: - logger.exception( - "Failed to retrieve setup intent %s for session %s", - setup_intent_id, - getattr(checkout_session, "id", "?"), - ) - - if not payment_method_id: - logger.warning( - "Setup session %s completed without a payment method", - getattr(checkout_session, "id", "?"), + .with_for_update() ) + ).scalar_one_or_none() + + if purchase is None: + metadata = _get_metadata(checkout_session) + user_id = metadata.get("user_id") + quantity = int(metadata.get("quantity", "0")) + # Read the new metadata key first, fall back to the legacy one so + # in-flight checkout sessions created before the cost-credits + # release still fulfil correctly (the unit is numerically the + # same: $1 buys 1_000_000 micro-USD == 1_000_000 tokens). + credit_micros_per_unit = int( + metadata.get("credit_micros_per_unit") + or metadata.get("tokens_per_unit", "0") + ) + + if not user_id or quantity <= 0 or credit_micros_per_unit <= 0: + logger.error( + "Skipping token fulfillment for session %s: incomplete metadata %s", + checkout_session_id, + metadata, + ) + return StripeWebhookResponse() + + purchase = PremiumTokenPurchase( + user_id=uuid.UUID(user_id), + stripe_checkout_session_id=checkout_session_id, + stripe_payment_intent_id=_normalize_optional_string( + getattr(checkout_session, "payment_intent", None) + ), + quantity=quantity, + credit_micros_granted=quantity * credit_micros_per_unit, + amount_total=getattr(checkout_session, "amount_total", None), + currency=getattr(checkout_session, "currency", None), + status=PremiumTokenPurchaseStatus.PENDING, + ) + db_session.add(purchase) + await db_session.flush() + + if purchase.status == PremiumTokenPurchaseStatus.COMPLETED: return StripeWebhookResponse() user = ( ( await db_session.execute( - select(User) - .where(User.id == uuid.UUID(user_id)) - .with_for_update(of=User) + select(User).where(User.id == purchase.user_id).with_for_update(of=User) ) ) .unique() .scalar_one_or_none() ) if user is None: + logger.error( + "Skipping token fulfillment for session %s: user %s not found", + purchase.stripe_checkout_session_id, + purchase.user_id, + ) return StripeWebhookResponse() - customer_id = _normalize_optional_string( - getattr(checkout_session, "customer", None) + purchase.status = PremiumTokenPurchaseStatus.COMPLETED + purchase.completed_at = datetime.now(UTC) + purchase.amount_total = getattr(checkout_session, "amount_total", None) + purchase.currency = getattr(checkout_session, "currency", None) + purchase.stripe_payment_intent_id = _normalize_optional_string( + getattr(checkout_session, "payment_intent", None) + ) + # Top up the user's credit balance by the granted micro-USD amount. + # ``max(used, limit)`` clamps the case where the legacy code wrote a + # used value above the limit (e.g. underbilling rounding) so adding + # ``credit_micros_granted`` always lifts the limit by the full pack + # size rather than disappearing into past overuse. + user.premium_credit_micros_limit = ( + max(user.premium_credit_micros_used, user.premium_credit_micros_limit) + + purchase.credit_micros_granted ) - if customer_id and not user.stripe_customer_id: - user.stripe_customer_id = customer_id - user.auto_reload_payment_method_id = payment_method_id - await db_session.commit() - - # Make this the customer's default for future off-session charges. - if user.stripe_customer_id: - try: - stripe_client.v1.customers.update( - user.stripe_customer_id, - params={ - "invoice_settings": {"default_payment_method": payment_method_id} - }, - ) - except StripeError: - logger.warning( - "Failed to set default payment method for customer %s", - user.stripe_customer_id, - exc_info=True, - ) - - return StripeWebhookResponse() - - -async def _reconcile_auto_reload_payment_intent( - db_session: AsyncSession, - payment_intent: Any, - *, - succeeded: bool, -) -> StripeWebhookResponse: - """Backstop for the off-session auto-reload charge via webhook. - - The Celery task confirms the PaymentIntent synchronously and grants credit - inline, but the ``payment_intent.succeeded`` / ``payment_intent.payment_failed`` - webhook acts as a safety net. We locate the matching ``auto_reload`` - CreditPurchase by payment-intent id and only transition PENDING rows so we - never double-grant. - """ - payment_intent_id = str(payment_intent.id) - purchase = ( - await db_session.execute( - select(CreditPurchase) - .where(CreditPurchase.stripe_payment_intent_id == payment_intent_id) - .with_for_update() - ) - ).scalar_one_or_none() - - if purchase is None or purchase.status != CreditPurchaseStatus.PENDING: - return StripeWebhookResponse() - - if succeeded: - user = ( - ( - await db_session.execute( - select(User) - .where(User.id == purchase.user_id) - .with_for_update(of=User) - ) - ) - .unique() - .scalar_one_or_none() - ) - if user is None: - return StripeWebhookResponse() - purchase.status = CreditPurchaseStatus.COMPLETED - purchase.completed_at = datetime.now(UTC) - user.credit_micros_balance = ( - user.credit_micros_balance + purchase.credit_micros_granted - ) - else: - purchase.status = CreditPurchaseStatus.FAILED await db_session.commit() return StripeWebhookResponse() -@router.post( - "/create-credit-checkout-session", - response_model=CreateCreditCheckoutSessionResponse, -) -async def create_credit_checkout_session( - body: CreateCreditCheckoutSessionRequest, +@router.post("/create-checkout-session", response_model=CreateCheckoutSessionResponse) +async def create_checkout_session( + body: CreateCheckoutSessionRequest, user: User = Depends(current_active_user), db_session: AsyncSession = Depends(get_async_session), -) -> CreateCreditCheckoutSessionResponse: - """Create a Stripe Checkout Session for buying credit packs. - - Each pack grants ``STRIPE_CREDIT_MICROS_PER_UNIT`` micro-USD of credit - (default 1_000_000 = $1.00). The balance is debited at the actual provider - cost reported by LiteLLM (premium calls) or ``MICROS_PER_PAGE`` per page - (ETL), so $1 of credit always buys $1 worth of usage at cost. - """ - _ensure_credit_buying_enabled() +) -> CreateCheckoutSessionResponse: + """Create a Stripe Checkout Session for buying page packs.""" + _ensure_page_buying_enabled() stripe_client = get_stripe_client() - price_id = _get_required_credit_price_id() + price_id = _get_required_stripe_price_id() success_url, cancel_url = _get_checkout_urls(body.search_space_id) - credit_micros_granted = body.quantity * config.STRIPE_CREDIT_MICROS_PER_UNIT + pages_granted = body.quantity * config.STRIPE_PAGES_PER_UNIT try: checkout_session = stripe_client.v1.checkout.sessions.create( @@ -489,14 +415,14 @@ async def create_credit_checkout_session( "metadata": { "user_id": str(user.id), "quantity": str(body.quantity), - "credit_micros_per_unit": str(config.STRIPE_CREDIT_MICROS_PER_UNIT), - "purchase_type": "credits", + "pages_per_unit": str(config.STRIPE_PAGES_PER_UNIT), + "purchase_type": "page_packs", }, } ) except StripeError as exc: logger.exception( - "Failed to create credit checkout session for user %s", user.id + "Failed to create Stripe checkout session for user %s", user.id ) raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, @@ -511,23 +437,28 @@ async def create_credit_checkout_session( ) db_session.add( - CreditPurchase( + PagePurchase( user_id=user.id, stripe_checkout_session_id=str(checkout_session.id), stripe_payment_intent_id=_normalize_optional_string( getattr(checkout_session, "payment_intent", None) ), quantity=body.quantity, - credit_micros_granted=credit_micros_granted, + pages_granted=pages_granted, amount_total=getattr(checkout_session, "amount_total", None), currency=getattr(checkout_session, "currency", None), - source="checkout", - status=CreditPurchaseStatus.PENDING, + status=PagePurchaseStatus.PENDING, ) ) await db_session.commit() - return CreateCreditCheckoutSessionResponse(checkout_url=checkout_url) + return CreateCheckoutSessionResponse(checkout_url=checkout_url) + + +@router.get("/status", response_model=StripeStatusResponse) +async def get_stripe_status() -> StripeStatusResponse: + """Return page-buying availability for frontend feature gating.""" + return StripeStatusResponse(page_buying_enabled=config.STRIPE_PAGE_BUYING_ENABLED) @router.post("/webhook", response_model=StripeWebhookResponse) @@ -535,7 +466,7 @@ async def stripe_webhook( request: Request, db_session: AsyncSession = Depends(get_async_session), ) -> StripeWebhookResponse: - """Handle Stripe webhooks and grant purchased credit after payment.""" + """Handle Stripe webhooks and grant purchased pages after payment.""" if not config.STRIPE_WEBHOOK_SECRET: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, @@ -587,37 +518,12 @@ async def stripe_webhook( ) return StripeWebhookResponse() - # mode=setup sessions carry no line items / payment; they save a - # card for off-session auto-reload. - if getattr(checkout_session, "mode", None) == "setup": - return await _handle_setup_session_completed( - stripe_client, db_session, checkout_session - ) - metadata = _get_metadata(checkout_session) - if _is_credit_purchase(metadata): - return await _fulfill_completed_credit_purchase( + if _is_token_purchase(metadata): + return await _fulfill_completed_token_purchase( db_session, checkout_session ) - # Legacy page-pack purchase: page buying is removed, so log and - # ignore rather than fulfilling. - logger.info( - "Ignoring non-credit checkout session %s (purchase_type=%s); " - "page buying is removed.", - getattr(checkout_session, "id", "?"), - metadata.get("purchase_type"), - ) - return StripeWebhookResponse() - - if event.type == "payment_intent.succeeded": - return await _reconcile_auto_reload_payment_intent( - db_session, event.data.object, succeeded=True - ) - - if event.type == "payment_intent.payment_failed": - return await _reconcile_auto_reload_payment_intent( - db_session, event.data.object, succeeded=False - ) + return await _fulfill_completed_purchase(db_session, checkout_session) if event.type in { "checkout.session.async_payment_failed", @@ -625,12 +531,16 @@ async def stripe_webhook( }: checkout_session = event.data.object metadata = _get_metadata(checkout_session) - if _is_credit_purchase(metadata): - return await _mark_credit_purchase_failed( + if _is_token_purchase(metadata): + return await _mark_token_purchase_failed( db_session, str(checkout_session.id) ) - return StripeWebhookResponse() + return await _mark_purchase_failed(db_session, str(checkout_session.id)) except Exception: + # Re-raise so FastAPI returns 500 and Stripe retries this delivery. + # Logging here gives us a structured trail with event id + type so + # future webhook bugs surface immediately in the logs without + # having to grep by request_id. logger.exception( "Stripe webhook handler failed for event id=%s type=%s — Stripe will retry", getattr(event, "id", "?"), @@ -647,17 +557,24 @@ async def finalize_checkout( user: User = Depends(current_active_user), db_session: AsyncSession = Depends(get_async_session), ) -> FinalizeCheckoutResponse: - """Synchronously fulfil a credit checkout session from the success page. + """Synchronously fulfil a checkout session from the success page. Solves the webhook-vs-redirect race: the user lands on ``/dashboard/<id>/purchase-success?session_id=cs_...`` typically a - few hundred ms after paying, but Stripe's ``checkout.session.completed`` - webhook can take 5-30s+ to arrive. Calling this endpoint on success-page - mount fulfils the purchase immediately via the same idempotent helper the - webhook uses. + few hundred ms after paying, but Stripe's + ``checkout.session.completed`` webhook can take 5-30s+ to arrive. + Calling this endpoint on success-page mount fulfils the purchase + immediately by retrieving the session from Stripe's API and + invoking the same idempotent helpers the webhook uses. + + Idempotency: if the webhook has already fulfilled this purchase + (status=COMPLETED), the helpers short-circuit and we just return + the latest balance. Concurrent webhook + finalize calls are safe + because both acquire ``SELECT ... FOR UPDATE`` on the purchase row. Authorization: the session's ``client_reference_id`` must match the - authenticated user's id. + authenticated user's id. This prevents a user from finalising + someone else's checkout session if they happen to know the id. """ stripe_client = get_stripe_client() @@ -675,6 +592,9 @@ async def finalize_checkout( detail="Checkout session not found.", ) from exc + # Authorization check: the user finalising must be the user who + # initiated the checkout. ``client_reference_id`` is set in + # ``create_checkout_session`` / ``create_token_checkout_session``. client_reference_id = getattr(checkout_session, "client_reference_id", None) if client_reference_id != str(user.id): logger.warning( @@ -688,75 +608,109 @@ async def finalize_checkout( detail="This checkout session does not belong to you.", ) + metadata = _get_metadata(checkout_session) + is_token = _is_token_purchase(metadata) payment_status = getattr(checkout_session, "payment_status", None) session_status = getattr(checkout_session, "status", None) + + # Defensive fallback: if metadata can't be read for any reason + # (extraction failure, manually-created session in Stripe dashboard, + # SDK upgrade breaking ``to_dict``, etc.) we'd otherwise route every + # purchase to the page_packs handler and get stuck. Resolve the + # purchase_type by checking which table actually has the row keyed + # by this Stripe session id. + if not metadata: + existing_token_purchase = ( + await db_session.execute( + select(PremiumTokenPurchase.id).where( + PremiumTokenPurchase.stripe_checkout_session_id + == str(checkout_session.id) + ) + ) + ).scalar_one_or_none() + if existing_token_purchase is not None: + is_token = True + else: + existing_page_purchase = ( + await db_session.execute( + select(PagePurchase.id).where( + PagePurchase.stripe_checkout_session_id + == str(checkout_session.id) + ) + ) + ).scalar_one_or_none() + if existing_page_purchase is None: + logger.error( + "finalize_checkout: no purchase row in either table " + "and metadata is empty for session=%s user=%s", + session_id, + user.id, + ) + # Fall through; downstream path will short-circuit on + # missing-row + empty-metadata. + logger.info( + "finalize_checkout: recovered purchase_type=%s for session=%s " + "via DB fallback (metadata was empty)", + "premium_tokens" if is_token else "page_packs", + session_id, + ) + is_paid = payment_status in {"paid", "no_payment_required"} is_expired = session_status == "expired" if is_paid: - await _fulfill_completed_credit_purchase(db_session, checkout_session) + if is_token: + await _fulfill_completed_token_purchase(db_session, checkout_session) + else: + await _fulfill_completed_purchase(db_session, checkout_session) elif is_expired: - await _mark_credit_purchase_failed(db_session, str(checkout_session.id)) - # Otherwise leave the row alone — frontend keeps polling and the webhook - # will eventually win the race. + if is_token: + await _mark_token_purchase_failed(db_session, str(checkout_session.id)) + else: + await _mark_purchase_failed(db_session, str(checkout_session.id)) + # Otherwise (e.g. payment_status="unpaid", session_status="open"), + # leave the purchase row alone — frontend will keep polling and the + # webhook will eventually win the race. + # Refresh the user row so the response reflects any update applied + # by the fulfilment helpers in this same session. await db_session.refresh(user) + if is_token: + purchase = ( + await db_session.execute( + select(PremiumTokenPurchase).where( + PremiumTokenPurchase.stripe_checkout_session_id + == str(checkout_session.id) + ) + ) + ).scalar_one_or_none() + return FinalizeCheckoutResponse( + purchase_type="premium_tokens", + status=purchase.status.value if purchase else "pending", + premium_credit_micros_limit=user.premium_credit_micros_limit, + premium_credit_micros_used=user.premium_credit_micros_used, + premium_credit_micros_granted=( + purchase.credit_micros_granted if purchase else None + ), + ) + purchase = ( await db_session.execute( - select(CreditPurchase).where( - CreditPurchase.stripe_checkout_session_id == str(checkout_session.id) + select(PagePurchase).where( + PagePurchase.stripe_checkout_session_id == str(checkout_session.id) ) ) ).scalar_one_or_none() return FinalizeCheckoutResponse( + purchase_type="page_packs", status=purchase.status.value if purchase else "pending", - credit_micros_balance=user.credit_micros_balance, - credit_micros_granted=(purchase.credit_micros_granted if purchase else None), + pages_limit=user.pages_limit, + pages_used=user.pages_used, + pages_granted=purchase.pages_granted if purchase else None, ) -@router.get("/credit-status", response_model=CreditStripeStatusResponse) -async def get_credit_status( - user: User = Depends(current_active_user), -) -> CreditStripeStatusResponse: - """Return credit-buying availability and current balance for the frontend. - - ``credit_micros_balance`` is in micro-USD (1_000_000 = $1.00); the FE - divides by 1M when displaying. - """ - return CreditStripeStatusResponse( - credit_buying_enabled=config.STRIPE_CREDIT_BUYING_ENABLED, - credit_micros_balance=user.credit_micros_balance, - ) - - -@router.get("/credit-purchases", response_model=CreditPurchaseHistoryResponse) -async def get_credit_purchases( - user: User = Depends(current_active_user), - db_session: AsyncSession = Depends(get_async_session), - offset: int = 0, - limit: int = 50, -) -> CreditPurchaseHistoryResponse: - """Return the authenticated user's credit purchase history.""" - limit = min(limit, 100) - purchases = ( - ( - await db_session.execute( - select(CreditPurchase) - .where(CreditPurchase.user_id == user.id) - .order_by(CreditPurchase.created_at.desc()) - .offset(offset) - .limit(limit) - ) - ) - .scalars() - .all() - ) - - return CreditPurchaseHistoryResponse(purchases=purchases) - - @router.get("/purchases", response_model=PagePurchaseHistoryResponse) async def get_page_purchases( user: User = Depends(current_active_user), @@ -764,10 +718,7 @@ async def get_page_purchases( offset: int = 0, limit: int = 50, ) -> PagePurchaseHistoryResponse: - """Return the authenticated user's legacy page-purchase history (read-only). - - Page buying is removed; this endpoint stays for historical records. - """ + """Return the authenticated user's page-purchase history.""" limit = min(limit, 100) purchases = ( ( @@ -786,155 +737,163 @@ async def get_page_purchases( return PagePurchaseHistoryResponse(purchases=purchases) -def _auto_reload_settings_response(user: User) -> AutoReloadSettingsResponse: - return AutoReloadSettingsResponse( - feature_enabled=config.AUTO_RELOAD_ENABLED, - enabled=bool(user.auto_reload_enabled), - threshold_micros=user.auto_reload_threshold_micros, - amount_micros=user.auto_reload_amount_micros, - min_amount_micros=config.AUTO_RELOAD_MIN_AMOUNT_MICROS, - has_payment_method=bool(user.auto_reload_payment_method_id), - failed_at=user.auto_reload_failed_at, - ) +# ============================================================================= +# Premium Token Purchase Routes +# ============================================================================= -@router.post( - "/auto-reload/setup", - response_model=CreateAutoReloadSetupSessionResponse, -) -async def create_auto_reload_setup_session( - body: CreateAutoReloadSetupSessionRequest, - user: User = Depends(current_active_user), - db_session: AsyncSession = Depends(get_async_session), -) -> CreateAutoReloadSetupSessionResponse: - """Start a ``mode=setup`` checkout session to save a card for auto-reload. +def _ensure_token_buying_enabled() -> None: + if not config.STRIPE_TOKEN_BUYING_ENABLED: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Premium token purchases are temporarily unavailable.", + ) - Uses a SetupIntent (no immediate charge) attached to the user's Stripe - Customer so the card can later be charged off-session. On completion the - webhook stores the resulting payment method on the user. - """ - _ensure_auto_reload_enabled() - _ensure_credit_buying_enabled() - stripe_client = get_stripe_client() + +def _get_token_checkout_urls(search_space_id: int) -> tuple[str, str]: if not config.NEXT_FRONTEND_URL: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="NEXT_FRONTEND_URL is not configured.", ) - customer_id = await _get_or_create_stripe_customer(stripe_client, db_session, user) - base_url = config.NEXT_FRONTEND_URL.rstrip("/") + # See ``_get_checkout_urls`` for why session_id is appended. success_url = ( - f"{base_url}/dashboard/{body.search_space_id}/user-settings/purchases" - f"?auto_reload_setup=success" - ) - cancel_url = ( - f"{base_url}/dashboard/{body.search_space_id}/user-settings/purchases" - f"?auto_reload_setup=cancel" + f"{base_url}/dashboard/{search_space_id}/purchase-success" + f"?session_id={{CHECKOUT_SESSION_ID}}" ) + cancel_url = f"{base_url}/dashboard/{search_space_id}/purchase-cancel" + return success_url, cancel_url + + +def _get_required_token_price_id() -> str: + if not config.STRIPE_PREMIUM_TOKEN_PRICE_ID: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="STRIPE_PREMIUM_TOKEN_PRICE_ID is not configured.", + ) + return config.STRIPE_PREMIUM_TOKEN_PRICE_ID + + +@router.post("/create-token-checkout-session") +async def create_token_checkout_session( + body: CreateTokenCheckoutSessionRequest, + user: User = Depends(current_active_user), + db_session: AsyncSession = Depends(get_async_session), +): + """Create a Stripe Checkout Session for buying premium credit packs. + + Each pack grants ``STRIPE_CREDIT_MICROS_PER_UNIT`` micro-USD of + credit (default 1_000_000 = $1.00). The user's balance is debited + at the actual provider cost reported by LiteLLM at finalize time, + so $1 of credit always buys $1 worth of provider usage at cost. + """ + _ensure_token_buying_enabled() + stripe_client = get_stripe_client() + price_id = _get_required_token_price_id() + success_url, cancel_url = _get_token_checkout_urls(body.search_space_id) + credit_micros_granted = body.quantity * config.STRIPE_CREDIT_MICROS_PER_UNIT try: checkout_session = stripe_client.v1.checkout.sessions.create( params={ - "mode": "setup", - # Required in setup mode when payment_method_types is omitted - # (dynamic payment methods); auto-reload charges are in USD. - "currency": "usd", + "mode": "payment", "success_url": success_url, "cancel_url": cancel_url, - "customer": customer_id, + "line_items": [ + { + "price": price_id, + "quantity": body.quantity, + } + ], "client_reference_id": str(user.id), + "customer_email": user.email, "metadata": { "user_id": str(user.id), - "purchase_type": "auto_reload_setup", + "quantity": str(body.quantity), + "credit_micros_per_unit": str(config.STRIPE_CREDIT_MICROS_PER_UNIT), + # Canonical value matched by ``_is_token_purchase``. + # The legacy ``"premium_credit"`` is still accepted on + # the read side for any in-flight sessions started + # before this rename. + "purchase_type": "premium_tokens", }, } ) except StripeError as exc: - logger.exception( - "Failed to create auto-reload setup session for user %s", user.id - ) + logger.exception("Failed to create token checkout session for user %s", user.id) raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, - detail="Unable to create Stripe setup session.", + detail="Unable to create Stripe checkout session.", ) from exc checkout_url = getattr(checkout_session, "url", None) if not checkout_url: raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, - detail="Stripe setup session did not return a URL.", + detail="Stripe checkout session did not return a URL.", ) - return CreateAutoReloadSetupSessionResponse(checkout_url=checkout_url) + db_session.add( + PremiumTokenPurchase( + user_id=user.id, + stripe_checkout_session_id=str(checkout_session.id), + stripe_payment_intent_id=_normalize_optional_string( + getattr(checkout_session, "payment_intent", None) + ), + quantity=body.quantity, + credit_micros_granted=credit_micros_granted, + amount_total=getattr(checkout_session, "amount_total", None), + currency=getattr(checkout_session, "currency", None), + status=PremiumTokenPurchaseStatus.PENDING, + ) + ) + await db_session.commit() + + return CreateTokenCheckoutSessionResponse(checkout_url=checkout_url) -@router.get("/auto-reload", response_model=AutoReloadSettingsResponse) -async def get_auto_reload_settings( +@router.get("/token-status") +async def get_token_status( user: User = Depends(current_active_user), -) -> AutoReloadSettingsResponse: - """Return the user's auto-reload configuration and saved-card state.""" - return _auto_reload_settings_response(user) +): + """Return token-buying availability and current premium credit quota for frontend. - -@router.put("/auto-reload", response_model=AutoReloadSettingsResponse) -async def update_auto_reload_settings( - body: UpdateAutoReloadSettingsRequest, - user: User = Depends(current_active_user), - db_session: AsyncSession = Depends(get_async_session), -) -> AutoReloadSettingsResponse: - """Update auto-reload preferences. - - Enabling requires a saved card plus a positive threshold and an amount of - at least ``AUTO_RELOAD_MIN_AMOUNT_MICROS``. Disabling always succeeds and - clears any prior failure flag. + Values are in micro-USD (1_000_000 = $1.00); the FE divides by 1M + when displaying. The route name is preserved for back-compat with + pinned client deployments. """ - _ensure_auto_reload_enabled() - - locked = ( - ( - await db_session.execute( - select(User).where(User.id == user.id).with_for_update(of=User) - ) - ) - .unique() - .scalar_one() + used = user.premium_credit_micros_used + limit = user.premium_credit_micros_limit + return TokenStripeStatusResponse( + token_buying_enabled=config.STRIPE_TOKEN_BUYING_ENABLED, + premium_credit_micros_used=used, + premium_credit_micros_limit=limit, + premium_credit_micros_remaining=max(0, limit - used), ) - if body.enabled: - if not locked.auto_reload_payment_method_id: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Add a payment method before enabling auto-reload.", - ) - if not body.threshold_micros or body.threshold_micros <= 0: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="A positive low-balance threshold is required.", - ) - if ( - body.amount_micros is None - or body.amount_micros < config.AUTO_RELOAD_MIN_AMOUNT_MICROS - ): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=( - "Reload amount must be at least " - f"{config.AUTO_RELOAD_MIN_AMOUNT_MICROS} micro-USD." - ), - ) - locked.auto_reload_enabled = True - locked.auto_reload_threshold_micros = body.threshold_micros - locked.auto_reload_amount_micros = body.amount_micros - # Re-enabling clears the prior failure flag so the user can retry. - locked.auto_reload_failed_at = None - else: - locked.auto_reload_enabled = False - if body.threshold_micros is not None: - locked.auto_reload_threshold_micros = body.threshold_micros - if body.amount_micros is not None: - locked.auto_reload_amount_micros = body.amount_micros - await db_session.commit() - await db_session.refresh(locked) - return _auto_reload_settings_response(locked) +@router.get("/token-purchases") +async def get_token_purchases( + user: User = Depends(current_active_user), + db_session: AsyncSession = Depends(get_async_session), + offset: int = 0, + limit: int = 50, +): + """Return the authenticated user's premium token purchase history.""" + limit = min(limit, 100) + purchases = ( + ( + await db_session.execute( + select(PremiumTokenPurchase) + .where(PremiumTokenPurchase.user_id == user.id) + .order_by(PremiumTokenPurchase.created_at.desc()) + .offset(offset) + .limit(limit) + ) + ) + .scalars() + .all() + ) + + return TokenPurchaseHistoryResponse(purchases=purchases) diff --git a/surfsense_backend/app/routes/vision_llm_routes.py b/surfsense_backend/app/routes/vision_llm_routes.py new file mode 100644 index 000000000..e4f08f604 --- /dev/null +++ b/surfsense_backend/app/routes/vision_llm_routes.py @@ -0,0 +1,304 @@ +import logging + +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config import config +from app.db import ( + Permission, + User, + VisionLLMConfig, + get_async_session, +) +from app.schemas import ( + GlobalVisionLLMConfigRead, + VisionLLMConfigCreate, + VisionLLMConfigRead, + VisionLLMConfigUpdate, +) +from app.services.vision_model_list_service import get_vision_model_list +from app.users import current_active_user +from app.utils.rbac import check_permission + +router = APIRouter() +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Vision Model Catalogue (from OpenRouter, filtered for image-input models) +# ============================================================================= + + +class VisionModelListItem(BaseModel): + value: str + label: str + provider: str + context_window: str | None = None + + +@router.get("/vision-models", response_model=list[VisionModelListItem]) +async def list_vision_models( + user: User = Depends(current_active_user), +): + """Return vision-capable models sourced from OpenRouter (filtered by image input).""" + try: + return await get_vision_model_list() + except Exception as e: + logger.exception("Failed to fetch vision model list") + raise HTTPException( + status_code=500, detail=f"Failed to fetch vision model list: {e!s}" + ) from e + + +# ============================================================================= +# Global Vision LLM Configs (from YAML) +# ============================================================================= + + +@router.get( + "/global-vision-llm-configs", + response_model=list[GlobalVisionLLMConfigRead], +) +async def get_global_vision_llm_configs( + user: User = Depends(current_active_user), +): + try: + global_configs = config.GLOBAL_VISION_LLM_CONFIGS + safe_configs = [] + + if global_configs and len(global_configs) > 0: + safe_configs.append( + { + "id": 0, + "name": "Auto (Fastest)", + "description": "Automatically routes across available vision LLM providers.", + "provider": "AUTO", + "custom_provider": None, + "model_name": "auto", + "api_base": None, + "api_version": None, + "litellm_params": {}, + "is_global": True, + "is_auto_mode": True, + # Auto mode treated as free until per-deployment billing-tier + # surfacing lands; see ``get_vision_llm`` for parity. + "billing_tier": "free", + "is_premium": False, + } + ) + + for cfg in global_configs: + billing_tier = str(cfg.get("billing_tier", "free")).lower() + safe_configs.append( + { + "id": cfg.get("id"), + "name": cfg.get("name"), + "description": cfg.get("description"), + "provider": cfg.get("provider"), + "custom_provider": cfg.get("custom_provider"), + "model_name": cfg.get("model_name"), + "api_base": cfg.get("api_base") or None, + "api_version": cfg.get("api_version") or None, + "litellm_params": cfg.get("litellm_params", {}), + "is_global": True, + "billing_tier": billing_tier, + # Mirror chat (``new_llm_config_routes``) so the new-chat + # selector's premium badge logic keys off the same + # field across chat / image / vision tabs. + "is_premium": billing_tier == "premium", + "quota_reserve_tokens": cfg.get("quota_reserve_tokens"), + "input_cost_per_token": cfg.get("input_cost_per_token"), + "output_cost_per_token": cfg.get("output_cost_per_token"), + } + ) + + return safe_configs + except Exception as e: + logger.exception("Failed to fetch global vision LLM configs") + raise HTTPException( + status_code=500, detail=f"Failed to fetch configs: {e!s}" + ) from e + + +# ============================================================================= +# VisionLLMConfig CRUD +# ============================================================================= + + +@router.post("/vision-llm-configs", response_model=VisionLLMConfigRead) +async def create_vision_llm_config( + config_data: VisionLLMConfigCreate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + try: + await check_permission( + session, + user, + config_data.search_space_id, + Permission.VISION_CONFIGS_CREATE.value, + "You don't have permission to create vision LLM configs in this search space", + ) + + db_config = VisionLLMConfig(**config_data.model_dump(), user_id=user.id) + session.add(db_config) + await session.commit() + await session.refresh(db_config) + return db_config + + except HTTPException: + raise + except Exception as e: + await session.rollback() + logger.exception("Failed to create VisionLLMConfig") + raise HTTPException( + status_code=500, detail=f"Failed to create config: {e!s}" + ) from e + + +@router.get("/vision-llm-configs", response_model=list[VisionLLMConfigRead]) +async def list_vision_llm_configs( + search_space_id: int, + skip: int = 0, + limit: int = 100, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + try: + await check_permission( + session, + user, + search_space_id, + Permission.VISION_CONFIGS_READ.value, + "You don't have permission to view vision LLM configs in this search space", + ) + + result = await session.execute( + select(VisionLLMConfig) + .filter(VisionLLMConfig.search_space_id == search_space_id) + .order_by(VisionLLMConfig.created_at.desc()) + .offset(skip) + .limit(limit) + ) + return result.scalars().all() + + except HTTPException: + raise + except Exception as e: + logger.exception("Failed to list VisionLLMConfigs") + raise HTTPException( + status_code=500, detail=f"Failed to fetch configs: {e!s}" + ) from e + + +@router.get("/vision-llm-configs/{config_id}", response_model=VisionLLMConfigRead) +async def get_vision_llm_config( + config_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + try: + result = await session.execute( + select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id) + ) + db_config = result.scalars().first() + if not db_config: + raise HTTPException(status_code=404, detail="Config not found") + + await check_permission( + session, + user, + db_config.search_space_id, + Permission.VISION_CONFIGS_READ.value, + "You don't have permission to view vision LLM configs in this search space", + ) + return db_config + + except HTTPException: + raise + except Exception as e: + logger.exception("Failed to get VisionLLMConfig") + raise HTTPException( + status_code=500, detail=f"Failed to fetch config: {e!s}" + ) from e + + +@router.put("/vision-llm-configs/{config_id}", response_model=VisionLLMConfigRead) +async def update_vision_llm_config( + config_id: int, + update_data: VisionLLMConfigUpdate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + try: + result = await session.execute( + select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id) + ) + db_config = result.scalars().first() + if not db_config: + raise HTTPException(status_code=404, detail="Config not found") + + await check_permission( + session, + user, + db_config.search_space_id, + Permission.VISION_CONFIGS_CREATE.value, + "You don't have permission to update vision LLM configs in this search space", + ) + + for key, value in update_data.model_dump(exclude_unset=True).items(): + setattr(db_config, key, value) + + await session.commit() + await session.refresh(db_config) + return db_config + + except HTTPException: + raise + except Exception as e: + await session.rollback() + logger.exception("Failed to update VisionLLMConfig") + raise HTTPException( + status_code=500, detail=f"Failed to update config: {e!s}" + ) from e + + +@router.delete("/vision-llm-configs/{config_id}", response_model=dict) +async def delete_vision_llm_config( + config_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + try: + result = await session.execute( + select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id) + ) + db_config = result.scalars().first() + if not db_config: + raise HTTPException(status_code=404, detail="Config not found") + + await check_permission( + session, + user, + db_config.search_space_id, + Permission.VISION_CONFIGS_DELETE.value, + "You don't have permission to delete vision LLM configs in this search space", + ) + + await session.delete(db_config) + await session.commit() + return { + "message": "Vision LLM config deleted successfully", + "id": config_id, + } + + except HTTPException: + raise + except Exception as e: + await session.rollback() + logger.exception("Failed to delete VisionLLMConfig") + raise HTTPException( + status_code=500, detail=f"Failed to delete config: {e!s}" + ) from e diff --git a/surfsense_backend/app/schemas/__init__.py b/surfsense_backend/app/schemas/__init__.py index 7b508a132..fdf34672b 100644 --- a/surfsense_backend/app/schemas/__init__.py +++ b/surfsense_backend/app/schemas/__init__.py @@ -34,27 +34,16 @@ from .folders import ( ) from .google_drive import DriveItem, GoogleDriveIndexingOptions, GoogleDriveIndexRequest from .image_generation import ( + GlobalImageGenConfigRead, + ImageGenerationConfigCreate, + ImageGenerationConfigPublic, + ImageGenerationConfigRead, + ImageGenerationConfigUpdate, ImageGenerationCreate, ImageGenerationListRead, ImageGenerationRead, ) from .logs import LogBase, LogCreate, LogFilter, LogRead, LogUpdate -from .model_connections import ( - ConnectionCreate, - ConnectionRead, - ConnectionUpdate, - ModelCreate, - ModelPreviewRead, - ModelProviderRead, - ModelRead, - ModelRolesRead, - ModelRolesUpdate, - ModelsBulkUpdate, - ModelSelection, - ModelTestPreview, - ModelUpdate, - VerifyConnectionResponse, -) from .new_chat import ( ChatMessage, NewChatMessageAppend, @@ -69,6 +58,17 @@ from .new_chat import ( ThreadListItem, ThreadListResponse, ) +from .new_llm_config import ( + DefaultSystemInstructionsResponse, + GlobalNewLLMConfigRead, + LLMPreferencesRead, + LLMPreferencesUpdate, + NewLLMConfigCreate, + NewLLMConfigPublic, + NewLLMConfigRead, + NewLLMConfigUpdate, +) +from .podcasts import PodcastBase, PodcastCreate, PodcastRead, PodcastUpdate from .rbac_schemas import ( InviteAcceptRequest, InviteAcceptResponse, @@ -111,13 +111,11 @@ from .search_space import ( SearchSpaceWithStats, ) from .stripe import ( - CreateCreditCheckoutSessionRequest, - CreateCreditCheckoutSessionResponse, - CreditPurchaseHistoryResponse, - CreditPurchaseRead, - CreditStripeStatusResponse, + CreateCheckoutSessionRequest, + CreateCheckoutSessionResponse, PagePurchaseHistoryResponse, PagePurchaseRead, + StripeStatusResponse, StripeWebhookResponse, ) from .users import UserCreate, UserRead, UserUpdate @@ -127,6 +125,13 @@ from .video_presentations import ( VideoPresentationRead, VideoPresentationUpdate, ) +from .vision_llm import ( + GlobalVisionLLMConfigRead, + VisionLLMConfigCreate, + VisionLLMConfigPublic, + VisionLLMConfigRead, + VisionLLMConfigUpdate, +) __all__ = [ # Folder schemas @@ -138,15 +143,9 @@ __all__ = [ "ChunkCreate", "ChunkRead", "ChunkUpdate", - # Model connection schemas - "ConnectionCreate", - "ConnectionRead", - "ConnectionUpdate", - "CreateCreditCheckoutSessionRequest", - "CreateCreditCheckoutSessionResponse", - "CreditPurchaseHistoryResponse", - "CreditPurchaseRead", - "CreditStripeStatusResponse", + "CreateCheckoutSessionRequest", + "CreateCheckoutSessionResponse", + "DefaultSystemInstructionsResponse", # Document schemas "DocumentBase", "DocumentMove", @@ -169,10 +168,19 @@ __all__ = [ "FolderRead", "FolderReorder", "FolderUpdate", + "GlobalImageGenConfigRead", + "GlobalNewLLMConfigRead", + # Vision LLM Config schemas + "GlobalVisionLLMConfigRead", "GoogleDriveIndexRequest", "GoogleDriveIndexingOptions", # Base schemas "IDModel", + # Image Generation Config schemas + "ImageGenerationConfigCreate", + "ImageGenerationConfigPublic", + "ImageGenerationConfigRead", + "ImageGenerationConfigUpdate", # Image Generation schemas "ImageGenerationCreate", "ImageGenerationListRead", @@ -184,6 +192,9 @@ __all__ = [ "InviteInfoResponse", "InviteRead", "InviteUpdate", + # LLM Preferences schemas + "LLMPreferencesRead", + "LLMPreferencesUpdate", # Log schemas "LogBase", "LogCreate", @@ -202,16 +213,6 @@ __all__ = [ "MembershipRead", "MembershipReadWithUser", "MembershipUpdate", - "ModelCreate", - "ModelPreviewRead", - "ModelProviderRead", - "ModelRead", - "ModelRolesRead", - "ModelRolesUpdate", - "ModelSelection", - "ModelTestPreview", - "ModelUpdate", - "ModelsBulkUpdate", "NewChatMessageAppend", "NewChatMessageCreate", "NewChatMessageRead", @@ -220,12 +221,21 @@ __all__ = [ "NewChatThreadRead", "NewChatThreadUpdate", "NewChatThreadWithMessages", + # NewLLMConfig schemas + "NewLLMConfigCreate", + "NewLLMConfigPublic", + "NewLLMConfigRead", + "NewLLMConfigUpdate", "PagePurchaseHistoryResponse", "PagePurchaseRead", "PaginatedResponse", "PermissionInfo", "PermissionsListResponse", # Podcast schemas + "PodcastBase", + "PodcastCreate", + "PodcastRead", + "PodcastUpdate", "RefreshTokenRequest", "RefreshTokenResponse", # Report schemas @@ -247,6 +257,7 @@ __all__ = [ "SearchSpaceRead", "SearchSpaceUpdate", "SearchSpaceWithStats", + "StripeStatusResponse", "StripeWebhookResponse", "ThreadHistoryLoadResponse", "ThreadListItem", @@ -257,10 +268,13 @@ __all__ = [ "UserRead", "UserSearchSpaceAccess", "UserUpdate", - "VerifyConnectionResponse", # Video Presentation schemas "VideoPresentationBase", "VideoPresentationCreate", "VideoPresentationRead", "VideoPresentationUpdate", + "VisionLLMConfigCreate", + "VisionLLMConfigPublic", + "VisionLLMConfigRead", + "VisionLLMConfigUpdate", ] diff --git a/surfsense_backend/app/schemas/image_generation.py b/surfsense_backend/app/schemas/image_generation.py index ebd0fa0ac..4262b2b3f 100644 --- a/surfsense_backend/app/schemas/image_generation.py +++ b/surfsense_backend/app/schemas/image_generation.py @@ -1,10 +1,109 @@ -"""Pydantic schemas for image generation requests/results.""" +""" +Pydantic schemas for Image Generation configs and generation requests. +ImageGenerationConfig: CRUD schemas for user-created image gen model configs. +ImageGeneration: Schemas for the actual image generation requests/results. +GlobalImageGenConfigRead: Schema for admin-configured YAML configs. +""" + +import uuid from datetime import datetime from typing import Any from pydantic import BaseModel, ConfigDict, Field +from app.db import ImageGenProvider + +# ============================================================================= +# ImageGenerationConfig CRUD Schemas +# ============================================================================= + + +class ImageGenerationConfigBase(BaseModel): + """Base schema with fields for ImageGenerationConfig.""" + + name: str = Field( + ..., max_length=100, description="User-friendly name for the config" + ) + description: str | None = Field( + None, max_length=500, description="Optional description" + ) + provider: ImageGenProvider = Field( + ..., + description="Image generation provider (OpenAI, Azure, Google AI Studio, Vertex AI, Bedrock, Recraft, OpenRouter, Xinference, Nscale)", + ) + custom_provider: str | None = Field( + None, max_length=100, description="Custom provider name" + ) + model_name: str = Field( + ..., max_length=100, description="Model name (e.g., dall-e-3, gpt-image-1)" + ) + api_key: str = Field(..., description="API key for the provider") + api_base: str | None = Field( + None, max_length=500, description="Optional API base URL" + ) + api_version: str | None = Field( + None, + max_length=50, + description="Azure-specific API version (e.g., '2024-02-15-preview')", + ) + litellm_params: dict[str, Any] | None = Field( + default=None, description="Additional LiteLLM parameters" + ) + + +class ImageGenerationConfigCreate(ImageGenerationConfigBase): + """Schema for creating a new ImageGenerationConfig.""" + + search_space_id: int = Field( + ..., description="Search space ID to associate the config with" + ) + + +class ImageGenerationConfigUpdate(BaseModel): + """Schema for updating an existing ImageGenerationConfig. All fields optional.""" + + name: str | None = Field(None, max_length=100) + description: str | None = Field(None, max_length=500) + provider: ImageGenProvider | None = None + custom_provider: str | None = Field(None, max_length=100) + model_name: str | None = Field(None, max_length=100) + api_key: str | None = None + api_base: str | None = Field(None, max_length=500) + api_version: str | None = Field(None, max_length=50) + litellm_params: dict[str, Any] | None = None + + +class ImageGenerationConfigRead(ImageGenerationConfigBase): + """Schema for reading an ImageGenerationConfig (includes id and timestamps).""" + + id: int + created_at: datetime + search_space_id: int + user_id: uuid.UUID + + model_config = ConfigDict(from_attributes=True) + + +class ImageGenerationConfigPublic(BaseModel): + """Public schema that hides the API key (for list views).""" + + id: int + name: str + description: str | None = None + provider: ImageGenProvider + custom_provider: str | None = None + model_name: str + api_base: str | None = None + api_version: str | None = None + litellm_params: dict[str, Any] | None = None + created_at: datetime + search_space_id: int + user_id: uuid.UUID + + model_config = ConfigDict(from_attributes=True) + + # ============================================================================= # ImageGeneration (request/result) Schemas # ============================================================================= @@ -37,12 +136,12 @@ class ImageGenerationCreate(BaseModel): search_space_id: int = Field( ..., description="Search space ID to associate the generation with" ) - image_gen_model_id: int | None = Field( + image_generation_config_id: int | None = Field( None, description=( - "Image generation model ID. " - "0 = Auto mode, negative = GLOBAL model, positive = BYOK Model row. " - "If not provided, uses the search space's image_gen_model_id preference." + "Image generation config ID. " + "0 = Auto mode (router), negative = global YAML config, positive = DB config. " + "If not provided, uses the search space's image_generation_config_id preference." ), ) @@ -58,7 +157,7 @@ class ImageGenerationRead(BaseModel): size: str | None = None style: str | None = None response_format: str | None = None - image_gen_model_id: int | None = None + image_generation_config_id: int | None = None response_data: dict[str, Any] | None = None error_message: str | None = None search_space_id: int @@ -104,3 +203,58 @@ class ImageGenerationListRead(BaseModel): is_success=obj.response_data is not None, image_count=image_count, ) + + +# ============================================================================= +# Global Image Gen Config (from YAML) +# ============================================================================= + + +class GlobalImageGenConfigRead(BaseModel): + """ + Schema for reading global image generation configs from YAML. + Global configs have negative IDs. API key is hidden. + ID 0 is reserved for Auto mode (LiteLLM Router load balancing). + + The ``billing_tier`` field allows the frontend to show a Premium/Free + badge and (more importantly) tells the backend whether to debit the + user's premium credit pool when this config is used. ``"free"`` is + the default for backward compatibility — admins must explicitly opt + a global config into ``"premium"``. + """ + + id: int = Field( + ..., + description="Config ID: 0 for Auto mode, negative for global configs", + ) + name: str + description: str | None = None + provider: str + custom_provider: str | None = None + model_name: str + api_base: str | None = None + api_version: str | None = None + litellm_params: dict[str, Any] | None = None + is_global: bool = True + is_auto_mode: bool = False + billing_tier: str = Field( + default="free", + description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).", + ) + is_premium: bool = Field( + default=False, + description=( + "Convenience boolean derived server-side from " + "``billing_tier == 'premium'``. The new-chat model selector " + "keys its Free/Premium badge off this field for parity with " + "chat (`GlobalLLMConfigRead.is_premium`)." + ), + ) + quota_reserve_micros: int | None = Field( + default=None, + description=( + "Optional override for the reservation amount (in micro-USD) used when " + "this image generation is premium. Falls back to " + "QUOTA_DEFAULT_IMAGE_RESERVE_MICROS when omitted." + ), + ) diff --git a/surfsense_backend/app/schemas/incentive_tasks.py b/surfsense_backend/app/schemas/incentive_tasks.py index 7b9b39cd1..52c2a5182 100644 --- a/surfsense_backend/app/schemas/incentive_tasks.py +++ b/surfsense_backend/app/schemas/incentive_tasks.py @@ -15,8 +15,7 @@ class IncentiveTaskInfo(BaseModel): task_type: IncentiveTaskType title: str description: str - # Credit reward in USD micro-units (1_000_000 == $1.00). - credit_micros_reward: int + pages_reward: int action_url: str completed: bool completed_at: datetime | None = None @@ -26,7 +25,7 @@ class IncentiveTasksResponse(BaseModel): """Response containing all available incentive tasks with completion status.""" tasks: list[IncentiveTaskInfo] - total_credit_micros_earned: int + total_pages_earned: int class CompleteTaskRequest(BaseModel): @@ -40,8 +39,8 @@ class CompleteTaskResponse(BaseModel): success: bool message: str - credit_micros_awarded: int - new_balance_micros: int + pages_awarded: int + new_pages_limit: int class TaskAlreadyCompletedResponse(BaseModel): diff --git a/surfsense_backend/app/schemas/model_connections.py b/surfsense_backend/app/schemas/model_connections.py deleted file mode 100644 index 0eec666c1..000000000 --- a/surfsense_backend/app/schemas/model_connections.py +++ /dev/null @@ -1,148 +0,0 @@ -import uuid -from datetime import datetime -from typing import Any - -from pydantic import BaseModel, ConfigDict, Field - -from app.db import ConnectionScope, ModelSource - - -class ModelRead(BaseModel): - id: int - connection_id: int - model_id: str - display_name: str | None = None - source: ModelSource | str - supports_chat: bool | None = None - max_input_tokens: int | None = None - supports_image_input: bool | None = None - supports_tools: bool | None = None - supports_image_generation: bool | None = None - capabilities_override: dict[str, Any] = Field(default_factory=dict) - enabled: bool - billing_tier: str | None = None - catalog: dict[str, Any] = Field(default_factory=dict) - created_at: datetime | None = None - - model_config = ConfigDict(from_attributes=True) - - -class ConnectionRead(BaseModel): - id: int - provider: str - base_url: str | None = None - api_key: str | None = None - extra: dict[str, Any] = Field(default_factory=dict) - scope: ConnectionScope | str - search_space_id: int | None = None - user_id: uuid.UUID | None = None - enabled: bool - has_api_key: bool - models: list[ModelRead] = Field(default_factory=list) - created_at: datetime | None = None - - model_config = ConfigDict(from_attributes=True) - - -class ModelSelection(BaseModel): - model_id: str = Field(..., max_length=255) - display_name: str | None = Field(None, max_length=255) - source: ModelSource | str = ModelSource.DISCOVERED - supports_chat: bool | None = None - max_input_tokens: int | None = None - supports_image_input: bool | None = None - supports_tools: bool | None = None - supports_image_generation: bool | None = None - enabled: bool = False - metadata: dict[str, Any] = Field(default_factory=dict) - - -class ModelPreviewRead(BaseModel): - model_id: str - display_name: str | None = None - source: ModelSource | str = ModelSource.DISCOVERED - supports_chat: bool | None = None - max_input_tokens: int | None = None - supports_image_input: bool | None = None - supports_tools: bool | None = None - supports_image_generation: bool | None = None - enabled: bool = False - metadata: dict[str, Any] = Field(default_factory=dict) - - -class ConnectionCreate(BaseModel): - provider: str = Field(..., max_length=100) - base_url: str | None = Field(None, max_length=500) - api_key: str | None = None - extra: dict[str, Any] = Field(default_factory=dict) - scope: ConnectionScope = ConnectionScope.SEARCH_SPACE - search_space_id: int | None = None - enabled: bool = True - models: list[ModelSelection] = Field(default_factory=list) - - -class ModelTestPreview(ConnectionCreate): - model_id: str = Field(..., max_length=255) - - -class ConnectionUpdate(BaseModel): - provider: str | None = Field(None, max_length=100) - base_url: str | None = Field(None, max_length=500) - api_key: str | None = None - extra: dict[str, Any] | None = None - enabled: bool | None = None - - -class ModelCreate(BaseModel): - """Manually register a model id on a connection. - - For providers without a usable ``/models`` endpoint (Perplexity, MiniMax, - Azure deployments, etc.) or to pin a single model from a noisy provider. - """ - - model_id: str = Field(..., max_length=255) - display_name: str | None = Field(None, max_length=255) - - -class ModelUpdate(BaseModel): - display_name: str | None = Field(None, max_length=255) - enabled: bool | None = None - supports_chat: bool | None = None - max_input_tokens: int | None = None - supports_image_input: bool | None = None - supports_tools: bool | None = None - supports_image_generation: bool | None = None - capabilities_override: dict[str, Any] | None = None - - -class ModelsBulkUpdate(BaseModel): - model_ids: list[int] = Field(..., min_length=1, max_length=1000) - enabled: bool - - -class ModelProviderRead(BaseModel): - provider: str - transport: str - discovery: str - default_base_url: str | None = None - base_url_required: bool - auth_style: str - local_only: bool = False - - -class VerifyConnectionResponse(BaseModel): - status: str - ok: bool - message: str = "" - - -class ModelRolesRead(BaseModel): - chat_model_id: int | None = 0 - vision_model_id: int | None = 0 - image_gen_model_id: int | None = 0 - - -class ModelRolesUpdate(BaseModel): - chat_model_id: int | None = None - vision_model_id: int | None = None - image_gen_model_id: int | None = None diff --git a/surfsense_backend/app/schemas/new_llm_config.py b/surfsense_backend/app/schemas/new_llm_config.py new file mode 100644 index 000000000..716aa0457 --- /dev/null +++ b/surfsense_backend/app/schemas/new_llm_config.py @@ -0,0 +1,256 @@ +""" +Pydantic schemas for the NewLLMConfig API. + +NewLLMConfig combines model settings with prompt configuration: +- LLM provider, model, API key, etc. +- Configurable system instructions +- Citation toggle +""" + +import uuid +from datetime import datetime +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + +from app.db import LiteLLMProvider + + +class NewLLMConfigBase(BaseModel): + """Base schema with common fields for NewLLMConfig.""" + + name: str = Field( + ..., max_length=100, description="User-friendly name for the configuration" + ) + description: str | None = Field( + None, max_length=500, description="Optional description" + ) + + # Model Configuration + provider: LiteLLMProvider = Field(..., description="LiteLLM provider type") + custom_provider: str | None = Field( + None, max_length=100, description="Custom provider name when provider is CUSTOM" + ) + model_name: str = Field( + ..., max_length=100, description="Model name without provider prefix" + ) + api_key: str = Field(..., description="API key for the provider") + api_base: str | None = Field( + None, max_length=500, description="Optional API base URL" + ) + litellm_params: dict[str, Any] | None = Field( + default=None, description="Additional LiteLLM parameters" + ) + + # Prompt Configuration + system_instructions: str = Field( + default="", + description="Custom system instructions. Empty string uses default SURFSENSE_SYSTEM_INSTRUCTIONS.", + ) + use_default_system_instructions: bool = Field( + default=True, + description="Whether to use default instructions when system_instructions is empty", + ) + citations_enabled: bool = Field( + default=True, + description="Whether to include citation instructions in the system prompt", + ) + + +class NewLLMConfigCreate(NewLLMConfigBase): + """Schema for creating a new NewLLMConfig.""" + + search_space_id: int = Field( + ..., description="Search space ID to associate the config with" + ) + + +class NewLLMConfigUpdate(BaseModel): + """Schema for updating an existing NewLLMConfig. All fields are optional.""" + + name: str | None = Field(None, max_length=100) + description: str | None = Field(None, max_length=500) + + # Model Configuration + provider: LiteLLMProvider | None = None + custom_provider: str | None = Field(None, max_length=100) + model_name: str | None = Field(None, max_length=100) + api_key: str | None = None + api_base: str | None = Field(None, max_length=500) + litellm_params: dict[str, Any] | None = None + + # Prompt Configuration + system_instructions: str | None = None + use_default_system_instructions: bool | None = None + citations_enabled: bool | None = None + + +class NewLLMConfigRead(NewLLMConfigBase): + """Schema for reading a NewLLMConfig (includes id and timestamps).""" + + id: int + created_at: datetime + search_space_id: int + user_id: uuid.UUID + # Capability flag derived at the API boundary (no DB column). Default + # True matches the conservative-allow stance — a BYOK row that the + # route forgot to augment is not pre-judged. The streaming-task + # safety net is the only place a False actually blocks a request. + supports_image_input: bool = Field( + default=True, + description=( + "Whether the BYOK chat config can accept image inputs. Derived " + "at the route boundary from LiteLLM's authoritative model map " + "(``litellm.supports_vision``) — there is no DB column. " + "Default True is the conservative-allow stance for unknown / " + "unmapped models." + ), + ) + + model_config = ConfigDict(from_attributes=True) + + +class NewLLMConfigPublic(BaseModel): + """ + Public schema for NewLLMConfig that hides the API key. + Used when returning configs in list views or to users who shouldn't see keys. + """ + + id: int + name: str + description: str | None = None + + # Model Configuration (no api_key) + provider: LiteLLMProvider + custom_provider: str | None = None + model_name: str + api_base: str | None = None + litellm_params: dict[str, Any] | None = None + + # Prompt Configuration + system_instructions: str + use_default_system_instructions: bool + citations_enabled: bool + + created_at: datetime + search_space_id: int + user_id: uuid.UUID + # Capability flag derived at the API boundary (see NewLLMConfigRead). + supports_image_input: bool = Field( + default=True, + description=( + "Whether the BYOK chat config can accept image inputs. Derived " + "at the route boundary from LiteLLM's authoritative model map. " + "Default True is the conservative-allow stance." + ), + ) + + model_config = ConfigDict(from_attributes=True) + + +class DefaultSystemInstructionsResponse(BaseModel): + """Response schema for getting default system instructions.""" + + default_system_instructions: str = Field( + ..., description="The default SURFSENSE_SYSTEM_INSTRUCTIONS template" + ) + + +class GlobalNewLLMConfigRead(BaseModel): + """ + Schema for reading global LLM configs from YAML. + Global configs have negative IDs and no search_space_id. + API key is hidden for security. + + ID 0 is reserved for Auto mode which uses LiteLLM Router for load balancing. + """ + + id: int = Field( + ..., + description="Config ID: 0 for Auto mode, negative for global configs", + ) + name: str + description: str | None = None + + # Model Configuration (no api_key) + provider: str # String because YAML doesn't enforce enum, "AUTO" for Auto mode + custom_provider: str | None = None + model_name: str + api_base: str | None = None + litellm_params: dict[str, Any] | None = None + + # Prompt Configuration + system_instructions: str = "" + use_default_system_instructions: bool = True + citations_enabled: bool = True + + is_global: bool = True # Always true for global configs + is_auto_mode: bool = False # True only for Auto mode (ID 0) + + billing_tier: str = "free" + is_premium: bool = False + anonymous_enabled: bool = False + seo_enabled: bool = False + seo_slug: str | None = None + seo_title: str | None = None + seo_description: str | None = None + quota_reserve_tokens: int | None = None + supports_image_input: bool = Field( + default=True, + description=( + "Whether the model accepts image inputs (multimodal vision). " + "Derived server-side: OpenRouter dynamic configs use " + "``architecture.input_modalities``; YAML / BYOK use LiteLLM's " + "authoritative model map (``litellm.supports_vision``). The " + "new-chat selector hints with a 'No image' badge when this is " + "False and there are pending image attachments. The streaming " + "task fails fast only when LiteLLM *explicitly* marks a model " + "as text-only — unknown / unmapped models default-allow." + ), + ) + + +# ============================================================================= +# LLM Preferences Schemas (for role assignments) +# ============================================================================= + + +class LLMPreferencesRead(BaseModel): + """Schema for reading LLM preferences (role assignments) for a search space.""" + + agent_llm_id: int | None = Field( + None, description="ID of the LLM config to use for agent/chat tasks" + ) + image_generation_config_id: int | None = Field( + None, description="ID of the image generation config to use" + ) + vision_llm_config_id: int | None = Field( + None, + description="ID of the vision LLM config to use for vision/screenshot analysis", + ) + agent_llm: dict[str, Any] | None = Field( + None, description="Full config for agent LLM" + ) + image_generation_config: dict[str, Any] | None = Field( + None, description="Full config for image generation" + ) + vision_llm_config: dict[str, Any] | None = Field( + None, description="Full config for vision LLM" + ) + + model_config = ConfigDict(from_attributes=True) + + +class LLMPreferencesUpdate(BaseModel): + """Schema for updating LLM preferences.""" + + agent_llm_id: int | None = Field( + None, description="ID of the LLM config to use for agent/chat tasks" + ) + image_generation_config_id: int | None = Field( + None, description="ID of the image generation config to use" + ) + vision_llm_config_id: int | None = Field( + None, + description="ID of the vision LLM config to use for vision/screenshot analysis", + ) diff --git a/surfsense_backend/app/schemas/podcasts.py b/surfsense_backend/app/schemas/podcasts.py new file mode 100644 index 000000000..d41f1ca36 --- /dev/null +++ b/surfsense_backend/app/schemas/podcasts.py @@ -0,0 +1,66 @@ +"""Podcast schemas for API responses.""" + +from datetime import datetime +from enum import StrEnum +from typing import Any + +from pydantic import BaseModel + + +class PodcastStatusEnum(StrEnum): + PENDING = "pending" + GENERATING = "generating" + READY = "ready" + FAILED = "failed" + + +class PodcastBase(BaseModel): + """Base podcast schema.""" + + title: str + podcast_transcript: list[dict[str, Any]] | None = None + file_location: str | None = None + search_space_id: int + + +class PodcastCreate(PodcastBase): + """Schema for creating a podcast.""" + + pass + + +class PodcastUpdate(BaseModel): + """Schema for updating a podcast.""" + + title: str | None = None + podcast_transcript: list[dict[str, Any]] | None = None + file_location: str | None = None + + +class PodcastRead(PodcastBase): + """Schema for reading a podcast.""" + + id: int + status: PodcastStatusEnum = PodcastStatusEnum.READY + created_at: datetime + transcript_entries: int | None = None + + class Config: + from_attributes = True + + @classmethod + def from_orm_with_entries(cls, obj): + """Create PodcastRead with transcript_entries computed.""" + data = { + "id": obj.id, + "title": obj.title, + "podcast_transcript": obj.podcast_transcript, + "file_location": obj.file_location, + "search_space_id": obj.search_space_id, + "status": obj.status, + "created_at": obj.created_at, + "transcript_entries": len(obj.podcast_transcript) + if obj.podcast_transcript + else None, + } + return cls(**data) diff --git a/surfsense_backend/app/schemas/stripe.py b/surfsense_backend/app/schemas/stripe.py index 95c946a3d..ad13ddf04 100644 --- a/surfsense_backend/app/schemas/stripe.py +++ b/surfsense_backend/app/schemas/stripe.py @@ -1,4 +1,4 @@ -"""Schemas for Stripe-backed credit purchases.""" +"""Schemas for Stripe-backed page purchases.""" import uuid from datetime import datetime @@ -8,59 +8,27 @@ from pydantic import BaseModel, ConfigDict, Field from app.db import PagePurchaseStatus -class CreateCreditCheckoutSessionRequest(BaseModel): - """Request body for creating a credit-purchase checkout session.""" +class CreateCheckoutSessionRequest(BaseModel): + """Request body for creating a page-purchase checkout session.""" - quantity: int = Field(ge=1, le=10_000) + quantity: int = Field(ge=1, le=100) search_space_id: int = Field(ge=1) -class CreateCreditCheckoutSessionResponse(BaseModel): +class CreateCheckoutSessionResponse(BaseModel): """Response containing the Stripe-hosted checkout URL.""" checkout_url: str -class CreditPurchaseRead(BaseModel): - """Serialized credit purchase record. +class StripeStatusResponse(BaseModel): + """Response describing Stripe page-buying availability.""" - ``credit_micros_granted`` is in micro-USD (1_000_000 = $1.00). - """ - - id: uuid.UUID - stripe_checkout_session_id: str - stripe_payment_intent_id: str | None = None - quantity: int - credit_micros_granted: int - amount_total: int | None = None - currency: str | None = None - source: str = "checkout" - status: str - completed_at: datetime | None = None - created_at: datetime - - model_config = ConfigDict(from_attributes=True) - - -class CreditPurchaseHistoryResponse(BaseModel): - """Response containing the user's credit purchases.""" - - purchases: list[CreditPurchaseRead] - - -class CreditStripeStatusResponse(BaseModel): - """Response describing credit-buying availability and current balance. - - ``credit_micros_balance`` is in micro-USD; the FE divides by 1_000_000 - to display USD. - """ - - credit_buying_enabled: bool - credit_micros_balance: int = 0 + page_buying_enabled: bool class PagePurchaseRead(BaseModel): - """Serialized legacy page-purchase record (read-only history).""" + """Serialized page-purchase record for purchase history.""" id: uuid.UUID stripe_checkout_session_id: str @@ -77,52 +45,11 @@ class PagePurchaseRead(BaseModel): class PagePurchaseHistoryResponse(BaseModel): - """Response containing the authenticated user's legacy page purchases.""" + """Response containing the authenticated user's page purchases.""" purchases: list[PagePurchaseRead] -class AutoReloadSettingsResponse(BaseModel): - """Auto-reload configuration + saved-card state for the settings UI. - - All ``*_micros`` fields are micro-USD (1_000_000 == $1.00). ``feature_enabled`` - reflects the server-side ``AUTO_RELOAD_ENABLED`` flag; when it is false the - UI should hide / disable the auto-reload controls entirely. - """ - - feature_enabled: bool - enabled: bool = False - threshold_micros: int | None = None - amount_micros: int | None = None - min_amount_micros: int - has_payment_method: bool = False - failed_at: datetime | None = None - - -class UpdateAutoReloadSettingsRequest(BaseModel): - """Update auto-reload preferences. - - Enabling requires a saved card (set up via /stripe/auto-reload/setup) plus a - positive threshold and an amount of at least ``AUTO_RELOAD_MIN_AMOUNT_MICROS``. - """ - - enabled: bool - threshold_micros: int | None = Field(default=None, ge=0) - amount_micros: int | None = Field(default=None, ge=0) - - -class CreateAutoReloadSetupSessionRequest(BaseModel): - """Request body for starting the save-a-card (SetupIntent) checkout.""" - - search_space_id: int = Field(ge=1) - - -class CreateAutoReloadSetupSessionResponse(BaseModel): - """Response containing the Stripe-hosted setup (save-card) checkout URL.""" - - checkout_url: str - - class StripeWebhookResponse(BaseModel): """Generic acknowledgement for Stripe webhook delivery.""" @@ -139,6 +66,64 @@ class FinalizeCheckoutResponse(BaseModel): endpoint until it sees ``completed`` or a final ``failed``. """ + purchase_type: str # "page_packs" | "premium_tokens" + status: str # PagePurchaseStatus / PremiumTokenPurchaseStatus value + pages_limit: int | None = None + pages_used: int | None = None + pages_granted: int | None = None + premium_credit_micros_limit: int | None = None + premium_credit_micros_used: int | None = None + premium_credit_micros_granted: int | None = None + + +class CreateTokenCheckoutSessionRequest(BaseModel): + """Request body for creating a premium token purchase checkout session.""" + + quantity: int = Field(ge=1, le=100) + search_space_id: int = Field(ge=1) + + +class CreateTokenCheckoutSessionResponse(BaseModel): + """Response containing the Stripe-hosted checkout URL.""" + + checkout_url: str + + +class TokenPurchaseRead(BaseModel): + """Serialized premium credit purchase record. + + ``credit_micros_granted`` is in micro-USD (1_000_000 = $1.00). The + schema name kept ``Token`` for API back-compat with pinned clients. + """ + + id: uuid.UUID + stripe_checkout_session_id: str + stripe_payment_intent_id: str | None = None + quantity: int + credit_micros_granted: int + amount_total: int | None = None + currency: str | None = None status: str - credit_micros_balance: int = 0 - credit_micros_granted: int | None = None + completed_at: datetime | None = None + created_at: datetime + + model_config = ConfigDict(from_attributes=True) + + +class TokenPurchaseHistoryResponse(BaseModel): + """Response containing the user's premium credit purchases.""" + + purchases: list[TokenPurchaseRead] + + +class TokenStripeStatusResponse(BaseModel): + """Response describing premium-credit-buying availability and balance. + + All ``premium_credit_micros_*`` fields are in micro-USD; the FE + divides by 1_000_000 to display USD. + """ + + token_buying_enabled: bool + premium_credit_micros_used: int = 0 + premium_credit_micros_limit: int = 0 + premium_credit_micros_remaining: int = 0 diff --git a/surfsense_backend/app/schemas/users.py b/surfsense_backend/app/schemas/users.py index 558463f57..88d0a4f37 100644 --- a/surfsense_backend/app/schemas/users.py +++ b/surfsense_backend/app/schemas/users.py @@ -4,7 +4,8 @@ from fastapi_users import schemas class UserRead(schemas.BaseUser[uuid.UUID]): - credit_micros_balance: int + pages_limit: int + pages_used: int display_name: str | None = None avatar_url: str | None = None diff --git a/surfsense_backend/app/schemas/vision_llm.py b/surfsense_backend/app/schemas/vision_llm.py new file mode 100644 index 000000000..d0eeaf5c6 --- /dev/null +++ b/surfsense_backend/app/schemas/vision_llm.py @@ -0,0 +1,116 @@ +import uuid +from datetime import datetime +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + +from app.db import VisionProvider + + +class VisionLLMConfigBase(BaseModel): + name: str = Field(..., max_length=100) + description: str | None = Field(None, max_length=500) + provider: VisionProvider = Field(...) + custom_provider: str | None = Field(None, max_length=100) + model_name: str = Field(..., max_length=100) + api_key: str = Field(...) + api_base: str | None = Field(None, max_length=500) + api_version: str | None = Field(None, max_length=50) + litellm_params: dict[str, Any] | None = Field(default=None) + + +class VisionLLMConfigCreate(VisionLLMConfigBase): + search_space_id: int = Field(...) + + +class VisionLLMConfigUpdate(BaseModel): + name: str | None = Field(None, max_length=100) + description: str | None = Field(None, max_length=500) + provider: VisionProvider | None = None + custom_provider: str | None = Field(None, max_length=100) + model_name: str | None = Field(None, max_length=100) + api_key: str | None = None + api_base: str | None = Field(None, max_length=500) + api_version: str | None = Field(None, max_length=50) + litellm_params: dict[str, Any] | None = None + + +class VisionLLMConfigRead(VisionLLMConfigBase): + id: int + created_at: datetime + search_space_id: int + user_id: uuid.UUID + + model_config = ConfigDict(from_attributes=True) + + +class VisionLLMConfigPublic(BaseModel): + id: int + name: str + description: str | None = None + provider: VisionProvider + custom_provider: str | None = None + model_name: str + api_base: str | None = None + api_version: str | None = None + litellm_params: dict[str, Any] | None = None + created_at: datetime + search_space_id: int + user_id: uuid.UUID + + model_config = ConfigDict(from_attributes=True) + + +class GlobalVisionLLMConfigRead(BaseModel): + """Schema for reading global vision LLM configs from YAML. + + The ``billing_tier`` field allows the frontend to show a Premium/Free + badge and (more importantly) tells the backend whether to debit the + user's premium credit pool when this config is used. ``"free"`` is + the default for backward compatibility — admins must explicitly opt + a global config into ``"premium"``. + """ + + id: int = Field(...) + name: str + description: str | None = None + provider: str + custom_provider: str | None = None + model_name: str + api_base: str | None = None + api_version: str | None = None + litellm_params: dict[str, Any] | None = None + is_global: bool = True + is_auto_mode: bool = False + billing_tier: str = Field( + default="free", + description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).", + ) + is_premium: bool = Field( + default=False, + description=( + "Convenience boolean derived server-side from " + "``billing_tier == 'premium'``. The new-chat model selector " + "keys its Free/Premium badge off this field for parity with " + "chat (`GlobalLLMConfigRead.is_premium`)." + ), + ) + quota_reserve_tokens: int | None = Field( + default=None, + description=( + "Optional override for the per-call reservation in *tokens* — " + "converted to micro-USD via the model's input/output prices at " + "reservation time. Falls back to QUOTA_DEFAULT_RESERVE_TOKENS." + ), + ) + input_cost_per_token: float | None = Field( + default=None, + description=( + "Optional input price in USD/token. Used by pricing_registration to " + "register custom Azure / OpenRouter aliases with LiteLLM at startup." + ), + ) + output_cost_per_token: float | None = Field( + default=None, + description="Optional output price in USD/token. Pair with input_cost_per_token.", + ) diff --git a/surfsense_backend/app/services/ai_file_sort_service.py b/surfsense_backend/app/services/ai_file_sort_service.py index 1bf4d325e..2f04131a6 100644 --- a/surfsense_backend/app/services/ai_file_sort_service.py +++ b/surfsense_backend/app/services/ai_file_sort_service.py @@ -156,7 +156,7 @@ async def _resolve_document_text( stmt = ( select(Chunk.content) .where(Chunk.document_id == document.id) - .order_by(Chunk.position, Chunk.id) + .order_by(Chunk.id) .limit(_MAX_CHUNKS_FOR_CONTEXT) ) result = await session.execute(stmt) diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py index f98933a65..9bbca8669 100644 --- a/surfsense_backend/app/services/auto_model_pin_service.py +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -1,13 +1,13 @@ -"""Resolve and persist Auto model pins per chat thread. +"""Resolve and persist Auto (Fastest) model pins per chat thread. -Auto is represented by ``chat_model_id == 0``. For chat threads we -resolve that virtual mode to one concrete global model exactly once and +Auto (Fastest) is represented by ``agent_llm_id == 0``. For chat threads we +resolve that virtual mode to one concrete global LLM config exactly once and persist the chosen config id on ``new_chat_threads.pinned_llm_config_id`` so subsequent turns are stable. Single-writer invariant: this module is the only writer of ``NewChatThread.pinned_llm_config_id`` (aside from the bulk clear in -``model_connections_routes`` when a search space's ``chat_model_id`` changes). +``search_spaces_routes`` when a search space's ``agent_llm_id`` changes). Therefore a non-NULL value unambiguously means "this thread has an Auto-resolved pin"; no separate source/policy column is needed. """ @@ -21,35 +21,26 @@ import time from dataclasses import dataclass from uuid import UUID -import redis from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import selectinload from app.config import config -from app.db import Connection, Model, NewChatThread -from app.services.model_capabilities import has_capability +from app.db import NewChatThread from app.services.quality_score import _QUALITY_TOP_K from app.services.token_quota_service import TokenQuotaService logger = logging.getLogger(__name__) -AUTO_MODE_ID = 0 -# Stable internal hash namespace for deterministic per-thread selection. -# Do not rename: changing this rebalances Auto's model choice for new pins. -AUTO_PIN_HASH_NAMESPACE = "auto_fastest" +AUTO_FASTEST_ID = 0 +AUTO_FASTEST_MODE = "auto_fastest" _RUNTIME_COOLDOWN_SECONDS = 600 _HEALTHY_TTL_SECONDS = 45 -_RUNTIME_COOLDOWN_REDIS_KEY_PREFIX = "auto:cooldown:llm:" -_REDIS_TIMEOUT_SECONDS = 0.2 # In-memory runtime cooldown map for configs that recently hard-failed at # provider runtime (e.g. OpenRouter 429 on a pinned free model). This keeps # the same unhealthy config from being reselected immediately during repair. _runtime_cooldown_until: dict[int, float] = {} _runtime_cooldown_lock = threading.Lock() -_runtime_cooldown_redis: redis.Redis | None = None -_runtime_cooldown_redis_lock = threading.Lock() # Short-TTL "recently healthy" cache for configs that just passed a runtime # preflight ping. Lets back-to-back turns on the same model skip the probe @@ -70,15 +61,11 @@ def _is_usable_global_config(cfg: dict) -> bool: return bool( cfg.get("id") is not None and cfg.get("model_name") - and (cfg.get("provider") or cfg.get("litellm_provider")) + and cfg.get("provider") and cfg.get("api_key") ) -def _has_capability(model: dict | Model, capability: str) -> bool: - return has_capability(model, capability) - - def _prune_runtime_cooldowns(now_ts: float | None = None) -> None: now = time.time() if now_ts is None else now_ts stale = [cid for cid, until in _runtime_cooldown_until.items() if until <= now] @@ -92,81 +79,6 @@ def _is_runtime_cooled_down(config_id: int) -> bool: return config_id in _runtime_cooldown_until -def _runtime_cooldown_redis_key(config_id: int) -> str: - return f"{_RUNTIME_COOLDOWN_REDIS_KEY_PREFIX}{int(config_id)}" - - -def _get_runtime_cooldown_redis() -> redis.Redis: - global _runtime_cooldown_redis - if _runtime_cooldown_redis is None: - with _runtime_cooldown_redis_lock: - if _runtime_cooldown_redis is None: - _runtime_cooldown_redis = redis.from_url( - config.REDIS_APP_URL, - decode_responses=True, - socket_connect_timeout=_REDIS_TIMEOUT_SECONDS, - socket_timeout=_REDIS_TIMEOUT_SECONDS, - ) - return _runtime_cooldown_redis - - -def _mark_shared_runtime_cooldown( - config_id: int, - *, - reason: str, - cooldown_seconds: int, -) -> None: - try: - _get_runtime_cooldown_redis().set( - _runtime_cooldown_redis_key(config_id), - reason, - ex=int(cooldown_seconds), - ) - except Exception: - logger.warning( - "auto_pin_runtime_cooldown_redis_write_failed config_id=%s", - config_id, - exc_info=True, - ) - - -def _shared_runtime_cooled_down_ids(config_ids: list[int]) -> set[int]: - unique_ids = list(dict.fromkeys(int(cid) for cid in config_ids)) - if not unique_ids: - return set() - try: - values = _get_runtime_cooldown_redis().mget( - [_runtime_cooldown_redis_key(cid) for cid in unique_ids] - ) - except Exception: - logger.warning( - "auto_pin_runtime_cooldown_redis_read_failed count=%s", - len(unique_ids), - exc_info=True, - ) - return set() - return { - cid for cid, value in zip(unique_ids, values, strict=False) if value is not None - } - - -def _clear_shared_runtime_cooldown(config_id: int | None = None) -> None: - try: - client = _get_runtime_cooldown_redis() - if config_id is not None: - client.delete(_runtime_cooldown_redis_key(config_id)) - return - keys = list(client.scan_iter(f"{_RUNTIME_COOLDOWN_REDIS_KEY_PREFIX}*")) - if keys: - client.delete(*keys) - except Exception: - logger.warning( - "auto_pin_runtime_cooldown_redis_clear_failed config_id=%s", - config_id, - exc_info=True, - ) - - def mark_runtime_cooldown( config_id: int, *, @@ -185,11 +97,6 @@ def mark_runtime_cooldown( with _runtime_cooldown_lock: _runtime_cooldown_until[int(config_id)] = until _prune_runtime_cooldowns() - _mark_shared_runtime_cooldown( - int(config_id), - reason=reason, - cooldown_seconds=int(cooldown_seconds), - ) # A cooled cfg can never be "recently healthy"; drop any stale credit so # the next turn that resolves to it (after cooldown) re-runs preflight. clear_healthy(int(config_id)) @@ -206,9 +113,8 @@ def clear_runtime_cooldown(config_id: int | None = None) -> None: with _runtime_cooldown_lock: if config_id is None: _runtime_cooldown_until.clear() - else: - _runtime_cooldown_until.pop(int(config_id), None) - _clear_shared_runtime_cooldown(config_id) + return + _runtime_cooldown_until.pop(int(config_id), None) def _prune_healthy(now_ts: float | None = None) -> None: @@ -280,20 +186,15 @@ def _cfg_supports_image_input(cfg: dict) -> bool: else None ) return derive_supports_image_input( - provider=cfg.get("provider") or cfg.get("litellm_provider"), + provider=cfg.get("provider"), model_name=cfg.get("model_name"), base_model=base_model, custom_provider=cfg.get("custom_provider"), ) -def _global_candidates( - *, - capability: str = "chat", - requires_image_input: bool = False, - shared_cooled_down_ids: set[int] | None = None, -) -> list[dict]: - """Return Auto-eligible global virtual models. +def _global_candidates(*, requires_image_input: bool = False) -> list[dict]: + """Return Auto-eligible global cfgs. Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers @@ -304,167 +205,30 @@ def _global_candidates( filters out configs whose ``supports_image_input`` resolves to False so a text-only deployment can't be pinned for an image request. """ - connection_by_id = { - int(conn.get("id")): conn - for conn in config.GLOBAL_CONNECTIONS - if conn.get("id") is not None - } - config_by_model_name = { - cfg.get("model_name"): cfg + candidates = [ + cfg for cfg in config.GLOBAL_LLM_CONFIGS if _is_usable_global_config(cfg) - } - candidates: list[dict] = [] - shared_cooled_down_ids = shared_cooled_down_ids or set() - for model in config.GLOBAL_MODELS: - model_id = int(model.get("id", 0)) - if ( - model_id >= 0 - or _is_runtime_cooled_down(model_id) - or model_id in shared_cooled_down_ids - ): - continue - if not _has_capability(model, capability): - continue - cfg = config_by_model_name.get(model.get("model_id")) or {} - if cfg.get("health_gated"): - continue - if requires_image_input and not _has_capability(model, "vision"): - continue - if requires_image_input and cfg and not _cfg_supports_image_input(cfg): - continue - connection = connection_by_id.get(int(model.get("connection_id", 0))) - if not connection: - continue - catalog = model.get("catalog") or {} - candidates.append( - { - "id": model_id, - "model_id": model.get("model_id"), - "source": "global", - "connection": connection, - "supports_chat": model.get("supports_chat"), - "supports_image_input": model.get("supports_image_input"), - "supports_tools": model.get("supports_tools"), - "supports_image_generation": model.get("supports_image_generation"), - "capabilities_override": model.get("capabilities_override") or {}, - "billing_tier": model.get("billing_tier", "free"), - "provider": connection.get("provider"), - "model_name": model.get("model_id"), - "auto_pin_tier": catalog.get("auto_pin_tier") - or cfg.get("auto_pin_tier") - or "A", - "quality_score": catalog.get("quality_score") - or cfg.get("quality_score") - or cfg.get("quality_score_static") - or 50, - } - ) - return sorted(candidates, key=lambda c: int(c.get("id", 0))) - - -async def _db_candidates( - session: AsyncSession, - *, - search_space_id: int, - user_id: str | UUID | None, - capability: str, - requires_image_input: bool = False, -) -> list[dict]: - parsed_user_id = _to_uuid(user_id) - stmt = ( - select(Model) - .options(selectinload(Model.connection)) - .join(Connection, Model.connection_id == Connection.id) - .where(Model.enabled.is_(True), Connection.enabled.is_(True)) - ) - result = await session.execute(stmt) - models = result.scalars().all() - shared_cooled_down_ids = _shared_runtime_cooled_down_ids( - [int(model.id) for model in models] - ) - candidates: list[dict] = [] - for model in models: - conn = model.connection - if not conn: - continue - if conn.search_space_id is not None and conn.search_space_id != search_space_id: - continue - if ( - conn.user_id is not None - and parsed_user_id is not None - and conn.user_id != parsed_user_id - ): - continue - if conn.user_id is not None and parsed_user_id is None: - continue - if not _has_capability(model, capability): - continue - if requires_image_input and not _has_capability(model, "vision"): - continue - model_id = int(model.id) - if _is_runtime_cooled_down(model_id) or model_id in shared_cooled_down_ids: - continue - catalog = model.catalog or {} - candidates.append( - { - "id": model_id, - "model_id": model.model_id, - "source": "db", - "connection": conn, - "supports_chat": model.supports_chat, - "supports_image_input": model.supports_image_input, - "supports_tools": model.supports_tools, - "supports_image_generation": model.supports_image_generation, - "capabilities_override": model.capabilities_override or {}, - "billing_tier": "byok", - "provider": conn.provider, - "model_name": model.model_id, - "auto_pin_tier": catalog.get("auto_pin_tier") or "BYOK", - "quality_score": catalog.get("quality_score") or 75, - } - ) - return sorted(candidates, key=lambda c: int(c.get("id", 0))) - - -async def auto_model_candidates( - session: AsyncSession, - *, - search_space_id: int, - user_id: str | UUID | None, - capability: str, - requires_image_input: bool = False, - exclude_model_ids: set[int] | None = None, -) -> list[dict]: - excluded_ids = {int(mid) for mid in (exclude_model_ids or set())} - global_ids = [ - int(model.get("id", 0)) - for model in config.GLOBAL_MODELS - if int(model.get("id", 0)) < 0 + and not cfg.get("health_gated") + and not _is_runtime_cooled_down(int(cfg.get("id", 0))) + and (not requires_image_input or _cfg_supports_image_input(cfg)) ] - shared_global_cooled_down_ids = _shared_runtime_cooled_down_ids(global_ids) - db_candidates = await _db_candidates( - session, - search_space_id=search_space_id, - user_id=user_id, - capability=capability, - requires_image_input=requires_image_input, - ) - candidates = [ - *_global_candidates( - capability=capability, - requires_image_input=requires_image_input, - shared_cooled_down_ids=shared_global_cooled_down_ids, - ), - *db_candidates, - ] - return [c for c in candidates if int(c.get("id", 0)) not in excluded_ids] + return sorted(candidates, key=lambda c: int(c.get("id", 0))) def _tier_of(cfg: dict) -> str: return str(cfg.get("billing_tier", "free")).lower() +def _is_preferred_premium_auto_config(cfg: dict) -> bool: + """Return True for the operator-preferred premium Auto model.""" + return ( + _tier_of(cfg) == "premium" + and str(cfg.get("provider", "")).upper() == "AZURE_OPENAI" + and str(cfg.get("model_name", "")).lower() == "gpt-5.4" + ) + + def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]: """Pick a config with quality-first ranking + deterministic spread. @@ -482,16 +246,11 @@ def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]: pool = tier_a if tier_a else eligible pool = sorted(pool, key=lambda c: -int(c.get("quality_score") or 0)) top_k = pool[:_QUALITY_TOP_K] - digest = hashlib.sha256(f"{AUTO_PIN_HASH_NAMESPACE}:{thread_id}".encode()).digest() + digest = hashlib.sha256(f"{AUTO_FASTEST_MODE}:{thread_id}".encode()).digest() idx = int.from_bytes(digest[:8], "big") % len(top_k) return top_k[idx], len(top_k) -def choose_auto_model_candidate(candidates: list[dict], seed_id: int) -> dict: - selected, _ = _select_pin(candidates, seed_id) - return selected - - def _to_uuid(user_id: str | UUID | None) -> UUID | None: if user_id is None: return None @@ -509,7 +268,7 @@ async def _is_premium_eligible( parsed = _to_uuid(user_id) if parsed is None: return False - usage = await TokenQuotaService.credit_get_usage(session, parsed) + usage = await TokenQuotaService.premium_get_usage(session, parsed) return bool(usage.allowed) @@ -524,7 +283,7 @@ async def resolve_or_get_pinned_llm_config_id( exclude_config_ids: set[int] | None = None, requires_image_input: bool = False, ) -> AutoPinResolution: - """Resolve Auto to one concrete config id and persist the pin. + """Resolve Auto (Fastest) to one concrete config id and persist the pin. For non-auto selections, this function clears any existing pin and returns the selected id as-is. @@ -556,7 +315,7 @@ async def resolve_or_get_pinned_llm_config_id( ) # Explicit model selected: clear any stale pin. - if selected_llm_config_id != AUTO_MODE_ID: + if selected_llm_config_id != AUTO_FASTEST_ID: if thread.pinned_llm_config_id is not None: thread.pinned_llm_config_id = None await session.commit() @@ -567,21 +326,20 @@ async def resolve_or_get_pinned_llm_config_id( ) excluded_ids = {int(cid) for cid in (exclude_config_ids or set())} - candidates = await auto_model_candidates( - session, - search_space_id=search_space_id, - user_id=user_id, - capability="chat", - requires_image_input=requires_image_input, - exclude_model_ids=excluded_ids, - ) + candidates = [ + c + for c in _global_candidates(requires_image_input=requires_image_input) + if int(c.get("id", 0)) not in excluded_ids + ] if not candidates: if requires_image_input: # Distinguish the "no vision-capable cfg" case from generic # "no usable cfg" so the streaming task can map this to the # MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT SSE error. - raise ValueError("No vision-capable LLM models are available for Auto mode") - raise ValueError("No usable LLM models are available for Auto mode") + raise ValueError( + "No vision-capable global LLM configs are available for Auto mode" + ) + raise ValueError("No usable global LLM configs are available for Auto mode") candidate_by_id = {int(c["id"]): c for c in candidates} # Reuse an existing valid pin without re-checking current quota (no silent @@ -621,13 +379,24 @@ async def resolve_or_get_pinned_llm_config_id( # log that explicitly so operators can correlate the re-pin with # the user's image attachment instead of suspecting a cooldown. if requires_image_input: - logger.info( - "auto_pin_repinned_for_image thread_id=%s search_space_id=%s " - "previous_config_id=%s", - thread_id, - search_space_id, - pinned_id, - ) + try: + pinned_global = next( + c + for c in config.GLOBAL_LLM_CONFIGS + if int(c.get("id", 0)) == int(pinned_id) + ) + except StopIteration: + pinned_global = None + if pinned_global is not None and not _cfg_supports_image_input( + pinned_global + ): + logger.info( + "auto_pin_repinned_for_image thread_id=%s search_space_id=%s " + "previous_config_id=%s", + thread_id, + search_space_id, + pinned_id, + ) logger.info( "auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s", thread_id, @@ -638,10 +407,12 @@ async def resolve_or_get_pinned_llm_config_id( premium_eligible = ( False if force_repin_free else await _is_premium_eligible(session, user_id) ) - byok_candidates = [c for c in candidates if _tier_of(c) == "byok"] if premium_eligible: premium_candidates = [c for c in candidates if _tier_of(c) == "premium"] - eligible = premium_candidates or byok_candidates + preferred_premium = [ + c for c in premium_candidates if _is_preferred_premium_auto_config(c) + ] + eligible = preferred_premium or premium_candidates else: eligible = [c for c in candidates if _tier_of(c) != "premium"] diff --git a/surfsense_backend/app/services/auto_reload_service.py b/surfsense_backend/app/services/auto_reload_service.py deleted file mode 100644 index 9f5114a56..000000000 --- a/surfsense_backend/app/services/auto_reload_service.py +++ /dev/null @@ -1,99 +0,0 @@ -"""Debit-triggered credit auto-reload. - -``maybe_trigger_auto_reload`` is a cheap, best-effort pre-filter invoked after -every credit debit (ETL ``charge_credits`` and premium ``credit_finalize``). -When the wallet drops below the user's configured threshold it enqueues the -Celery task that performs the authoritative re-check and the off-session Stripe -charge. All real safety (row lock, cooldown, Stripe idempotency) lives in the -task — this function only avoids enqueuing work that obviously isn't needed. - -Everything here is gated behind ``config.AUTO_RELOAD_ENABLED``; when the flag is -off this module is inert. -""" - -from __future__ import annotations - -import logging -from datetime import UTC, datetime, timedelta - -from sqlalchemy import select - -from app.config import config - -logger = logging.getLogger(__name__) - - -async def maybe_trigger_auto_reload(user_id: str) -> None: - """Enqueue an auto-reload charge if the user's balance fell below threshold. - - Best-effort: any failure is swallowed by the caller. Opens its own - short-lived session so it never interferes with the caller's transaction - (it always runs after the caller has already committed the debit). - """ - if not config.AUTO_RELOAD_ENABLED: - return - - from app.db import CreditPurchase, CreditPurchaseStatus, User, async_session_maker - - async with async_session_maker() as session: - user = ( - (await session.execute(select(User).where(User.id == user_id))) - .unique() - .scalar_one_or_none() - ) - if user is None or not user.auto_reload_enabled: - return - - if not (user.stripe_customer_id and user.auto_reload_payment_method_id): - return - - threshold = user.auto_reload_threshold_micros - amount = user.auto_reload_amount_micros - if not threshold or not amount: - return - - available = user.credit_micros_balance - user.credit_micros_reserved - if available >= threshold: - return - - # Cheap cooldown pre-check: skip if a recent auto-reload purchase exists - # or a recent attempt failed (avoids hammering a declined card). - cutoff = datetime.now(UTC) - timedelta( - minutes=max(config.AUTO_RELOAD_COOLDOWN_MINUTES, 0) - ) - if user.auto_reload_failed_at and user.auto_reload_failed_at >= cutoff: - return - recent = ( - await session.execute( - select(CreditPurchase.id) - .where( - CreditPurchase.user_id == user.id, - CreditPurchase.source == "auto_reload", - CreditPurchase.created_at >= cutoff, - CreditPurchase.status.in_( - [ - CreditPurchaseStatus.PENDING, - CreditPurchaseStatus.COMPLETED, - ] - ), - ) - .limit(1) - ) - ).first() - if recent is not None: - return - - # Enqueue outside the session. The task re-checks everything with a row - # lock before charging, so a benign race here only costs a no-op task run. - try: - from app.tasks.celery_tasks.auto_reload_task import ( - auto_reload_credits_task, - ) - - auto_reload_credits_task.delay(str(user_id)) - except Exception: - logger.warning( - "Failed to enqueue auto_reload_credits task for user %s", - user_id, - exc_info=True, - ) diff --git a/surfsense_backend/app/services/billable_calls.py b/surfsense_backend/app/services/billable_calls.py index 15a3c3e55..92ccd6a78 100644 --- a/surfsense_backend/app/services/billable_calls.py +++ b/surfsense_backend/app/services/billable_calls.py @@ -69,8 +69,8 @@ BillableSessionFactory = Callable[[], AbstractAsyncContextManager[AsyncSession]] class QuotaInsufficientError(Exception): - """Raised when ``TokenQuotaService.credit_reserve`` denies a billable - call because the user has exhausted their credit wallet. + """Raised when ``TokenQuotaService.premium_reserve`` denies a billable + call because the user has exhausted their premium credit pool. The route handler should catch this and return HTTP 402 Payment Required (or the equivalent for the surface area). Outside of the HTTP @@ -83,15 +83,17 @@ class QuotaInsufficientError(Exception): self, *, usage_type: str, - balance_micros: int, + used_micros: int, + limit_micros: int, remaining_micros: int, ) -> None: self.usage_type = usage_type - self.balance_micros = balance_micros + self.used_micros = used_micros + self.limit_micros = limit_micros self.remaining_micros = remaining_micros super().__init__( - f"Credit exhausted for {usage_type}: " - f"balance={balance_micros} remaining={remaining_micros} (micro-USD)" + f"Premium credit exhausted for {usage_type}: " + f"used={used_micros} limit={limit_micros} remaining={remaining_micros} (micro-USD)" ) @@ -265,7 +267,7 @@ async def billable_call( ``TokenTrackingCallback`` populates the accumulator automatically. Raises: - QuotaInsufficientError: when premium and ``credit_reserve`` denies. + QuotaInsufficientError: when premium and ``premium_reserve`` denies. """ is_premium = billing_tier == "premium" session_factory = billable_session_factory or shielded_async_session @@ -308,7 +310,7 @@ async def billable_call( request_id = str(uuid4()) async with session_factory() as quota_session: - reserve_result = await TokenQuotaService.credit_reserve( + reserve_result = await TokenQuotaService.premium_reserve( db_session=quota_session, user_id=user_id, request_id=request_id, @@ -318,16 +320,18 @@ async def billable_call( if not reserve_result.allowed: logger.info( "[billable_call] reserve DENIED user=%s usage_type=%s " - "reserve=%d balance=%d remaining=%d", + "reserve=%d used=%d limit=%d remaining=%d", user_id, usage_type, reserve_micros, - reserve_result.balance, + reserve_result.used, + reserve_result.limit, reserve_result.remaining, ) raise QuotaInsufficientError( usage_type=usage_type, - balance_micros=reserve_result.balance, + used_micros=reserve_result.used, + limit_micros=reserve_result.limit, remaining_micros=reserve_result.remaining, ) @@ -348,14 +352,14 @@ async def billable_call( # BaseException so cancellation also releases. try: async with session_factory() as quota_session: - await TokenQuotaService.credit_release( + await TokenQuotaService.premium_release( db_session=quota_session, user_id=user_id, reserved_micros=reserve_micros, ) except Exception: logger.exception( - "[billable_call] credit_release failed for user=%s " + "[billable_call] premium_release failed for user=%s " "reserve_micros=%d (reservation will be GC'd by quota " "reconciliation if/when implemented)", user_id, @@ -376,7 +380,7 @@ async def billable_call( thread_id, ) async with session_factory() as quota_session: - final_result = await TokenQuotaService.credit_finalize( + final_result = await TokenQuotaService.premium_finalize( db_session=quota_session, user_id=user_id, request_id=request_id, @@ -385,25 +389,26 @@ async def billable_call( ) logger.info( "[billable_call] finalize user=%s usage_type=%s actual=%d " - "reserved=%d → balance=%d (remaining=%d)", + "reserved=%d → used=%d/%d (remaining=%d)", user_id, usage_type, actual_micros, reserve_micros, - final_result.balance, + final_result.used, + final_result.limit, final_result.remaining, ) except Exception as finalize_exc: # Last-ditch: if finalize itself fails, we must at least release # so the reservation doesn't leak. logger.exception( - "[billable_call] credit_finalize failed for user=%s; " + "[billable_call] premium_finalize failed for user=%s; " "attempting release", user_id, ) try: async with session_factory() as quota_session: - await TokenQuotaService.credit_release( + await TokenQuotaService.premium_release( db_session=quota_session, user_id=user_id, reserved_micros=reserve_micros, @@ -445,22 +450,22 @@ async def _resolve_agent_billing_for_search_space( thread_id: int | None = None, ) -> tuple[UUID, str, str]: """Resolve ``(owner_user_id, billing_tier, base_model)`` for the search-space - chat model. + agent LLM. Used by Celery tasks (podcast generation, video presentation) to bill the - search-space owner's premium credit pool when the chat model is premium. + search-space owner's premium credit pool when the agent LLM is premium. - Resolution rules mirror the chat model role resolver: + Resolution rules mirror chat at ``stream_new_chat.py:2294-2351``: - - Search space not found / no ``chat_model_id``: raise ``ValueError``. - - **Auto mode** (``id == AUTO_MODE_ID == 0``): + - Search space not found / no ``agent_llm_id``: raise ``ValueError``. + - **Auto mode** (``id == AUTO_FASTEST_ID == 0``): * ``thread_id`` is set: delegate to ``resolve_or_get_pinned_llm_config_id`` (the same call chat uses) and recurse into the resolved id. Reuses chat's existing pin if present so the same model bills for chat + downstream podcast/video. If the user is not premium-eligible, the pin service auto-restricts to free deployments — denial only happens later in - ``billable_call.credit_reserve`` if the pin really is premium and + ``billable_call.premium_reserve`` if the pin really is premium and credit ran out mid-flow. * ``thread_id`` is None: fallback to ``("free", "auto")``. Forward-compat for any future direct-API path; today both Celery tasks always pass @@ -469,8 +474,9 @@ async def _resolve_agent_billing_for_search_space( (defaults to ``"free"`` via ``app/config/__init__.py:52`` setdefault), ``base_model = litellm_params.get("base_model") or model_name`` — NOT provider-prefixed, matching chat's cost-map lookup convention. - - **Positive id** (user BYOK ``Model``): always free; ``base_model`` from - the model catalog override or the upstream ``model_id``. + - **Positive id** (user BYOK ``NewLLMConfig``): always free (matches + ``AgentConfig.from_new_llm_config`` which hard-codes ``billing_tier="free"``); + ``base_model`` from ``litellm_params`` or ``model_name``. Note on imports: ``llm_service``, ``auto_model_pin_service``, and ``llm_router_service`` are imported lazily inside the function body to @@ -479,9 +485,8 @@ async def _resolve_agent_billing_for_search_space( ``billable_calls.py``'s module load path. """ from sqlalchemy import select - from sqlalchemy.orm import selectinload - from app.db import Model, SearchSpace + from app.db import NewLLMConfig, SearchSpace result = await session.execute( select(SearchSpace).where(SearchSpace.id == search_space_id) @@ -490,20 +495,20 @@ async def _resolve_agent_billing_for_search_space( if search_space is None: raise ValueError(f"Search space {search_space_id} not found") - chat_model_id = search_space.chat_model_id - if chat_model_id is None: + agent_llm_id = search_space.agent_llm_id + if agent_llm_id is None: raise ValueError( - f"Search space {search_space_id} has no chat_model_id configured" + f"Search space {search_space_id} has no agent_llm_id configured" ) owner_user_id: UUID = search_space.user_id from app.services.auto_model_pin_service import ( - AUTO_MODE_ID, + AUTO_FASTEST_ID, resolve_or_get_pinned_llm_config_id, ) - if chat_model_id == AUTO_MODE_ID: + if agent_llm_id == AUTO_FASTEST_ID: if thread_id is None: return owner_user_id, "free", "auto" try: @@ -512,7 +517,7 @@ async def _resolve_agent_billing_for_search_space( thread_id=thread_id, search_space_id=search_space_id, user_id=str(owner_user_id), - selected_llm_config_id=AUTO_MODE_ID, + selected_llm_config_id=AUTO_FASTEST_ID, ) except ValueError: logger.warning( @@ -523,35 +528,28 @@ async def _resolve_agent_billing_for_search_space( exc_info=True, ) return owner_user_id, "free", "auto" - chat_model_id = resolution.resolved_llm_config_id + agent_llm_id = resolution.resolved_llm_config_id - if chat_model_id < 0: + if agent_llm_id < 0: from app.services.llm_service import get_global_llm_config - cfg = get_global_llm_config(chat_model_id) or {} + cfg = get_global_llm_config(agent_llm_id) or {} billing_tier = str(cfg.get("billing_tier", "free")).lower() litellm_params = cfg.get("litellm_params") or {} base_model = litellm_params.get("base_model") or cfg.get("model_name") or "" return owner_user_id, billing_tier, base_model - model_result = await session.execute( - select(Model) - .options(selectinload(Model.connection)) - .where(Model.id == chat_model_id, Model.enabled.is_(True)) - ) - model = model_result.scalars().first() - base_model = "" - if ( - model is not None - and model.connection is not None - and model.connection.enabled - and ( - model.connection.search_space_id in (None, search_space_id) - and model.connection.user_id in (None, owner_user_id) + nlc_result = await session.execute( + select(NewLLMConfig).where( + NewLLMConfig.id == agent_llm_id, + NewLLMConfig.search_space_id == search_space_id, ) - ): - catalog = model.catalog or {} - base_model = catalog.get("base_model") or model.model_id or "" + ) + nlc = nlc_result.scalars().first() + base_model = "" + if nlc is not None: + litellm_params = nlc.litellm_params or {} + base_model = litellm_params.get("base_model") or nlc.model_name or "" return owner_user_id, "free", base_model diff --git a/surfsense_backend/app/services/export_service.py b/surfsense_backend/app/services/export_service.py index 9e6869fe1..97f952223 100644 --- a/surfsense_backend/app/services/export_service.py +++ b/surfsense_backend/app/services/export_service.py @@ -62,7 +62,7 @@ async def _get_document_markdown( chunk_result = await session.execute( select(Chunk.content) .filter(Chunk.document_id == document.id) - .order_by(Chunk.position, Chunk.id) + .order_by(Chunk.id) ) chunks = chunk_result.scalars().all() if chunks: diff --git a/surfsense_backend/app/services/global_model_catalog.py b/surfsense_backend/app/services/global_model_catalog.py deleted file mode 100644 index 1bcc99215..000000000 --- a/surfsense_backend/app/services/global_model_catalog.py +++ /dev/null @@ -1,128 +0,0 @@ -"""Materialize server-owned GLOBAL YAML configs as virtual connections/models.""" - -from __future__ import annotations - -from typing import Any - -from app.services.model_resolver import native_connection_from_config - - -def _base_model(config: dict[str, Any]) -> str | None: - litellm_params = config.get("litellm_params") or {} - if isinstance(litellm_params, dict): - return litellm_params.get("base_model") - return None - - -def _connection_key(conn: dict[str, Any]) -> tuple[Any, ...]: - # Deliberately includes api_key because two operator-owned credentials for - # the same provider/base can have different quota/rate limits upstream. - return ( - conn.get("provider"), - conn.get("base_url"), - conn.get("api_key"), - _freeze(conn.get("extra") or {}), - ) - - -def _freeze(value: Any) -> Any: - if isinstance(value, dict): - return tuple(sorted((key, _freeze(val)) for key, val in value.items())) - if isinstance(value, list): - return tuple(_freeze(item) for item in value) - return value - - -def _catalog_metadata(config: dict[str, Any]) -> dict[str, Any]: - return { - "billing_tier": config.get("billing_tier", "free"), - "quota_reserve_tokens": config.get("quota_reserve_tokens"), - "rpm": config.get("rpm"), - "tpm": config.get("tpm"), - "anonymous_enabled": config.get("anonymous_enabled", False), - "seo_enabled": config.get("seo_enabled", False), - "seo_slug": config.get("seo_slug"), - "input_cost_per_token": (config.get("litellm_params") or {}).get( - "input_cost_per_token" - ) - if isinstance(config.get("litellm_params"), dict) - else None, - "output_cost_per_token": (config.get("litellm_params") or {}).get( - "output_cost_per_token" - ) - if isinstance(config.get("litellm_params"), dict) - else None, - "is_planner": config.get("is_planner", False), - "base_model": _base_model(config), - "router_pool_eligible": config.get("router_pool_eligible", True), - } - - -def materialize_global_model_catalog( - *, - chat_configs: list[dict[str, Any]], - image_configs: list[dict[str, Any]], -) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: - connections: list[dict[str, Any]] = [] - models: list[dict[str, Any]] = [] - connection_id_by_key: dict[tuple[Any, ...], int] = {} - next_connection_id = -1 - - def add_config(config: dict[str, Any], role: str) -> None: - nonlocal next_connection_id - if not config.get("id") or not config.get("model_name"): - return - conn = native_connection_from_config(config) - conn["scope"] = "GLOBAL" - conn["enabled"] = True - key = _connection_key(conn) - connection_id = connection_id_by_key.get(key) - if connection_id is None: - connection_id = next_connection_id - next_connection_id -= 1 - connection_id_by_key[key] = connection_id - connections.append( - { - "id": connection_id, - **conn, - } - ) - - model_id = int(config["id"]) - models.append( - { - "id": model_id, - "connection_id": connection_id, - "model_id": config["model_name"], - "display_name": config.get("name") or config["model_name"], - "source": "MANUAL", - "supports_chat": role == "chat", - "max_input_tokens": config.get("max_input_tokens"), - "supports_image_input": ( - role == "chat" and bool(config.get("supports_image_input")) - ), - "supports_tools": bool(config.get("supports_tools", False)), - "supports_image_generation": role == "image_gen", - "capabilities_override": {}, - "enabled": True, - "billing_tier": config.get("billing_tier", "free"), - "catalog": _catalog_metadata(config), - "role": role, - } - ) - - for cfg in chat_configs: - if cfg.get("is_auto_mode"): - continue - add_config(cfg, "chat") - for cfg in image_configs: - if cfg.get("is_auto_mode"): - continue - add_config(cfg, "image_gen") - - # Each virtual connection is server-only. Callers that serialize these - # must strip api_key before returning data to clients. - return connections, models - - -__all__ = ["materialize_global_model_catalog"] diff --git a/surfsense_backend/app/services/image_gen_router_service.py b/surfsense_backend/app/services/image_gen_router_service.py index 241b3bc53..b4de2a0bf 100644 --- a/surfsense_backend/app/services/image_gen_router_service.py +++ b/surfsense_backend/app/services/image_gen_router_service.py @@ -20,13 +20,28 @@ from typing import Any from litellm import Router from litellm.utils import ImageResponse -from app.services.model_resolver import native_connection_from_config, to_litellm +from app.services.provider_api_base import resolve_api_base logger = logging.getLogger(__name__) # Special ID for Auto mode - uses router for load balancing IMAGE_GEN_AUTO_MODE_ID = 0 +# Provider mapping for LiteLLM model string construction. +# Only includes providers that support image generation. +# See: https://docs.litellm.ai/docs/image_generation#supported-providers +IMAGE_GEN_PROVIDER_MAP = { + "OPENAI": "openai", + "AZURE_OPENAI": "azure", + "GOOGLE": "gemini", # Google AI Studio + "VERTEX_AI": "vertex_ai", + "BEDROCK": "bedrock", # AWS Bedrock + "RECRAFT": "recraft", + "OPENROUTER": "openrouter", + "XINFERENCE": "xinference", + "NSCALE": "nscale", +} + class ImageGenRouterService: """ @@ -138,11 +153,38 @@ class ImageGenRouterService: if not config.get("model_name") or not config.get("api_key"): return None - model_string, resolved_kwargs = to_litellm( - native_connection_from_config(config), - config["model_name"], + # Build model string + provider = config.get("provider", "").upper() + if config.get("custom_provider"): + provider_prefix = config["custom_provider"] + else: + provider_prefix = IMAGE_GEN_PROVIDER_MAP.get(provider, provider.lower()) + model_string = f"{provider_prefix}/{config['model_name']}" + + # Build litellm params + litellm_params: dict[str, Any] = { + "model": model_string, + "api_key": config.get("api_key"), + } + + # Resolve ``api_base`` so deployments don't silently inherit + # ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE`` and 404 against + # the wrong provider (see ``provider_api_base`` docstring). + api_base = resolve_api_base( + provider=provider, + provider_prefix=provider_prefix, + config_api_base=config.get("api_base"), ) - litellm_params: dict[str, Any] = {"model": model_string, **resolved_kwargs} + if api_base: + litellm_params["api_base"] = api_base + + # Add api_version (required for Azure) + if config.get("api_version"): + litellm_params["api_version"] = config["api_version"] + + # Add any additional litellm parameters + if config.get("litellm_params"): + litellm_params.update(config["litellm_params"]) # All configs use same alias "auto" for unified routing deployment: dict[str, Any] = { diff --git a/surfsense_backend/app/services/llm_error_adapter.py b/surfsense_backend/app/services/llm_error_adapter.py deleted file mode 100644 index 8d451ee96..000000000 --- a/surfsense_backend/app/services/llm_error_adapter.py +++ /dev/null @@ -1,257 +0,0 @@ -"""Normalize provider/LLM exceptions into low-cardinality product categories.""" - -from __future__ import annotations - -import json -from dataclasses import dataclass -from enum import StrEnum -from typing import Any - - -class LLMErrorCategory(StrEnum): - RATE_LIMITED = "rate_limited" - TIMEOUT = "timeout" - PROVIDER_UNAVAILABLE = "provider_unavailable" - BAD_GATEWAY = "bad_gateway" - CONNECTION_FAILED = "connection_failed" - AUTH_FAILED = "auth_failed" - PERMISSION_DENIED = "permission_denied" - MODEL_NOT_FOUND = "model_not_found" - BAD_REQUEST = "bad_request" - CONTEXT_LIMIT = "context_limit" - RESPONSE_INVALID = "response_invalid" - SERVER_ERROR = "server_error" - UNKNOWN = "unknown" - - -@dataclass(frozen=True) -class LLMErrorAdaptation: - category: LLMErrorCategory - retryable: bool - user_message: str - provider_status_code: int | None = None - provider_error_type: str | None = None - - -_CATEGORY_MESSAGES: dict[LLMErrorCategory, str] = { - LLMErrorCategory.RATE_LIMITED: "LLM rate limit exceeded. Will retry on next sync.", - LLMErrorCategory.TIMEOUT: "LLM request timed out. Will retry on next sync.", - LLMErrorCategory.PROVIDER_UNAVAILABLE: "LLM service temporarily unavailable. Will retry on next sync.", - LLMErrorCategory.BAD_GATEWAY: "LLM gateway error. Will retry on next sync.", - LLMErrorCategory.CONNECTION_FAILED: "Could not reach the LLM service. Check network connectivity.", - LLMErrorCategory.AUTH_FAILED: "LLM authentication failed. Check your API key.", - LLMErrorCategory.PERMISSION_DENIED: "LLM request denied. Check your account permissions.", - LLMErrorCategory.MODEL_NOT_FOUND: "Model not found. Check your model configuration.", - LLMErrorCategory.BAD_REQUEST: "LLM rejected the request. Document content may be invalid.", - LLMErrorCategory.CONTEXT_LIMIT: "Document exceeds the LLM context window even after optimization.", - LLMErrorCategory.RESPONSE_INVALID: "LLM returned an invalid response.", - LLMErrorCategory.SERVER_ERROR: "LLM internal server error. Will retry on next sync.", - LLMErrorCategory.UNKNOWN: "Something went wrong when calling the LLM.", -} - -_RETRYABLE_CATEGORIES = { - LLMErrorCategory.RATE_LIMITED, - LLMErrorCategory.TIMEOUT, - LLMErrorCategory.PROVIDER_UNAVAILABLE, - LLMErrorCategory.BAD_GATEWAY, - LLMErrorCategory.CONNECTION_FAILED, - LLMErrorCategory.SERVER_ERROR, -} - -_CLASS_NAME_MAP: tuple[tuple[LLMErrorCategory, tuple[str, ...]], ...] = ( - ( - LLMErrorCategory.RATE_LIMITED, - ("RateLimitError", "TooManyRequests", "TooManyRequestsError"), - ), - (LLMErrorCategory.TIMEOUT, ("Timeout", "APITimeoutError", "TimeoutException")), - ( - LLMErrorCategory.PROVIDER_UNAVAILABLE, - ("ServiceUnavailableError", "ServiceUnavailable"), - ), - ( - LLMErrorCategory.BAD_GATEWAY, - ("BadGatewayError", "GatewayTimeoutError"), - ), - ( - LLMErrorCategory.CONNECTION_FAILED, - ("APIConnectionError", "ConnectError", "ConnectTimeout", "ReadTimeout"), - ), - ( - LLMErrorCategory.AUTH_FAILED, - ("AuthenticationError", "InvalidApiKey", "InvalidAPIKey", "InvalidApiKeyError"), - ), - (LLMErrorCategory.PERMISSION_DENIED, ("PermissionDeniedError", "ForbiddenError")), - (LLMErrorCategory.MODEL_NOT_FOUND, ("NotFoundError", "ModelNotFoundError")), - ( - LLMErrorCategory.CONTEXT_LIMIT, - ("ContextWindowExceeded", "ContextOverflow", "ContextLimit"), - ), - ( - LLMErrorCategory.RESPONSE_INVALID, - ("APIResponseValidationError", "ResponseValidationError"), - ), - ( - LLMErrorCategory.BAD_REQUEST, - ("BadRequestError", "InvalidRequestError", "UnprocessableEntityError"), - ), - (LLMErrorCategory.SERVER_ERROR, ("InternalServerError",)), -) - - -def _parse_error_payload(message: str) -> dict[str, Any] | None: - candidates = [message] - first_brace_idx = message.find("{") - if first_brace_idx >= 0: - candidates.append(message[first_brace_idx:]) - - for candidate in candidates: - try: - parsed = json.loads(candidate) - if isinstance(parsed, dict): - return parsed - except Exception: - continue - return None - - -def _class_names(exc: BaseException) -> tuple[str, ...]: - return tuple(cls.__name__ for cls in type(exc).__mro__) - - -def _category_from_class_name(exc: BaseException) -> LLMErrorCategory | None: - names = _class_names(exc) - for category, hints in _CLASS_NAME_MAP: - if any(any(hint in name for hint in hints) for name in names): - return category - return None - - -def _extract_provider_status_code(parsed: dict[str, Any] | None) -> int | None: - if not isinstance(parsed, dict): - return None - candidates: list[Any] = [parsed.get("code"), parsed.get("status")] - nested = parsed.get("error") - if isinstance(nested, dict): - candidates.extend([nested.get("code"), nested.get("status")]) - for value in candidates: - try: - if value is None: - continue - return int(value) - except Exception: - continue - return None - - -def _extract_provider_error_type(parsed: dict[str, Any] | None) -> str | None: - if not isinstance(parsed, dict): - return None - candidates: list[Any] = [parsed.get("type")] - nested = parsed.get("error") - if isinstance(nested, dict): - candidates.append(nested.get("type")) - for value in candidates: - if isinstance(value, str) and value: - return value - return None - - -def _category_from_provider_payload( - status_code: int | None, - provider_error_type: str | None, -) -> LLMErrorCategory | None: - if status_code == 429: - return LLMErrorCategory.RATE_LIMITED - if status_code == 401: - return LLMErrorCategory.AUTH_FAILED - if status_code == 403: - return LLMErrorCategory.PERMISSION_DENIED - if status_code == 404: - return LLMErrorCategory.MODEL_NOT_FOUND - if status_code in (400, 422): - return LLMErrorCategory.BAD_REQUEST - if status_code in (502, 504): - return LLMErrorCategory.BAD_GATEWAY - if status_code == 503: - return LLMErrorCategory.PROVIDER_UNAVAILABLE - if status_code is not None and status_code >= 500: - return LLMErrorCategory.SERVER_ERROR - - normalized_type = (provider_error_type or "").lower() - if normalized_type == "rate_limit_error": - return LLMErrorCategory.RATE_LIMITED - if normalized_type in { - "authentication_error", - "invalid_api_key", - "invalid_api_key_error", - }: - return LLMErrorCategory.AUTH_FAILED - if normalized_type in {"permission_denied", "forbidden"}: - return LLMErrorCategory.PERMISSION_DENIED - if normalized_type in {"not_found_error", "model_not_found"}: - return LLMErrorCategory.MODEL_NOT_FOUND - if normalized_type in {"context_length_exceeded", "context_window_exceeded"}: - return LLMErrorCategory.CONTEXT_LIMIT - return None - - -def _category_from_message(raw: str) -> LLMErrorCategory | None: - lowered = raw.lower() - if any( - hint in lowered - for hint in ("rate limit", "rate-limited", "temporarily rate-limited") - ): - return LLMErrorCategory.RATE_LIMITED - if any( - hint in lowered - for hint in ( - "invalid api key", - "invalid_api_key", - "authentication", - "unauthorized", - "user not found", - "api key is expired", - "expired api key", - ) - ): - return LLMErrorCategory.AUTH_FAILED - if "forbidden" in lowered or "permission denied" in lowered: - return LLMErrorCategory.PERMISSION_DENIED - if "model not found" in lowered: - return LLMErrorCategory.MODEL_NOT_FOUND - if any( - hint in lowered - for hint in ( - "context length", - "context window", - "maximum context", - "too many tokens", - ) - ): - return LLMErrorCategory.CONTEXT_LIMIT - return None - - -def adapt_llm_exception(exc: BaseException) -> LLMErrorAdaptation: - raw = str(exc) - parsed = _parse_error_payload(raw) - status_code = _extract_provider_status_code(parsed) - provider_error_type = _extract_provider_error_type(parsed) - - category = ( - _category_from_provider_payload(status_code, provider_error_type) - or _category_from_message(raw) - or _category_from_class_name(exc) - or LLMErrorCategory.UNKNOWN - ) - return LLMErrorAdaptation( - category=category, - retryable=category in _RETRYABLE_CATEGORIES, - user_message=_CATEGORY_MESSAGES[category], - provider_status_code=status_code, - provider_error_type=provider_error_type, - ) - - -def llm_error_message(exc: BaseException) -> str: - return adapt_llm_exception(exc).user_message diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py index 3affdcce7..d220aa346 100644 --- a/surfsense_backend/app/services/llm_router_service.py +++ b/surfsense_backend/app/services/llm_router_service.py @@ -30,7 +30,6 @@ from litellm.exceptions import ( ) from pydantic import Field -from app.services.model_resolver import native_connection_from_config, to_litellm from app.utils.perf import get_perf_logger litellm.json_logs = False @@ -97,6 +96,53 @@ def _sanitize_content(content: Any) -> Any: # Special ID for Auto mode - uses router for load balancing AUTO_MODE_ID = 0 +# Provider mapping for LiteLLM model string construction +PROVIDER_MAP = { + "OPENAI": "openai", + "ANTHROPIC": "anthropic", + "GROQ": "groq", + "COHERE": "cohere", + "GOOGLE": "gemini", + "OLLAMA": "ollama_chat", + "MISTRAL": "mistral", + "AZURE_OPENAI": "azure", + "OPENROUTER": "openrouter", + "COMETAPI": "cometapi", + "XAI": "xai", + "BEDROCK": "bedrock", + "AWS_BEDROCK": "bedrock", # Legacy support + "VERTEX_AI": "vertex_ai", + "TOGETHER_AI": "together_ai", + "FIREWORKS_AI": "fireworks_ai", + "REPLICATE": "replicate", + "PERPLEXITY": "perplexity", + "ANYSCALE": "anyscale", + "DEEPINFRA": "deepinfra", + "CEREBRAS": "cerebras", + "SAMBANOVA": "sambanova", + "AI21": "ai21", + "CLOUDFLARE": "cloudflare", + "DATABRICKS": "databricks", + "DEEPSEEK": "openai", + "ALIBABA_QWEN": "openai", + "MOONSHOT": "openai", + "ZHIPU": "openai", + "GITHUB_MODELS": "github", + "HUGGINGFACE": "huggingface", + "MINIMAX": "openai", + "CUSTOM": "custom", +} + + +# ``PROVIDER_DEFAULT_API_BASE`` and ``PROVIDER_KEY_DEFAULT_API_BASE`` were +# hoisted to ``app.services.provider_api_base`` so vision and image-gen +# call sites can share the exact same defense (OpenRouter / Groq / etc. +# 404-ing against an inherited Azure endpoint). Re-exported here for +# backward compatibility with any external import. +from app.services.provider_api_base import ( # noqa: E402 + resolve_api_base, +) + class LLMRouterService: """ @@ -374,11 +420,38 @@ class LLMRouterService: if not config.get("model_name") or not config.get("api_key"): return None - model_string, resolved_kwargs = to_litellm( - native_connection_from_config(config), - config["model_name"], + # Build model string + provider = config.get("provider", "").upper() + if config.get("custom_provider"): + provider_prefix = config["custom_provider"] + model_string = f"{provider_prefix}/{config['model_name']}" + else: + provider_prefix = PROVIDER_MAP.get(provider, provider.lower()) + model_string = f"{provider_prefix}/{config['model_name']}" + + # Build litellm params + litellm_params = { + "model": model_string, + "api_key": config.get("api_key"), + } + + # Resolve ``api_base``. Config value wins; otherwise apply a + # provider-aware default so the deployment does not silently + # inherit unrelated env vars (e.g. ``AZURE_API_BASE``) and route + # requests to the wrong endpoint. See ``provider_api_base`` + # docstring for the motivating bug (OpenRouter models 404-ing + # against an Azure endpoint). + api_base = resolve_api_base( + provider=provider, + provider_prefix=provider_prefix, + config_api_base=config.get("api_base"), ) - litellm_params = {"model": model_string, **resolved_kwargs} + if api_base: + litellm_params["api_base"] = api_base + + # Add any additional litellm parameters + if config.get("litellm_params"): + litellm_params.update(config["litellm_params"]) # Extract rate limits if provided deployment = { diff --git a/surfsense_backend/app/services/llm_service.py b/surfsense_backend/app/services/llm_service.py index e535d0150..7061a826f 100644 --- a/surfsense_backend/app/services/llm_service.py +++ b/surfsense_backend/app/services/llm_service.py @@ -6,21 +6,17 @@ from langchain_core.messages import HumanMessage from langchain_litellm import ChatLiteLLM from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from sqlalchemy.orm import selectinload from app.config import config -from app.db import Model, SearchSpace -from app.services.auto_model_pin_service import ( - auto_model_candidates, - choose_auto_model_candidate, -) +from app.db import NewLLMConfig, SearchSpace from app.services.llm_router_service import ( AUTO_MODE_ID, ChatLiteLLMRouter, + LLMRouterService, + get_auto_mode_llm, is_auto_mode, ) -from app.services.model_capabilities import has_capability -from app.services.model_resolver import native_connection_from_config, to_litellm +from app.services.provider_api_base import resolve_api_base from app.services.token_tracking_service import token_tracker # Configure litellm to automatically drop unsupported parameters @@ -70,29 +66,6 @@ def _is_interactive_auth_provider( return False -def _legacy_config_connection( - *, - provider: str, - model_name: str, - api_key: str | None, - api_base: str | None, - custom_provider: str | None = None, - litellm_params: dict | None = None, - api_version: str | None = None, -) -> tuple[str, dict]: - cfg = { - "provider": provider.lower(), - "model_name": model_name, - "api_key": api_key, - "api_base": api_base, - "custom_provider": custom_provider, - "api_version": api_version, - "litellm_params": litellm_params or {}, - } - conn = native_connection_from_config(cfg) - return to_litellm(conn, model_name) - - class LLMRole: AGENT = "agent" # For agent/chat operations @@ -100,16 +73,26 @@ class LLMRole: def get_global_llm_config(llm_config_id: int) -> dict | None: """ Get a global LLM configuration by ID. - Global configs have negative IDs. Auto mode (ID 0) is resolved through the - model-candidate pipeline, not this legacy config lookup. + Global configs have negative IDs. ID 0 is reserved for Auto mode. Args: - llm_config_id: The ID of the global config (must be negative) + llm_config_id: The ID of the global config (should be negative or 0 for Auto) Returns: dict: Global config dictionary or None if not found """ - if llm_config_id >= 0: + # Auto mode (ID 0) is handled separately via the router + if llm_config_id == AUTO_MODE_ID: + return { + "id": AUTO_MODE_ID, + "name": "Auto (Fastest)", + "description": "Automatically routes requests across available LLM providers for optimal performance and rate limit handling", + "provider": "AUTO", + "model_name": "auto", + "is_auto_mode": True, + } + + if llm_config_id > 0: return None for cfg in config.GLOBAL_LLM_CONFIGS: @@ -119,55 +102,6 @@ def get_global_llm_config(llm_config_id: int) -> dict | None: return None -def get_global_model(model_id: int) -> dict | None: - return next((m for m in config.GLOBAL_MODELS if m.get("id") == model_id), None) - - -def get_global_connection(connection_id: int) -> dict | None: - return next( - (c for c in config.GLOBAL_CONNECTIONS if c.get("id") == connection_id), - None, - ) - - -def _has_capability(model: dict | Model, capability: str) -> bool: - return has_capability(model, capability) - - -def _chat_litellm_from_resolved( - *, - conn: dict | object, - model_id: str, - disable_streaming: bool = False, -) -> tuple[str, dict]: - model_string, resolved_kwargs = to_litellm(conn, model_id) - litellm_kwargs = {"model": model_string, **resolved_kwargs} - if disable_streaming: - litellm_kwargs["disable_streaming"] = True - return model_string, litellm_kwargs - - -async def _get_db_model( - session: AsyncSession, - model_id: int, - search_space: SearchSpace, -) -> Model | None: - result = await session.execute( - select(Model) - .options(selectinload(Model.connection)) - .where(Model.id == model_id, Model.enabled.is_(True)) - ) - model = result.scalars().first() - if not model or not model.connection or not model.connection.enabled: - return None - conn = model.connection - if conn.search_space_id and conn.search_space_id != search_space.id: - return None - if conn.user_id and conn.user_id != search_space.user_id: - return None - return model - - async def validate_llm_config( provider: str, model_name: str, @@ -212,15 +146,62 @@ async def validate_llm_config( return False, msg try: - model_string, resolved_kwargs = _legacy_config_connection( - provider=provider, - model_name=model_name, - api_key=api_key, - api_base=api_base, - custom_provider=custom_provider, - litellm_params=litellm_params, - ) - litellm_kwargs = {"model": model_string, **resolved_kwargs, "timeout": 30} + # Build the model string for litellm + if custom_provider: + model_string = f"{custom_provider}/{model_name}" + else: + # Map provider enum to litellm format + provider_map = { + "OPENAI": "openai", + "ANTHROPIC": "anthropic", + "GROQ": "groq", + "COHERE": "cohere", + "GOOGLE": "gemini", + "OLLAMA": "ollama_chat", + "MISTRAL": "mistral", + "AZURE_OPENAI": "azure", + "OPENROUTER": "openrouter", + "COMETAPI": "cometapi", + "XAI": "xai", + "BEDROCK": "bedrock", + "AWS_BEDROCK": "bedrock", # Legacy support (backward compatibility) + "VERTEX_AI": "vertex_ai", + "TOGETHER_AI": "together_ai", + "FIREWORKS_AI": "fireworks_ai", + "REPLICATE": "replicate", + "PERPLEXITY": "perplexity", + "ANYSCALE": "anyscale", + "DEEPINFRA": "deepinfra", + "CEREBRAS": "cerebras", + "SAMBANOVA": "sambanova", + "AI21": "ai21", + "CLOUDFLARE": "cloudflare", + "DATABRICKS": "databricks", + # Chinese LLM providers + "DEEPSEEK": "openai", + "ALIBABA_QWEN": "openai", + "MOONSHOT": "openai", + "ZHIPU": "openai", # GLM needs special handling + "MINIMAX": "openai", + "GITHUB_MODELS": "github", + } + provider_prefix = provider_map.get(provider, provider.lower()) + model_string = f"{provider_prefix}/{model_name}" + + # Create ChatLiteLLM instance + litellm_kwargs = { + "model": model_string, + "api_key": api_key, + "timeout": 30, # Set a timeout for validation + } + + # Add optional parameters + if api_base: + litellm_kwargs["api_base"] = api_base + + # Add any additional litellm parameters + if litellm_params: + litellm_kwargs.update(litellm_params) from app.agents.chat.runtime.llm_config import ( SanitizedChatLiteLLM, @@ -302,9 +283,9 @@ async def get_search_space_llm_instance( logger.error(f"Search space {search_space_id} not found") return None - # Get the appropriate model binding ID based on role + # Get the appropriate LLM config ID based on role if role == LLMRole.AGENT: - llm_config_id = search_space.chat_model_id + llm_config_id = search_space.agent_llm_id else: logger.error(f"Invalid LLM role: {role}") return None @@ -313,42 +294,88 @@ async def get_search_space_llm_instance( logger.error(f"No {role} LLM configured for search space {search_space_id}") return None - # Auto mode resolves to one concrete global or BYOK model from the - # unified model-connections catalog. + # Check for Auto mode (ID 0) - use router for load balancing if is_auto_mode(llm_config_id): - candidates = await auto_model_candidates( - session, - search_space_id=search_space_id, - user_id=search_space.user_id, - capability="chat", - ) - if not candidates: - logger.error("No chat-capable models available for Auto mode") - return None - llm_config_id = int( - choose_auto_model_candidate(candidates, search_space_id)["id"] - ) - - # Check if this is a global virtual model (negative ID) - if llm_config_id < 0: - global_model = get_global_model(llm_config_id) - if not global_model or not _has_capability(global_model, "chat"): - logger.error(f"Global chat model {llm_config_id} not found") - return None - global_connection = get_global_connection(global_model["connection_id"]) - if not global_connection: + if not LLMRouterService.is_initialized(): logger.error( - "Global connection %s not found for model %s", - global_model["connection_id"], - llm_config_id, + "Auto mode requested but LLM Router not initialized. " + "Ensure global_llm_config.yaml exists with valid configs." ) return None - _, litellm_kwargs = _chat_litellm_from_resolved( - conn=global_connection, - model_id=global_model["model_id"], - disable_streaming=disable_streaming, - ) + try: + logger.debug( + f"Using Auto mode (LLM Router) for search space {search_space_id}, role {role}" + ) + return get_auto_mode_llm(streaming=not disable_streaming) + except Exception as e: + logger.error(f"Failed to create ChatLiteLLMRouter: {e}") + return None + + # Check if this is a global config (negative ID) + if llm_config_id < 0: + global_config = get_global_llm_config(llm_config_id) + if not global_config: + logger.error(f"Global LLM config {llm_config_id} not found") + return None + + # Build model string for global config + if global_config.get("custom_provider"): + model_string = ( + f"{global_config['custom_provider']}/{global_config['model_name']}" + ) + else: + provider_map = { + "OPENAI": "openai", + "ANTHROPIC": "anthropic", + "GROQ": "groq", + "COHERE": "cohere", + "GOOGLE": "gemini", + "OLLAMA": "ollama_chat", + "MISTRAL": "mistral", + "AZURE_OPENAI": "azure", + "OPENROUTER": "openrouter", + "COMETAPI": "cometapi", + "XAI": "xai", + "BEDROCK": "bedrock", + "AWS_BEDROCK": "bedrock", + "VERTEX_AI": "vertex_ai", + "TOGETHER_AI": "together_ai", + "FIREWORKS_AI": "fireworks_ai", + "REPLICATE": "replicate", + "PERPLEXITY": "perplexity", + "ANYSCALE": "anyscale", + "DEEPINFRA": "deepinfra", + "CEREBRAS": "cerebras", + "SAMBANOVA": "sambanova", + "AI21": "ai21", + "CLOUDFLARE": "cloudflare", + "DATABRICKS": "databricks", + "DEEPSEEK": "openai", + "ALIBABA_QWEN": "openai", + "MOONSHOT": "openai", + "ZHIPU": "openai", + "MINIMAX": "openai", + } + provider_prefix = provider_map.get( + global_config["provider"], global_config["provider"].lower() + ) + model_string = f"{provider_prefix}/{global_config['model_name']}" + + # Create ChatLiteLLM instance from global config + litellm_kwargs = { + "model": model_string, + "api_key": global_config["api_key"], + } + + if global_config.get("api_base"): + litellm_kwargs["api_base"] = global_config["api_base"] + + if global_config.get("litellm_params"): + litellm_kwargs.update(global_config["litellm_params"]) + + if disable_streaming: + litellm_kwargs["disable_streaming"] = True from app.agents.chat.runtime.llm_config import ( SanitizedChatLiteLLM, @@ -356,18 +383,80 @@ async def get_search_space_llm_instance( return SanitizedChatLiteLLM(**litellm_kwargs) - model = await _get_db_model(session, llm_config_id, search_space) - if not model or not _has_capability(model, "chat"): + # Get the LLM configuration from database (NewLLMConfig) + result = await session.execute( + select(NewLLMConfig).where( + NewLLMConfig.id == llm_config_id, + NewLLMConfig.search_space_id == search_space_id, + ) + ) + llm_config = result.scalars().first() + + if not llm_config: logger.error( - f"Chat model {llm_config_id} not found in search space {search_space_id}" + f"LLM config {llm_config_id} not found in search space {search_space_id}" ) return None - _, litellm_kwargs = _chat_litellm_from_resolved( - conn=model.connection, - model_id=model.model_id, - disable_streaming=disable_streaming, - ) + # Build the model string for litellm + if llm_config.custom_provider: + model_string = f"{llm_config.custom_provider}/{llm_config.model_name}" + else: + # Map provider enum to litellm format + provider_map = { + "OPENAI": "openai", + "ANTHROPIC": "anthropic", + "GROQ": "groq", + "COHERE": "cohere", + "GOOGLE": "gemini", + "OLLAMA": "ollama_chat", + "MISTRAL": "mistral", + "AZURE_OPENAI": "azure", + "OPENROUTER": "openrouter", + "COMETAPI": "cometapi", + "XAI": "xai", + "BEDROCK": "bedrock", + "AWS_BEDROCK": "bedrock", + "VERTEX_AI": "vertex_ai", + "TOGETHER_AI": "together_ai", + "FIREWORKS_AI": "fireworks_ai", + "REPLICATE": "replicate", + "PERPLEXITY": "perplexity", + "ANYSCALE": "anyscale", + "DEEPINFRA": "deepinfra", + "CEREBRAS": "cerebras", + "SAMBANOVA": "sambanova", + "AI21": "ai21", + "CLOUDFLARE": "cloudflare", + "DATABRICKS": "databricks", + "DEEPSEEK": "openai", + "ALIBABA_QWEN": "openai", + "MOONSHOT": "openai", + "ZHIPU": "openai", + "MINIMAX": "openai", + "GITHUB_MODELS": "github", + } + provider_prefix = provider_map.get( + llm_config.provider.value, llm_config.provider.value.lower() + ) + model_string = f"{provider_prefix}/{llm_config.model_name}" + + # Create ChatLiteLLM instance + litellm_kwargs = { + "model": model_string, + "api_key": llm_config.api_key, + } + + # Add optional parameters + if llm_config.api_base: + litellm_kwargs["api_base"] = llm_config.api_base + + # Add any additional litellm parameters + if llm_config.litellm_params: + litellm_kwargs.update(llm_config.litellm_params) + + if disable_streaming: + litellm_kwargs["disable_streaming"] = True from app.agents.chat.runtime.llm_config import ( SanitizedChatLiteLLM, @@ -385,7 +474,7 @@ async def get_search_space_llm_instance( async def get_agent_llm( session: AsyncSession, search_space_id: int, disable_streaming: bool = False ) -> ChatLiteLLM | ChatLiteLLMRouter | None: - """Get the search space's chat model instance.""" + """Get the search space's agent LLM instance for chat operations.""" return await get_search_space_llm_instance( session, search_space_id, @@ -399,17 +488,24 @@ async def get_vision_llm( ) -> ChatLiteLLM | ChatLiteLLMRouter | None: """Get the search space's vision LLM instance for screenshot analysis. - Resolves from the new connection/model role bindings: - - Auto mode (ID 0): unified global/BYOK model candidate selection - - Global (negative ID): virtual GLOBAL models from YAML - - DB (positive ID): Model + Connection tables + Resolves from the dedicated VisionLLMConfig system: + - Auto mode (ID 0): VisionLLMRouterService + - Global (negative ID): YAML configs + - DB (positive ID): VisionLLMConfig table Premium global configs are wrapped in :class:`QuotaCheckedVisionLLM` so each ``ainvoke`` debits the search-space owner's premium credit pool. User-owned BYOK configs and free global configs are returned unwrapped — they don't consume premium credit (issue M). """ + from app.db import VisionLLMConfig from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM + from app.services.vision_llm_router_service import ( + VISION_PROVIDER_MAP, + VisionLLMRouterService, + get_global_vision_llm_config, + is_vision_auto_mode, + ) try: result = await session.execute( @@ -420,78 +516,64 @@ async def get_vision_llm( logger.error(f"Search space {search_space_id} not found") return None - owner_user_id = search_space.user_id - - # Prefer the selected chat model when it is vision-capable. - chat_model_id = search_space.chat_model_id - if chat_model_id and chat_model_id != AUTO_MODE_ID: - if chat_model_id < 0: - chat_model = get_global_model(chat_model_id) - if chat_model and _has_capability(chat_model, "vision"): - global_connection = get_global_connection( - chat_model["connection_id"] - ) - if global_connection: - model_string, litellm_kwargs = _chat_litellm_from_resolved( - conn=global_connection, - model_id=chat_model["model_id"], - ) - from app.agents.chat.runtime.llm_config import ( - SanitizedChatLiteLLM, - ) - - return SanitizedChatLiteLLM(**litellm_kwargs) - else: - chat_model = await _get_db_model(session, chat_model_id, search_space) - if chat_model and _has_capability(chat_model, "vision"): - _, litellm_kwargs = _chat_litellm_from_resolved( - conn=chat_model.connection, - model_id=chat_model.model_id, - ) - from app.agents.chat.runtime.llm_config import ( - SanitizedChatLiteLLM, - ) - - return SanitizedChatLiteLLM(**litellm_kwargs) - - config_id = search_space.vision_model_id + config_id = search_space.vision_llm_config_id if config_id is None: logger.error(f"No vision LLM configured for search space {search_space_id}") return None - if config_id == AUTO_MODE_ID: - candidates = await auto_model_candidates( - session, - search_space_id=search_space_id, - user_id=owner_user_id, - capability="vision", - ) - if not candidates: - logger.error("No vision-capable models available for Auto mode") - return None - config_id = int( - choose_auto_model_candidate(candidates, search_space_id)["id"] - ) + owner_user_id = search_space.user_id - if config_id < 0: - global_model = get_global_model(config_id) - if not global_model or not _has_capability(global_model, "vision"): - logger.error(f"Global vision model {config_id} not found") - return None - - global_connection = get_global_connection(global_model["connection_id"]) - if not global_connection: + if is_vision_auto_mode(config_id): + if not VisionLLMRouterService.is_initialized(): logger.error( - "Global connection %s not found for model %s", - global_model["connection_id"], - config_id, + "Vision Auto mode requested but Vision LLM Router not initialized" ) return None + try: + # Auto mode is currently treated as free at the wrapper + # level — the underlying router can dispatch to either + # premium or free YAML configs but routing decisions are + # opaque. If/when we want to bill Auto-routed vision + # calls we'd need to thread the resolved deployment's + # billing_tier back from the router. For now we keep + # parity with chat Auto, which also doesn't pre-classify. + return ChatLiteLLMRouter( + router=VisionLLMRouterService.get_router(), + streaming=True, + ) + except Exception as e: + logger.error(f"Failed to create vision ChatLiteLLMRouter: {e}") + return None - model_string, litellm_kwargs = _chat_litellm_from_resolved( - conn=global_connection, - model_id=global_model["model_id"], + if config_id < 0: + global_cfg = get_global_vision_llm_config(config_id) + if not global_cfg: + logger.error(f"Global vision LLM config {config_id} not found") + return None + + if global_cfg.get("custom_provider"): + provider_prefix = global_cfg["custom_provider"] + model_string = f"{provider_prefix}/{global_cfg['model_name']}" + else: + provider_prefix = VISION_PROVIDER_MAP.get( + global_cfg["provider"].upper(), + global_cfg["provider"].lower(), + ) + model_string = f"{provider_prefix}/{global_cfg['model_name']}" + + litellm_kwargs = { + "model": model_string, + "api_key": global_cfg["api_key"], + } + api_base = resolve_api_base( + provider=global_cfg.get("provider"), + provider_prefix=provider_prefix, + config_api_base=global_cfg.get("api_base"), ) + if api_base: + litellm_kwargs["api_base"] = api_base + if global_cfg.get("litellm_params"): + litellm_kwargs.update(global_cfg["litellm_params"]) from app.agents.chat.runtime.llm_config import ( SanitizedChatLiteLLM, @@ -499,7 +581,7 @@ async def get_vision_llm( inner_llm = SanitizedChatLiteLLM(**litellm_kwargs) - billing_tier = str(global_model.get("billing_tier", "free")).lower() + billing_tier = str(global_cfg.get("billing_tier", "free")).lower() if billing_tier == "premium": return QuotaCheckedVisionLLM( inner_llm, @@ -507,23 +589,47 @@ async def get_vision_llm( search_space_id=search_space_id, billing_tier=billing_tier, base_model=model_string, - quota_reserve_tokens=global_model.get("catalog", {}).get( - "quota_reserve_tokens" - ), + quota_reserve_tokens=global_cfg.get("quota_reserve_tokens"), ) return inner_llm - model = await _get_db_model(session, config_id, search_space) - if not model or not _has_capability(model, "vision"): + # User-owned (positive ID) BYOK configs — always free. + result = await session.execute( + select(VisionLLMConfig).where( + VisionLLMConfig.id == config_id, + VisionLLMConfig.search_space_id == search_space_id, + ) + ) + vision_cfg = result.scalars().first() + if not vision_cfg: logger.error( - f"Vision model {config_id} not found in search space {search_space_id}" + f"Vision LLM config {config_id} not found in search space {search_space_id}" ) return None - _, litellm_kwargs = _chat_litellm_from_resolved( - conn=model.connection, - model_id=model.model_id, + if vision_cfg.custom_provider: + provider_prefix = vision_cfg.custom_provider + model_string = f"{provider_prefix}/{vision_cfg.model_name}" + else: + provider_prefix = VISION_PROVIDER_MAP.get( + vision_cfg.provider.value.upper(), + vision_cfg.provider.value.lower(), + ) + model_string = f"{provider_prefix}/{vision_cfg.model_name}" + + litellm_kwargs = { + "model": model_string, + "api_key": vision_cfg.api_key, + } + api_base = resolve_api_base( + provider=vision_cfg.provider.value, + provider_prefix=provider_prefix, + config_api_base=vision_cfg.api_base, ) + if api_base: + litellm_kwargs["api_base"] = api_base + if vision_cfg.litellm_params: + litellm_kwargs.update(vision_cfg.litellm_params) from app.agents.chat.runtime.llm_config import ( SanitizedChatLiteLLM, diff --git a/surfsense_backend/app/services/model_capabilities.py b/surfsense_backend/app/services/model_capabilities.py deleted file mode 100644 index fb7681f35..000000000 --- a/surfsense_backend/app/services/model_capabilities.py +++ /dev/null @@ -1,36 +0,0 @@ -"""Override-aware model capability lookup.""" - -from __future__ import annotations - -from collections.abc import Mapping -from typing import Any - -CAPABILITY_FIELDS = { - "chat": "supports_chat", - "vision": "supports_image_input", - "image_gen": "supports_image_generation", - "tools": "supports_tools", -} - - -def _get_value(model: Any, key: str) -> Any: - if isinstance(model, Mapping): - return model.get(key) - return getattr(model, key, None) - - -def has_capability(model: Any, capability: str) -> bool: - field = CAPABILITY_FIELDS.get(capability) - if field is None: - return False - - override = _get_value(model, "capabilities_override") or {} - if isinstance(override, Mapping) and field in override: - return bool(override[field]) - if isinstance(override, Mapping) and capability in override: - return bool(override[capability]) - - return bool(_get_value(model, field)) - - -__all__ = ["CAPABILITY_FIELDS", "has_capability"] diff --git a/surfsense_backend/app/services/model_connection_service.py b/surfsense_backend/app/services/model_connection_service.py deleted file mode 100644 index cdfd1d725..000000000 --- a/surfsense_backend/app/services/model_connection_service.py +++ /dev/null @@ -1,490 +0,0 @@ -"""Connection verification, model discovery, and capability probing.""" - -from __future__ import annotations - -import contextlib -import logging -from dataclasses import dataclass -from typing import Any - -import anyio -import httpx -import litellm - -from app.db import Connection, Model, ModelSource -from app.services.model_resolver import ensure_v1, to_litellm -from app.services.openrouter_model_normalizer import normalize_openrouter_models -from app.services.provider_registry import Transport, provider_label, spec_for - -logger = logging.getLogger(__name__) - -VERIFY_TIMEOUT_SECONDS = 8.0 -DISCOVERY_TIMEOUT_SECONDS = 15.0 -TEST_TIMEOUT_SECONDS = 30.0 - - -@dataclass(frozen=True) -class VerifyResult: - status: str - ok: bool - message: str = "" - - -class ModelDiscoveryError(Exception): - """User-correctable discovery failure for provider configuration issues.""" - - -def _auth_headers(conn: Connection) -> dict[str, str]: - if not conn.api_key: - return {} - return {"Authorization": f"Bearer {conn.api_key}"} - - -def _anthropic_headers(conn: Connection) -> dict[str, str]: - headers = {"anthropic-version": "2023-06-01"} - if conn.api_key: - headers["x-api-key"] = conn.api_key - return headers - - -def _base_url_or_default(conn: Connection) -> str | None: - if conn.base_url: - return conn.base_url.rstrip("/") - if conn.provider == "openai": - return "https://api.openai.com/v1" - if conn.provider == "anthropic": - return "https://api.anthropic.com/v1" - return spec_for(conn.provider).default_base_url - - -def _docker_hint(url: str | None, exc_or_status: Any) -> str: - raw = str(exc_or_status) - if not url: - return raw - if "localhost" in url or "127.0.0.1" in url: - return ( - f"{raw}. The backend is running inside Docker; localhost means the " - "backend container. Use host.docker.internal and make sure the model " - "server listens on 0.0.0.0." - ) - if "host.docker.internal" in url and ( - "refused" in raw.lower() or "connect" in raw.lower() - ): - return ( - f"{raw}. The host is reachable only if your local model server is " - "listening on 0.0.0.0. On Linux Docker, add " - "`host.docker.internal:host-gateway` to extra_hosts." - ) - return raw - - -def _model_test_error(conn: Connection, model_id: str, exc: Exception) -> VerifyResult: - provider_name = provider_label(conn.provider) - raw = str(exc) - normalized = raw.lower() - exc_name = exc.__class__.__name__.lower() - status_code = getattr(exc, "status_code", None) - - logger.info( - "Model test failed for provider=%s model=%s: %s", - conn.provider, - model_id, - raw, - ) - - if status_code in (401, 403) or "authentication" in exc_name or "401" in normalized: - return VerifyResult( - "AUTH_FAILED", - False, - f"Authentication failed. Check your {provider_name} credentials and try again.", - ) - - if status_code == 404 or "notfound" in exc_name or "not found" in normalized: - if conn.provider == "azure": - message = ( - "Azure OpenAI deployment was not found. Check the deployment name, " - "API version, and endpoint." - ) - else: - message = f"Model '{model_id}' was not found on {provider_name}." - return VerifyResult("NOT_FOUND", False, message) - - if status_code == 429 or "ratelimit" in exc_name or "rate limit" in normalized: - return VerifyResult( - "RATE_LIMITED", - False, - f"{provider_name} rate limited the model test. Try again later.", - ) - - if "timeout" in exc_name or "timed out" in normalized: - return VerifyResult( - "TIMEOUT", - False, - f"{provider_name} did not respond in time. Check the endpoint and try again.", - ) - - if "connection" in exc_name or "connect" in normalized: - return VerifyResult( - "UNREACHABLE", - False, - _docker_hint( - _base_url_or_default(conn), - f"Could not reach {provider_name}. Check the endpoint and try again.", - ), - ) - - return VerifyResult( - "UNREACHABLE", - False, - f"Could not test model '{model_id}' on {provider_name}. Check the credentials, endpoint, and model name.", - ) - - -async def verify_connection(conn: Connection) -> VerifyResult: - spec = spec_for(conn.provider) - base_url = _base_url_or_default(conn) - if spec.base_url_required and not base_url: - return VerifyResult("UNREACHABLE", False, "Base URL is required.") - - if spec.transport == Transport.OLLAMA and base_url: - url = f"{base_url.rstrip('/')}/api/version" - elif spec.discovery in {"openai_models", "openrouter"} and base_url: - url = f"{ensure_v1(base_url)}/models" - elif spec.discovery == "anthropic_models" and base_url: - url = f"{base_url.rstrip('/')}/models" - else: - return VerifyResult( - "OK", True, "Connection uses provider-native authentication." - ) - - try: - async with httpx.AsyncClient(timeout=VERIFY_TIMEOUT_SECONDS) as client: - headers = ( - _anthropic_headers(conn) - if spec.auth_style == "x-api-key" - else _auth_headers(conn) - ) - response = await client.get(url, headers=headers) - if response.status_code in (401, 403): - return VerifyResult("AUTH_FAILED", False, "Authentication failed.") - if response.status_code == 404: - if spec.transport == Transport.OLLAMA and url.endswith("/v1/models"): - message = "Ollama native API should not use /v1." - elif spec.transport == Transport.OPENAI_COMPATIBLE: - message = "OpenAI-compatible servers should expose /v1/models." - else: - message = "Endpoint returned 404." - return VerifyResult("NOT_FOUND", False, message) - response.raise_for_status() - return VerifyResult("OK", True, "Connection verified.") - except httpx.ConnectError as exc: - return VerifyResult("UNREACHABLE", False, _docker_hint(base_url, exc)) - except httpx.TimeoutException as exc: - return VerifyResult("UNREACHABLE", False, f"Connection timed out: {exc}") - except httpx.HTTPError as exc: - return VerifyResult("UNREACHABLE", False, _docker_hint(base_url, exc)) - - -def _discovery_error_message(conn: Connection, exc: httpx.HTTPError) -> str: - base_url = _base_url_or_default(conn) - if isinstance(exc, httpx.HTTPStatusError): - status_code = exc.response.status_code - if status_code in (401, 403): - return "Authentication failed while discovering models." - if status_code == 404: - spec = spec_for(conn.provider) - if spec.transport == Transport.OPENAI_COMPATIBLE: - return "OpenAI-compatible servers should expose /v1/models." - return "Model discovery endpoint returned 404." - return f"Model discovery failed with HTTP {status_code}." - if isinstance(exc, httpx.TimeoutException): - return f"Model discovery timed out: {exc}" - return _docker_hint(base_url, exc) - - -def _allowlist(conn: Connection) -> set[str]: - raw = (conn.extra or {}).get("model_ids") or [] - return {str(item).strip() for item in raw if str(item).strip()} - - -def _litellm_info(model_string: str, model_id: str) -> dict[str, Any]: - with contextlib.suppress(Exception): - info = litellm.get_model_info(model=model_string) - if isinstance(info, dict): - return info - return ( - litellm.model_cost.get(model_string) or litellm.model_cost.get(model_id) or {} - ) - - -def _classify_from_litellm(model_string: str, model_id: str) -> dict[str, Any]: - info = _litellm_info(model_string, model_id) - mode = info.get("mode") - supports_image_input = False - supports_tools = False - with contextlib.suppress(Exception): - supports_image_input = bool(litellm.supports_vision(model=model_string)) - with contextlib.suppress(Exception): - supports_tools = bool(litellm.supports_function_calling(model=model_string)) - return { - "supports_chat": mode in (None, "chat", "completion", "responses"), - "max_input_tokens": info.get("max_input_tokens") or info.get("max_tokens"), - "supports_image_input": supports_image_input, - "supports_tools": supports_tools, - "supports_image_generation": mode - in {"image_generation", "image_generation_model"}, - } - - -def derive_capabilities( - conn: Connection, model_id: str, metadata: dict | None = None -) -> dict[str, Any]: - metadata = metadata or {} - spec = spec_for(conn.provider) - model_string, _ = to_litellm(conn, model_id) - facts = _classify_from_litellm(model_string, model_id) - if spec.transport == Transport.OLLAMA: - caps = set(metadata.get("capabilities") or []) - details = metadata.get("details") or {} - facts.update( - { - "supports_chat": "embedding" not in caps, - "supports_image_input": "vision" in caps - or facts["supports_image_input"], - "supports_tools": "tools" in caps or facts["supports_tools"], - "supports_image_generation": False, - "max_input_tokens": metadata.get("context_length") - or metadata.get("num_ctx") - or details.get("context_length") - or facts["max_input_tokens"], - } - ) - return facts - - -async def _discover_openai_shaped_models( - conn: Connection, base_url: str | None -) -> list[dict[str, Any]]: - resolved_base_url = base_url or _base_url_or_default(conn) - if not resolved_base_url: - return [] - - url = f"{ensure_v1(resolved_base_url)}/models" - async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: - response = await client.get(url, headers=_auth_headers(conn)) - response.raise_for_status() - - results: list[dict[str, Any]] = [] - for item in response.json().get("data", []): - model_id = item.get("id") - if not model_id: - continue - results.append( - { - "model_id": model_id, - "display_name": item.get("name") or model_id, - "source": ModelSource.DISCOVERED, - **derive_capabilities(conn, model_id, item), - "metadata": item, - } - ) - return results - - -async def _discover_anthropic_models(conn: Connection) -> list[dict[str, Any]]: - base_url = _base_url_or_default(conn) - if not base_url: - return [] - - url = f"{base_url.rstrip('/')}/models" - async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: - response = await client.get(url, headers=_anthropic_headers(conn)) - response.raise_for_status() - - results: list[dict[str, Any]] = [] - for item in response.json().get("data", []): - model_id = item.get("id") - if not model_id: - continue - results.append( - { - "model_id": model_id, - "display_name": item.get("display_name") or model_id, - "source": ModelSource.DISCOVERED, - **derive_capabilities(conn, model_id, item), - "metadata": item, - } - ) - return results - - -async def _ollama_tags_then_show(conn: Connection) -> list[dict[str, Any]]: - if not conn.base_url: - return [] - - base_url = conn.base_url.rstrip("/") - async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: - response = await client.get(f"{base_url}/api/tags", headers=_auth_headers(conn)) - response.raise_for_status() - models = response.json().get("models", []) - results: list[dict[str, Any]] = [] - for item in models: - model_id = item.get("model") or item.get("name") - if not model_id: - continue - metadata = dict(item) - with contextlib.suppress(Exception): - show_response = await client.post( - f"{base_url}/api/show", - json={"model": model_id}, - headers=_auth_headers(conn), - ) - show_response.raise_for_status() - metadata.update(show_response.json()) - results.append( - { - "model_id": model_id, - "display_name": item.get("name") or model_id, - "source": ModelSource.DISCOVERED, - **derive_capabilities(conn, model_id, metadata), - "metadata": metadata, - } - ) - return results - - -async def _openrouter_models(conn: Connection) -> list[dict[str, Any]]: - base_url = _base_url_or_default(conn) or "https://openrouter.ai/api/v1" - async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: - response = await client.get( - f"{ensure_v1(base_url)}/models", headers=_auth_headers(conn) - ) - response.raise_for_status() - return normalize_openrouter_models(response.json().get("data", [])) - - -def _litellm_static_models(conn: Connection) -> list[dict[str, Any]]: - provider = conn.provider - prefix = spec_for(provider).litellm_prefix or provider - results: list[dict[str, Any]] = [] - for model_string, metadata in litellm.model_cost.items(): - if not isinstance(model_string, str) or not model_string.startswith( - f"{prefix}/" - ): - continue - model_id = model_string.split("/", 1)[1] - results.append( - { - "model_id": model_id, - "display_name": metadata.get("display_name") or model_id, - "source": ModelSource.DISCOVERED, - **_classify_from_litellm(model_string, model_id), - "metadata": metadata, - } - ) - return results - - -async def _discover_bedrock_models(conn: Connection) -> list[dict[str, Any]]: - params = (conn.extra or {}).get("litellm_params", {}) - region_name = params.get("aws_region_name") - if not region_name: - return [] - - def list_models() -> list[dict[str, Any]]: - import os - - import boto3 - - if bearer_token := params.get("aws_bearer_token_bedrock"): - try: - os.environ["AWS_BEARER_TOKEN_BEDROCK"] = bearer_token - client = boto3.client("bedrock", region_name=region_name) - finally: - os.environ.pop("AWS_BEARER_TOKEN_BEDROCK", None) - else: - client_kwargs: dict[str, str] = {"region_name": region_name} - if params.get("aws_access_key_id"): - client_kwargs["aws_access_key_id"] = params["aws_access_key_id"] - if params.get("aws_secret_access_key"): - client_kwargs["aws_secret_access_key"] = params["aws_secret_access_key"] - client = boto3.client("bedrock", **client_kwargs) - - response = client.list_foundation_models() - results: list[dict[str, Any]] = [] - for item in response.get("modelSummaries", []): - model_id = item.get("modelId") - if not model_id: - continue - input_modalities = set(item.get("inputModalities") or []) - output_modalities = set(item.get("outputModalities") or []) - results.append( - { - "model_id": model_id, - "display_name": item.get("modelName") or model_id, - "source": ModelSource.DISCOVERED, - "supports_chat": "TEXT" in input_modalities - and "TEXT" in output_modalities, - "supports_image_input": "IMAGE" in input_modalities, - "supports_tools": None, - "supports_image_generation": "IMAGE" in output_modalities, - "max_input_tokens": None, - "metadata": item, - } - ) - return results - - return await anyio.to_thread.run_sync(list_models) - - -async def discover_models(conn: Connection) -> list[dict[str, Any]]: - allowlist = _allowlist(conn) - spec = spec_for(conn.provider) - - try: - if spec.discovery == "ollama": - results = await _ollama_tags_then_show(conn) - elif spec.discovery == "openrouter": - results = await _openrouter_models(conn) - elif spec.discovery == "anthropic_models": - results = await _discover_anthropic_models(conn) - elif spec.discovery == "openai_models": - results = await _discover_openai_shaped_models(conn, conn.base_url) - elif spec.discovery == "bedrock_models": - results = await _discover_bedrock_models(conn) - elif spec.discovery == "static": - results = _litellm_static_models(conn) - else: - results = [] - except httpx.HTTPError as exc: - raise ModelDiscoveryError(_discovery_error_message(conn, exc)) from exc - - if allowlist: - results = [item for item in results if item["model_id"] in allowlist] - return results - - -async def test_model(conn: Connection, model: Model) -> VerifyResult: - model_string, kwargs = to_litellm(conn, model.model_id) - try: - await litellm.acompletion( - model=model_string, - messages=[{"role": "user", "content": "Hello"}], - timeout=TEST_TIMEOUT_SECONDS, - **kwargs, - ) - except Exception as exc: - return _model_test_error(conn, model.model_id, exc) - - model.supports_chat = True - return VerifyResult("OK", True, "Model test succeeded.") - - -__all__ = [ - "ModelDiscoveryError", - "VerifyResult", - "derive_capabilities", - "discover_models", - "test_model", - "verify_connection", -] diff --git a/surfsense_backend/app/services/model_list_service.py b/surfsense_backend/app/services/model_list_service.py index ffb430756..33837a8a0 100644 --- a/surfsense_backend/app/services/model_list_service.py +++ b/surfsense_backend/app/services/model_list_service.py @@ -12,8 +12,6 @@ from pathlib import Path import httpx -from app.services.openrouter_model_normalizer import normalize_openrouter_models - logger = logging.getLogger(__name__) OPENROUTER_API_URL = "https://openrouter.ai/api/v1/models" @@ -24,7 +22,7 @@ CACHE_TTL_SECONDS = 86400 # 24 hours _cache: list[dict] | None = None _cache_timestamp: float = 0 -# Maps OpenRouter provider slug to native LiteLLM provider prefixes. +# Maps OpenRouter provider slug → our LiteLLMProvider enum value. # Only providers where the model-name part (after the slash) can be # used directly with the native provider's litellm prefix are listed. # @@ -123,13 +121,26 @@ def _process_models(raw_models: list[dict]) -> list[dict]: """ processed: list[dict] = [] - for normalized in normalize_openrouter_models(raw_models): - model_id: str = normalized["model_id"] - name: str = normalized.get("display_name") or model_id - context_length = normalized.get("max_input_tokens") + for model in raw_models: + model_id: str = model.get("id", "") + name: str = model.get("name", "") + context_length = model.get("context_length") + if "/" not in model_id: continue + if not _is_text_output_model(model): + continue + + if not _supports_tool_calling(model): + continue + + if not _has_sufficient_context(model): + continue + + if not _is_allowed_model(model): + continue + provider_slug, model_name = model_id.split("/", 1) context_window = _format_context_length(context_length) @@ -143,19 +154,19 @@ def _process_models(raw_models: list[dict]) -> list[dict]: } ) - # 2) Emit for the direct provider when we have a mapping - direct_provider = OPENROUTER_SLUG_TO_PROVIDER.get(provider_slug) - if direct_provider: + # 2) Emit for the native provider when we have a mapping + native_provider = OPENROUTER_SLUG_TO_PROVIDER.get(provider_slug) + if native_provider: # Google's Gemini API only serves gemini-* models. # Open-source models like gemma-* are NOT available through it. - if direct_provider == "GOOGLE" and not model_name.startswith("gemini-"): + if native_provider == "GOOGLE" and not model_name.startswith("gemini-"): continue processed.append( { "value": model_name, "label": name, - "provider": direct_provider, + "provider": native_provider, "context_window": context_window, } ) diff --git a/surfsense_backend/app/services/model_resolver.py b/surfsense_backend/app/services/model_resolver.py deleted file mode 100644 index 628c9f473..000000000 --- a/surfsense_backend/app/services/model_resolver.py +++ /dev/null @@ -1,94 +0,0 @@ -"""Single model-to-LiteLLM resolver. - -All chat, vision, image-generation, validation, and Auto routing paths should -turn a Connection + Model into LiteLLM input through this module. -""" - -from __future__ import annotations - -from collections.abc import Mapping -from typing import TYPE_CHECKING, Any - -if TYPE_CHECKING: - from app.db import Connection - -from app.services.provider_registry import Transport, spec_for - - -def ensure_v1(base_url: str | None) -> str | None: - if not base_url: - return None - stripped = base_url.rstrip("/") - if stripped.endswith("/v1"): - return stripped - return f"{stripped}/v1" - - -def _conn_value(conn: Connection | Mapping[str, Any], key: str) -> Any: - if isinstance(conn, Mapping): - return conn.get(key) - return getattr(conn, key) - - -def to_litellm( - conn: Connection | Mapping[str, Any], - model_id: str, -) -> tuple[str, dict[str, Any]]: - """Return ``(model_string, litellm_kwargs)`` for any model role.""" - provider = _conn_value(conn, "provider") - base_url = _conn_value(conn, "base_url") - api_key = _conn_value(conn, "api_key") - extra = _conn_value(conn, "extra") or {} - spec = spec_for(provider) - - kwargs: dict[str, Any] = {} - if api_key: - kwargs["api_key"] = api_key - - prefix = spec.litellm_prefix or str(provider) - model_string = f"{prefix}/{model_id}" if prefix else model_id - if base_url: - api_base = ( - ensure_v1(base_url) - if spec.transport == Transport.OPENAI_COMPATIBLE - else base_url.rstrip("/") - ) - kwargs["api_base"] = api_base - - if api_version := extra.get("api_version"): - kwargs["api_version"] = api_version - kwargs.update(extra.get("litellm_params", {})) - kwargs.update(extra.get("kwargs", {})) - if provider == "bedrock" and ( - bearer_token := kwargs.pop("aws_bearer_token_bedrock", None) - ): - kwargs["api_key"] = bearer_token - return model_string, kwargs - - -def native_connection_from_config(config: Mapping[str, Any]) -> dict[str, Any]: - """Build an in-memory connection mapping from a global config.""" - provider = str( - config.get("provider") - or config.get("litellm_provider") - or config.get("custom_provider") - or "openai" - ) - extra: dict[str, Any] = { - "litellm_params": config.get("litellm_params") or {}, - } - if config.get("api_version"): - extra["api_version"] = config.get("api_version") - return { - "provider": provider, - "base_url": config.get("api_base") or None, - "api_key": config.get("api_key") or None, - "extra": extra, - } - - -__all__ = [ - "ensure_v1", - "native_connection_from_config", - "to_litellm", -] diff --git a/surfsense_backend/app/services/obsidian_plugin_indexer.py b/surfsense_backend/app/services/obsidian_plugin_indexer.py index cd05d7935..13f43d1ee 100644 --- a/surfsense_backend/app/services/obsidian_plugin_indexer.py +++ b/surfsense_backend/app/services/obsidian_plugin_indexer.py @@ -199,12 +199,11 @@ async def _extract_binary_attachment_markdown( async def _run_etl_extract(*, file_path: str, filename: str, vision_llm): """Lazy-load ETL dependencies to avoid module-import cycles.""" - from app.etl_pipeline.cache import extract_with_cache from app.etl_pipeline.etl_document import EtlRequest + from app.etl_pipeline.etl_pipeline_service import EtlPipelineService - return await extract_with_cache( - EtlRequest(file_path=file_path, filename=filename), - vision_llm=vision_llm, + return await EtlPipelineService(vision_llm=vision_llm).extract( + EtlRequest(file_path=file_path, filename=filename) ) diff --git a/surfsense_backend/app/services/openrouter_integration_service.py b/surfsense_backend/app/services/openrouter_integration_service.py index 17b8c10eb..6454e2d58 100644 --- a/surfsense_backend/app/services/openrouter_integration_service.py +++ b/surfsense_backend/app/services/openrouter_integration_service.py @@ -19,10 +19,6 @@ from typing import Any import httpx -from app.services.openrouter_model_normalizer import ( - is_openrouter_image_model, - normalize_openrouter_models, -) from app.services.quality_score import ( _HEALTH_BLEND_WEIGHT, _HEALTH_ENRICH_CONCURRENCY, @@ -278,7 +274,7 @@ def _generate_configs( OpenRouter's own ``openrouter/free`` meta-router is filtered out upstream via ``_EXCLUDED_MODEL_IDS``; we don't expose a redundant auto-select layer - because our own Auto pin + 24 h refresh + repair logic already + because our own Auto (Fastest) pin + 24 h refresh + repair logic already cover the catalogue-churn case. """ id_offset: int = settings.get("id_offset", -10000) @@ -296,16 +292,24 @@ def _generate_configs( use_default: bool = settings.get("use_default_system_instructions", True) citations_enabled: bool = settings.get("citations_enabled", True) - text_models = normalize_openrouter_models(raw_models) + text_models = [ + m + for m in raw_models + if _is_text_output_model(m) + and _supports_tool_calling(m) + and _has_sufficient_context(m) + and _is_compatible_provider(m) + and _is_allowed_model(m) + and "/" in m.get("id", "") + ] configs: list[dict] = [] taken: set[int] = set() now_ts = int(time.time()) - for normalized in text_models: - model = normalized.get("metadata") or {} - model_id: str = normalized["model_id"] - name: str = normalized.get("display_name") or model_id + for model in text_models: + model_id: str = model["id"] + name: str = model.get("name", model_id) tier = _openrouter_tier(model) static_q = static_score_or(model, now_ts=now_ts) @@ -319,10 +323,10 @@ def _generate_configs( "seo_enabled": seo_enabled, "seo_slug": None, "quota_reserve_tokens": quota_reserve_tokens, - "provider": "openrouter", + "provider": "OPENROUTER", "model_name": model_id, "api_key": api_key, - "api_base": "https://openrouter.ai/api/v1", + "api_base": "", "rpm": free_rpm if tier == "free" else rpm, "tpm": free_tpm if tier == "free" else tpm, "litellm_params": dict(litellm_params), @@ -341,9 +345,9 @@ def _generate_configs( # ``stream_new_chat`` as a fail-fast safety net before the # OpenRouter request would otherwise 404 with # ``"No endpoints found that support image input"``. - "supports_image_input": bool(normalized.get("supports_image_input")), + "supports_image_input": _supports_image_input(model), _OPENROUTER_DYNAMIC_MARKER: True, - # Auto ranking metadata. ``quality_score`` is initialised + # Auto (Fastest) ranking metadata. ``quality_score`` is initialised # to the static score and gets re-blended with health on the next # ``_enrich_health`` pass (synchronous on refresh, deferred on cold # start so startup latency is unchanged). @@ -358,7 +362,11 @@ def _generate_configs( return configs +# ID-offset bands used to keep dynamic OpenRouter configs in their own +# namespace per surface. Image / vision get separate bands so a single +# Postgres-INTEGER cfg ID is unambiguous about which selector it belongs to. _OPENROUTER_IMAGE_ID_OFFSET_DEFAULT = -20000 +_OPENROUTER_VISION_ID_OFFSET_DEFAULT = -30000 def _generate_image_gen_configs( @@ -392,7 +400,14 @@ def _generate_image_gen_configs( free_rpm: int = settings.get("free_rpm", 20) litellm_params: dict = settings.get("litellm_params") or {} - image_models = [m for m in raw_models if is_openrouter_image_model(m)] + image_models = [ + m + for m in raw_models + if _is_image_output_model(m) + and _is_compatible_provider(m) + and _is_allowed_model(m) + and "/" in m.get("id", "") + ] configs: list[dict] = [] taken: set[int] = set() @@ -405,9 +420,14 @@ def _generate_image_gen_configs( "id": _stable_config_id(model_id, id_offset, taken), "name": name, "description": f"{name} via OpenRouter (image generation)", - "provider": "openrouter", + "provider": "OPENROUTER", "model_name": model_id, "api_key": api_key, + # Pin to OpenRouter's public base URL so a downstream call site + # that forgets ``resolve_api_base`` still doesn't inherit + # ``AZURE_OPENAI_ENDPOINT`` and 404 on + # ``image_generation/transformation`` (defense-in-depth, see + # ``provider_api_base`` docstring). "api_base": "https://openrouter.ai/api/v1", "api_version": None, "rpm": free_rpm if tier == "free" else rpm, @@ -420,6 +440,93 @@ def _generate_image_gen_configs( return configs +def _generate_vision_llm_configs( + raw_models: list[dict], settings: dict[str, Any] +) -> list[dict]: + """Convert OpenRouter vision-capable LLMs into global vision-LLM config + dicts (matches the YAML shape consumed by ``vision_llm_routes``). + + Filter: + - architecture.input_modalities contains "image" + - architecture.output_modalities contains "text" + - compatible provider (excluded slugs blocked) + - allowed model id (excluded list blocked) + + Vision-LLM is invoked from the indexer (image extraction during + document upload) via ``langchain_litellm.ChatLiteLLM.ainvoke``, so + the chat-only ``_supports_tool_calling`` and ``_has_sufficient_context`` + filters do not apply: a small-context vision model that doesn't + advertise tool-calling is still perfectly viable for "describe this + image" prompts. + """ + id_offset: int = int( + settings.get("vision_id_offset") or _OPENROUTER_VISION_ID_OFFSET_DEFAULT + ) + api_key: str = settings.get("api_key", "") + rpm: int = settings.get("rpm", 200) + tpm: int = settings.get("tpm", 1_000_000) + free_rpm: int = settings.get("free_rpm", 20) + free_tpm: int = settings.get("free_tpm", 100_000) + quota_reserve_tokens: int = settings.get("quota_reserve_tokens", 4000) + litellm_params: dict = settings.get("litellm_params") or {} + + vision_models = [ + m + for m in raw_models + if _is_vision_input_model(m) + and _is_compatible_provider(m) + and _is_allowed_model(m) + and "/" in m.get("id", "") + ] + + configs: list[dict] = [] + taken: set[int] = set() + for model in vision_models: + model_id: str = model["id"] + name: str = model.get("name", model_id) + tier = _openrouter_tier(model) + pricing = model.get("pricing") or {} + + # Capture per-token prices so ``pricing_registration`` can + # register them with LiteLLM at startup (and so the cost + # estimator in ``estimate_call_reserve_micros`` can resolve + # them at reserve time). + try: + input_cost = float(pricing.get("prompt", 0) or 0) + except (TypeError, ValueError): + input_cost = 0.0 + try: + output_cost = float(pricing.get("completion", 0) or 0) + except (TypeError, ValueError): + output_cost = 0.0 + + cfg: dict[str, Any] = { + "id": _stable_config_id(model_id, id_offset, taken), + "name": name, + "description": f"{name} via OpenRouter (vision)", + "provider": "OPENROUTER", + "model_name": model_id, + "api_key": api_key, + # Pin to OpenRouter's public base URL so a downstream call site + # that forgets ``resolve_api_base`` still doesn't inherit + # ``AZURE_OPENAI_ENDPOINT`` (defense-in-depth, see + # ``provider_api_base`` docstring). + "api_base": "https://openrouter.ai/api/v1", + "api_version": None, + "rpm": free_rpm if tier == "free" else rpm, + "tpm": free_tpm if tier == "free" else tpm, + "litellm_params": dict(litellm_params), + "billing_tier": tier, + "quota_reserve_tokens": quota_reserve_tokens, + "input_cost_per_token": input_cost or None, + "output_cost_per_token": output_cost or None, + _OPENROUTER_DYNAMIC_MARKER: True, + } + configs.append(cfg) + + return configs + + class OpenRouterIntegrationService: """Singleton that manages the dynamic OpenRouter model catalogue.""" @@ -446,9 +553,11 @@ class OpenRouterIntegrationService: # Cached raw catalogue from the most recent fetch. Image / vision # emitters reuse this to avoid a second network call per surface. self._raw_models: list[dict] = [] - # Image config cache (only populated when the matching opt-in flag is - # true on initialize). Refreshed in lockstep with the chat catalogue. + # Image / vision config caches (only populated when the matching + # opt-in flag is true on initialize). Refreshed in lockstep with + # the chat catalogue. self._image_configs: list[dict] = [] + self._vision_configs: list[dict] = [] @classmethod def get_instance(cls) -> "OpenRouterIntegrationService": @@ -483,7 +592,7 @@ class OpenRouterIntegrationService: self._configs_by_id = {c["id"]: c for c in self._configs} self._raw_pricing = _extract_raw_pricing(raw_models) - # Populate image cache when its opt-in flag is set. + # Populate image / vision caches when their opt-in flag is set. # Empty otherwise so the accessors return [] without re-running # filters every refresh. if settings.get("image_generation_enabled"): @@ -495,6 +604,15 @@ class OpenRouterIntegrationService: else: self._image_configs = [] + if settings.get("vision_enabled"): + self._vision_configs = _generate_vision_llm_configs(raw_models, settings) + logger.info( + "OpenRouter integration: vision LLM emission ON (%d models)", + len(self._vision_configs), + ) + else: + self._vision_configs = [] + self._initialized = True tier_counts = self._tier_counts(self._configs) @@ -548,9 +666,9 @@ class OpenRouterIntegrationService: self._configs = new_configs self._configs_by_id = new_by_id - # Image list is atomic-swapped the same way: filter out + # Image / vision lists are atomic-swapped the same way: filter out # the previous dynamic entries from the live config list and append - # the freshly generated ones. No-op when the opt-in flag is off. + # the freshly generated ones. No-ops when the opt-in flag is off. if self._settings.get("image_generation_enabled"): new_image = _generate_image_gen_configs(raw_models, self._settings) static_image = [ @@ -561,6 +679,16 @@ class OpenRouterIntegrationService: app_config.GLOBAL_IMAGE_GEN_CONFIGS = static_image + new_image self._image_configs = new_image + if self._settings.get("vision_enabled"): + new_vision = _generate_vision_llm_configs(raw_models, self._settings) + static_vision = [ + c + for c in app_config.GLOBAL_VISION_LLM_CONFIGS + if not c.get(_OPENROUTER_DYNAMIC_MARKER) + ] + app_config.GLOBAL_VISION_LLM_CONFIGS = static_vision + new_vision + self._vision_configs = new_vision + # Catalogue churn invalidates per-config "recently healthy" credit # earned by the previous turn's preflight. Drop the whole table so # the next turn re-probes against the freshly loaded configs. @@ -582,7 +710,7 @@ class OpenRouterIntegrationService: ) # Re-blend health scores against the freshly fetched catalogue. Also - # re-stamps health for any YAML-curated cfg with provider=openrouter + # re-stamps health for any YAML-curated cfg with provider==OPENROUTER # so a hand-picked dead OR model is gated like a dynamic one. await self._enrich_health_safely(static_configs + new_configs, log_summary=True) @@ -630,7 +758,7 @@ class OpenRouterIntegrationService: return counts # ------------------------------------------------------------------ - # Auto health enrichment + # Auto (Fastest) health enrichment # ------------------------------------------------------------------ async def _enrich_health_safely( @@ -659,7 +787,7 @@ class OpenRouterIntegrationService: the entire previous cycle's cache for this run. """ or_cfgs = [ - c for c in configs if str(c.get("provider", "")).lower() == "openrouter" + c for c in configs if str(c.get("provider", "")).upper() == "OPENROUTER" ] if not or_cfgs: return @@ -840,6 +968,17 @@ class OpenRouterIntegrationService: """ return list(self._image_configs) + def get_vision_llm_configs(self) -> list[dict]: + """Return the dynamic OpenRouter vision-LLM configs (empty list + when the ``vision_enabled`` flag is off). + + Each entry exposes ``input_cost_per_token`` / ``output_cost_per_token`` + so ``pricing_registration`` can teach LiteLLM the cost of these + models the same way it does for chat — which keeps the billable + wrapper able to debit accurate micro-USD on a vision call. + """ + return list(self._vision_configs) + def get_raw_pricing(self) -> dict[str, dict[str, str]]: """Return the cached raw OpenRouter pricing map. diff --git a/surfsense_backend/app/services/openrouter_model_normalizer.py b/surfsense_backend/app/services/openrouter_model_normalizer.py deleted file mode 100644 index 5998a2f1f..000000000 --- a/surfsense_backend/app/services/openrouter_model_normalizer.py +++ /dev/null @@ -1,123 +0,0 @@ -"""Shared OpenRouter model normalization. - -OpenRouter metadata is richer than generic OpenAI-compatible ``/models`` -responses. Keep all OpenRouter filtering and capability extraction here so -GLOBAL catalogue generation and BYOK discovery agree. -""" - -from __future__ import annotations - -from typing import Any - -from app.db import ModelSource - -MIN_CONTEXT_LENGTH = 100_000 - -EXCLUDED_PROVIDER_SLUGS = {"amazon"} -EXCLUDED_MODEL_IDS: set[str] = { - "openai/gpt-4-1106-preview", - "openai/gpt-4-turbo-preview", - "openai/gpt-4o:extended", - "arcee-ai/virtuoso-large", - "openai/o3-deep-research", - "openai/o4-mini-deep-research", - "openrouter/free", -} -EXCLUDED_MODEL_SUFFIXES: tuple[str, ...] = ("-deep-research",) - - -def is_text_output_model(model: dict[str, Any]) -> bool: - output_mods = model.get("architecture", {}).get("output_modalities", []) - return output_mods == ["text"] - - -def is_image_output_model(model: dict[str, Any]) -> bool: - output_mods = model.get("architecture", {}).get("output_modalities", []) or [] - return "image" in output_mods - - -def supports_image_input(model: dict[str, Any]) -> bool: - input_mods = model.get("architecture", {}).get("input_modalities", []) or [] - return "image" in input_mods - - -def supports_tool_calling(model: dict[str, Any]) -> bool: - supported = model.get("supported_parameters") or [] - return "tools" in supported - - -def has_sufficient_context(model: dict[str, Any]) -> bool: - return int(model.get("context_length") or 0) >= MIN_CONTEXT_LENGTH - - -def is_compatible_provider(model: dict[str, Any]) -> bool: - model_id = str(model.get("id") or "") - slug = model_id.split("/", 1)[0] if "/" in model_id else "" - return slug not in EXCLUDED_PROVIDER_SLUGS - - -def is_allowed_model(model: dict[str, Any]) -> bool: - model_id = str(model.get("id") or "") - if model_id in EXCLUDED_MODEL_IDS: - return False - base_id = model_id.split(":")[0] - return not base_id.endswith(EXCLUDED_MODEL_SUFFIXES) - - -def is_openrouter_chat_model(model: dict[str, Any]) -> bool: - return ( - "/" in str(model.get("id") or "") - and is_text_output_model(model) - and supports_tool_calling(model) - and has_sufficient_context(model) - and is_compatible_provider(model) - and is_allowed_model(model) - ) - - -def is_openrouter_image_model(model: dict[str, Any]) -> bool: - return ( - "/" in str(model.get("id") or "") - and is_image_output_model(model) - and is_compatible_provider(model) - and is_allowed_model(model) - ) - - -def normalize_openrouter_models( - raw_models: list[dict[str, Any]], -) -> list[dict[str, Any]]: - normalized: list[dict[str, Any]] = [] - for model in raw_models: - if not is_openrouter_chat_model(model): - continue - model_id = str(model.get("id") or "") - normalized.append( - { - "model_id": model_id, - "display_name": model.get("name") or model_id, - "source": ModelSource.DISCOVERED, - "supports_chat": True, - "max_input_tokens": model.get("context_length"), - "supports_image_input": supports_image_input(model), - "supports_tools": supports_tool_calling(model), - "supports_image_generation": False, - "metadata": model, - } - ) - return normalized - - -__all__ = [ - "MIN_CONTEXT_LENGTH", - "has_sufficient_context", - "is_allowed_model", - "is_compatible_provider", - "is_image_output_model", - "is_openrouter_chat_model", - "is_openrouter_image_model", - "is_text_output_model", - "normalize_openrouter_models", - "supports_image_input", - "supports_tool_calling", -] diff --git a/surfsense_backend/app/services/etl_credit_service.py b/surfsense_backend/app/services/page_limit_service.py similarity index 67% rename from surfsense_backend/app/services/etl_credit_service.py rename to surfsense_backend/app/services/page_limit_service.py index 5c4ea4bbd..47fe07fc6 100644 --- a/surfsense_backend/app/services/etl_credit_service.py +++ b/surfsense_backend/app/services/page_limit_service.py @@ -1,14 +1,5 @@ """ -Service for charging the unified credit wallet for ETL document processing. - -Replaces the legacy ``PageLimitService`` page-quota model. Page counts are -still estimated the same way; they are now converted to USD micro-credits -(``config.MICROS_PER_PAGE`` per page, times a per-mode multiplier) and debited -from ``user.credit_micros_balance``. - -When ``config.ETL_CREDIT_BILLING_ENABLED`` is False (the default for -self-hosted / OSS installs) every check/charge is a no-op, preserving the prior -effectively-unlimited ETL behaviour. +Service for managing user page limits for ETL services. """ import os @@ -17,125 +8,141 @@ from pathlib import Path, PurePosixPath from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from app.config import config - -class InsufficientCreditsError(Exception): - """Raised when a user lacks enough credit to process a document.""" +class PageLimitExceededError(Exception): + """ + Exception raised when a user exceeds their page processing limit. + """ def __init__( self, - message: str = "Insufficient credits to process this document. " - "Add more credits to continue.", - balance_micros: int = 0, - required_micros: int = 0, + message: str = "Page limit exceeded. Please contact admin to increase limits for your account.", + pages_used: int = 0, + pages_limit: int = 0, + pages_to_add: int = 0, ): - self.balance_micros = balance_micros - self.required_micros = required_micros + self.pages_used = pages_used + self.pages_limit = pages_limit + self.pages_to_add = pages_to_add super().__init__(message) -class EtlCreditService: - """Checks and charges the credit wallet for ETL page processing.""" +class PageLimitService: + """Service for checking and updating user page limits.""" def __init__(self, session: AsyncSession): self.session = session - @staticmethod - def billing_enabled() -> bool: - return config.ETL_CREDIT_BILLING_ENABLED - - @staticmethod - def pages_to_micros(pages: int, multiplier: int = 1) -> int: - """Convert a (multiplied) page count to USD micro-credits.""" - return int(pages) * int(multiplier) * config.MICROS_PER_PAGE - - async def get_available_micros(self, user_id: str) -> int | None: - """Return spendable credit in micro-USD (``balance - reserved``). - - Returns ``None`` when ETL billing is disabled, which callers treat as - "unlimited" (no batch skipping, no blocking). + async def check_page_limit( + self, user_id: str, estimated_pages: int = 1 + ) -> tuple[bool, int, int]: """ - if not config.ETL_CREDIT_BILLING_ENABLED: - return None + Check if user has enough pages remaining for processing. + Args: + user_id: The user's ID + estimated_pages: Estimated number of pages to be processed + + Returns: + Tuple of (has_capacity, pages_used, pages_limit) + + Raises: + PageLimitExceededError: If user would exceed their page limit + """ from app.db import User + # Get user's current page usage result = await self.session.execute( - select(User.credit_micros_balance, User.credit_micros_reserved).where( - User.id == user_id - ) + select(User.pages_used, User.pages_limit).where(User.id == user_id) ) row = result.first() + if not row: raise ValueError(f"User with ID {user_id} not found") - balance, reserved = row - return balance - reserved + pages_used, pages_limit = row - async def check_credits( - self, user_id: str, estimated_pages: int = 1, multiplier: int = 1 - ) -> None: - """Raise :class:`InsufficientCreditsError` if the user can't afford to - process ``estimated_pages`` (times ``multiplier``). - - No-op when ETL billing is disabled. - """ - if not config.ETL_CREDIT_BILLING_ENABLED: - return - - required = self.pages_to_micros(estimated_pages, multiplier) - available = await self.get_available_micros(user_id) - if available is None: - return - - if required > available: - raise InsufficientCreditsError( - message=( - "Processing this document would exceed your available " - f"credit. Available: ${available / 1_000_000:.2f}. " - f"This document costs about ${required / 1_000_000:.2f} " - f"({estimated_pages} page(s)). Add more credits to continue." - ), - balance_micros=available, - required_micros=required, + # Check if adding estimated pages would exceed limit + if pages_used + estimated_pages > pages_limit: + raise PageLimitExceededError( + message=f"Processing this document would exceed your page limit. " + f"Used: {pages_used}/{pages_limit} pages. " + f"Document has approximately {estimated_pages} page(s). " + f"Please contact admin to increase limits for your account.", + pages_used=pages_used, + pages_limit=pages_limit, + pages_to_add=estimated_pages, ) - async def charge_credits( - self, user_id: str, pages: int, multiplier: int = 1 - ) -> int | None: - """Debit the credit wallet after successful processing. + return True, pages_used, pages_limit - The balance may dip slightly negative when the actual page count - exceeds the pre-check estimate (the document is already processed), - mirroring the prior ``allow_exceed=True`` semantics. - - Returns the new balance in micros, or ``None`` when billing is disabled. + async def update_page_usage( + self, user_id: str, pages_to_add: int, allow_exceed: bool = False + ) -> int: """ - if not config.ETL_CREDIT_BILLING_ENABLED: - return None + Update user's page usage after successful processing. + Args: + user_id: The user's ID + pages_to_add: Number of pages to add to usage + allow_exceed: If True, allows update even if it exceeds limit + (used when document was already processed after passing initial check) + + Returns: + New total pages_used value + + Raises: + PageLimitExceededError: If adding pages would exceed limit and allow_exceed is False + """ from app.db import User + # Get user result = await self.session.execute(select(User).where(User.id == user_id)) user = result.unique().scalar_one_or_none() + if not user: raise ValueError(f"User with ID {user_id} not found") - cost = self.pages_to_micros(pages, multiplier) - user.credit_micros_balance -= cost + # Check if this would exceed limit (only if allow_exceed is False) + new_usage = user.pages_used + pages_to_add + if not allow_exceed and new_usage > user.pages_limit: + raise PageLimitExceededError( + message=f"Cannot update page usage. Would exceed limit. " + f"Current: {user.pages_used}/{user.pages_limit}, " + f"Trying to add: {pages_to_add}", + pages_used=user.pages_used, + pages_limit=user.pages_limit, + pages_to_add=pages_to_add, + ) + + # Update usage + user.pages_used = new_usage await self.session.commit() await self.session.refresh(user) - # Best-effort: fire an auto-reload check if the balance dropped low. - try: - from app.services.auto_reload_service import maybe_trigger_auto_reload + return user.pages_used - await maybe_trigger_auto_reload(user_id) - except Exception: - pass + async def get_page_usage(self, user_id: str) -> tuple[int, int]: + """ + Get user's current page usage and limit. - return user.credit_micros_balance + Args: + user_id: The user's ID + + Returns: + Tuple of (pages_used, pages_limit) + """ + from app.db import User + + result = await self.session.execute( + select(User.pages_used, User.pages_limit).where(User.id == user_id) + ) + row = result.first() + + if not row: + raise ValueError(f"User with ID {user_id} not found") + + return row def estimate_pages_from_elements(self, elements: list) -> int: """ diff --git a/surfsense_backend/app/services/pricing_registration.py b/surfsense_backend/app/services/pricing_registration.py index 7343df737..de98e50c2 100644 --- a/surfsense_backend/app/services/pricing_registration.py +++ b/surfsense_backend/app/services/pricing_registration.py @@ -143,19 +143,21 @@ def _register_chat_shape_configs( sample_keys: list[str] = [] for cfg in configs: - provider = str(cfg.get("provider") or cfg.get("litellm_provider") or "").lower() + provider = str(cfg.get("provider") or "").upper() model_name = str(cfg.get("model_name") or "").strip() litellm_params = cfg.get("litellm_params") or {} base_model = str(litellm_params.get("base_model") or model_name).strip() - if provider == "openrouter": + if provider == "OPENROUTER": entry = or_pricing.get(model_name) if entry: input_cost = _safe_float(entry.get("prompt")) output_cost = _safe_float(entry.get("completion")) else: - # Some dynamically materialized configs can carry pricing - # inline when the raw OpenRouter cache has no matching entry. + # Vision configs from ``_generate_vision_llm_configs`` + # carry their pricing inline because the OpenRouter + # raw-pricing cache is keyed by chat-catalogue model_id; + # vision flows pick up the inline values here. input_cost = _safe_float(cfg.get("input_cost_per_token")) output_cost = _safe_float(cfg.get("output_cost_per_token")) if input_cost == 0.0 and output_cost == 0.0: @@ -187,11 +189,12 @@ def _register_chat_shape_configs( skipped_no_pricing += 1 continue aliases = _alias_set_for_yaml(provider, model_name, base_model) + provider_slug = "azure" if provider == "AZURE_OPENAI" else provider.lower() count = _register( aliases, input_cost=input_cost, output_cost=output_cost, - provider=provider, + provider=provider_slug, ) if count > 0: registered_models += 1 @@ -214,8 +217,9 @@ def _register_chat_shape_configs( def register_pricing_from_global_configs() -> None: """Register pricing for every known LLM deployment with LiteLLM. - Walks ``config.GLOBAL_LLM_CONFIGS`` so chat and vision calls can resolve - cost from the same chat-shaped deployment configs: + Walks ``config.GLOBAL_LLM_CONFIGS`` *and* ``config.GLOBAL_VISION_LLM_CONFIGS`` + so vision calls (during indexing) can resolve cost the same way chat + calls do — namely: 1. ``OPENROUTER``: pulls the cached raw pricing from ``OpenRouterIntegrationService`` (populated during its own @@ -242,7 +246,10 @@ def register_pricing_from_global_configs() -> None: from app.config import config as app_config chat_configs: list[dict] = list(getattr(app_config, "GLOBAL_LLM_CONFIGS", []) or []) - if not chat_configs: + vision_configs: list[dict] = list( + getattr(app_config, "GLOBAL_VISION_LLM_CONFIGS", []) or [] + ) + if not chat_configs and not vision_configs: logger.info("[PricingRegistration] no global configs to register") return @@ -261,3 +268,7 @@ def register_pricing_from_global_configs() -> None: if chat_configs: _register_chat_shape_configs(chat_configs, or_pricing=or_pricing, label="chat") + if vision_configs: + _register_chat_shape_configs( + vision_configs, or_pricing=or_pricing, label="vision" + ) diff --git a/surfsense_backend/app/services/provider_api_base.py b/surfsense_backend/app/services/provider_api_base.py new file mode 100644 index 000000000..dca1f9462 --- /dev/null +++ b/surfsense_backend/app/services/provider_api_base.py @@ -0,0 +1,106 @@ +"""Provider-aware ``api_base`` resolution shared by chat / image-gen / vision. + +LiteLLM falls back to the module-global ``litellm.api_base`` when an +individual call doesn't pass one, which silently inherits provider-agnostic +env vars like ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE``. Without an +explicit ``api_base``, an ``openrouter/<model>`` request can end up at an +Azure endpoint and 404 with ``Resource not found`` (real reproducer: +[litellm/llms/openrouter/image_generation/transformation.py:242-263] appends +``/chat/completions`` to whatever inherited base it gets, regardless of +provider). + +The chat router has had this defense for a while +(``llm_router_service.py:466-478``). This module hoists the maps + cascade +into a tiny standalone helper so vision and image-gen can share the same +source of truth without an inter-service circular import. +""" + +from __future__ import annotations + +PROVIDER_DEFAULT_API_BASE: dict[str, str] = { + "openrouter": "https://openrouter.ai/api/v1", + "groq": "https://api.groq.com/openai/v1", + "mistral": "https://api.mistral.ai/v1", + "perplexity": "https://api.perplexity.ai", + "xai": "https://api.x.ai/v1", + "cerebras": "https://api.cerebras.ai/v1", + "deepinfra": "https://api.deepinfra.com/v1/openai", + "fireworks_ai": "https://api.fireworks.ai/inference/v1", + "together_ai": "https://api.together.xyz/v1", + "anyscale": "https://api.endpoints.anyscale.com/v1", + "cometapi": "https://api.cometapi.com/v1", + "sambanova": "https://api.sambanova.ai/v1", +} +"""Default ``api_base`` per LiteLLM provider prefix (lowercase). + +Only providers with a well-known, stable public base URL are listed — +self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai, +huggingface, databricks, cloudflare, replicate) are intentionally omitted +so their existing config-driven behaviour is preserved.""" + + +PROVIDER_KEY_DEFAULT_API_BASE: dict[str, str] = { + "DEEPSEEK": "https://api.deepseek.com/v1", + "ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1", + "MOONSHOT": "https://api.moonshot.ai/v1", + "ZHIPU": "https://open.bigmodel.cn/api/paas/v4", + "MINIMAX": "https://api.minimax.io/v1", +} +"""Canonical provider key (uppercase) → base URL. + +Used when the LiteLLM provider prefix is the generic ``openai`` shim but the +config's ``provider`` field tells us which API it actually is (DeepSeek, +Alibaba, Moonshot, Zhipu, MiniMax all use the ``openai`` prefix but each +has its own base URL).""" + + +def resolve_api_base( + *, + provider: str | None, + provider_prefix: str | None, + config_api_base: str | None, +) -> str | None: + """Resolve a non-Azure-leaking ``api_base`` for a deployment. + + Cascade (first non-empty wins): + 1. The config's own ``api_base`` (whitespace-only treated as missing). + 2. ``PROVIDER_KEY_DEFAULT_API_BASE[provider.upper()]``. + 3. ``PROVIDER_DEFAULT_API_BASE[provider_prefix.lower()]``. + 4. ``None`` — caller should NOT set ``api_base`` and let the LiteLLM + provider integration apply its own default (e.g. AzureOpenAI's + deployment-derived URL, custom provider's per-deployment URL). + + Args: + provider: The config's ``provider`` field (e.g. ``"OPENROUTER"``, + ``"DEEPSEEK"``). Case-insensitive. + provider_prefix: The LiteLLM model-string prefix the same call + site builds for the model id (e.g. ``"openrouter"``, + ``"groq"``). Case-insensitive. + config_api_base: ``api_base`` from the global YAML / DB row / + OpenRouter dynamic config. Empty / whitespace-only means + "missing" — the resolver still applies the cascade. + + Returns: + A URL string, or ``None`` if no default applies for this provider. + """ + if config_api_base and config_api_base.strip(): + return config_api_base + + if provider: + key_default = PROVIDER_KEY_DEFAULT_API_BASE.get(provider.upper()) + if key_default: + return key_default + + if provider_prefix: + prefix_default = PROVIDER_DEFAULT_API_BASE.get(provider_prefix.lower()) + if prefix_default: + return prefix_default + + return None + + +__all__ = [ + "PROVIDER_DEFAULT_API_BASE", + "PROVIDER_KEY_DEFAULT_API_BASE", + "resolve_api_base", +] diff --git a/surfsense_backend/app/services/provider_capabilities.py b/surfsense_backend/app/services/provider_capabilities.py index fae283ab6..f094c9954 100644 --- a/surfsense_backend/app/services/provider_capabilities.py +++ b/surfsense_backend/app/services/provider_capabilities.py @@ -49,6 +49,51 @@ import litellm logger = logging.getLogger(__name__) +# Provider-name → LiteLLM model-prefix map. +# +# Owned here because ``app.services.provider_capabilities`` is the +# only edge that's safe to call from ``app.config``'s YAML loader at +# class-body init time. ``app.agents.chat.runtime.llm_config`` re-exports +# this constant under the historical ``PROVIDER_MAP`` name; placing the +# map there directly would re-introduce the +# ``app.config -> ... -> deliverables/tools/generate_image -> +# app.config`` cycle that prompted the move. +_PROVIDER_PREFIX_MAP: dict[str, str] = { + "OPENAI": "openai", + "ANTHROPIC": "anthropic", + "GROQ": "groq", + "COHERE": "cohere", + "GOOGLE": "gemini", + "OLLAMA": "ollama_chat", + "MISTRAL": "mistral", + "AZURE_OPENAI": "azure", + "OPENROUTER": "openrouter", + "XAI": "xai", + "BEDROCK": "bedrock", + "VERTEX_AI": "vertex_ai", + "TOGETHER_AI": "together_ai", + "FIREWORKS_AI": "fireworks_ai", + "DEEPSEEK": "openai", + "ALIBABA_QWEN": "openai", + "MOONSHOT": "openai", + "ZHIPU": "openai", + "GITHUB_MODELS": "github", + "REPLICATE": "replicate", + "PERPLEXITY": "perplexity", + "ANYSCALE": "anyscale", + "DEEPINFRA": "deepinfra", + "CEREBRAS": "cerebras", + "SAMBANOVA": "sambanova", + "AI21": "ai21", + "CLOUDFLARE": "cloudflare", + "DATABRICKS": "databricks", + "COMETAPI": "cometapi", + "HUGGINGFACE": "huggingface", + "MINIMAX": "openai", + "CUSTOM": "custom", +} + + def _candidate_model_strings( *, provider: str | None, @@ -78,7 +123,12 @@ def _candidate_model_strings( seen.add(key) candidates.append(key) - provider_prefix = custom_provider or provider + provider_prefix: str | None = None + if provider: + provider_prefix = _PROVIDER_PREFIX_MAP.get(provider.upper(), provider.lower()) + if custom_provider: + # ``custom_provider`` overrides everything for CUSTOM/proxy setups. + provider_prefix = custom_provider primary_model = base_model or model_name bare_model = model_name diff --git a/surfsense_backend/app/services/provider_registry.py b/surfsense_backend/app/services/provider_registry.py deleted file mode 100644 index 67d1c4db4..000000000 --- a/surfsense_backend/app/services/provider_registry.py +++ /dev/null @@ -1,126 +0,0 @@ -"""Provider registry for model connections. - -The provider string is the single public identity axis. This registry only -describes providers whose behavior differs from LiteLLM's native default. -""" - -from __future__ import annotations - -from dataclasses import dataclass -from enum import StrEnum -from typing import Literal - - -class Transport(StrEnum): - NATIVE = "NATIVE" - OPENAI_COMPATIBLE = "OPENAI_COMPATIBLE" - OLLAMA = "OLLAMA" - - -DiscoveryKind = Literal[ - "ollama", - "openai_models", - "anthropic_models", - "bedrock_models", - "openrouter", - "static", - "none", -] - -AuthStyle = Literal["bearer", "x-api-key", "none", "native"] - - -@dataclass(frozen=True) -class ProviderSpec: - transport: Transport - litellm_prefix: str | None - discovery: DiscoveryKind - default_base_url: str | None - base_url_required: bool - auth_style: AuthStyle - display_name: str | None = None - - -REGISTRY: dict[str, ProviderSpec] = { - "openai": ProviderSpec( - Transport.NATIVE, "openai", "openai_models", None, False, "bearer", "OpenAI" - ), - "anthropic": ProviderSpec( - Transport.NATIVE, - "anthropic", - "anthropic_models", - None, - False, - "x-api-key", - "Anthropic", - ), - "azure": ProviderSpec( - Transport.NATIVE, "azure", "static", None, True, "native", "Azure OpenAI" - ), - "vertex_ai": ProviderSpec( - Transport.NATIVE, "vertex_ai", "static", None, False, "native", "Vertex AI" - ), - "bedrock": ProviderSpec( - Transport.NATIVE, - "bedrock", - "bedrock_models", - None, - False, - "native", - "Amazon Bedrock", - ), - "openrouter": ProviderSpec( - Transport.OPENAI_COMPATIBLE, - "openrouter", - "openrouter", - "https://openrouter.ai/api/v1", - False, - "bearer", - "OpenRouter", - ), - "openai_compatible": ProviderSpec( - Transport.OPENAI_COMPATIBLE, - "openai", - "openai_models", - None, - True, - "bearer", - "OpenAI-compatible provider", - ), - "lm_studio": ProviderSpec( - Transport.OPENAI_COMPATIBLE, - "openai", - "openai_models", - "http://host.docker.internal:1234/v1", - True, - "bearer", - "LM Studio", - ), - "ollama_chat": ProviderSpec( - Transport.OLLAMA, - "ollama_chat", - "ollama", - "http://host.docker.internal:11434", - True, - "none", - "Ollama", - ), -} - - -def spec_for(provider: str | None) -> ProviderSpec: - provider_key = (provider or "").strip() - return REGISTRY.get(provider_key) or ProviderSpec( - Transport.NATIVE, provider_key or "openai", "static", None, False, "native" - ) - - -def provider_label(provider: str | None) -> str: - provider_key = (provider or "").strip() - spec = spec_for(provider_key) - if spec.display_name: - return spec.display_name - return provider_key.replace("_", " ").title() if provider_key else "Provider" - - -__all__ = ["REGISTRY", "ProviderSpec", "Transport", "provider_label", "spec_for"] diff --git a/surfsense_backend/app/services/public_chat_service.py b/surfsense_backend/app/services/public_chat_service.py index d17f411b8..e4e0dd33a 100644 --- a/surfsense_backend/app/services/public_chat_service.py +++ b/surfsense_backend/app/services/public_chat_service.py @@ -337,9 +337,6 @@ async def _get_podcast_for_snapshot( "original_id": podcast.id, "title": podcast.title, "transcript": podcast.podcast_transcript, - "storage_backend": podcast.storage_backend, - "storage_key": podcast.storage_key, - # Legacy fallback for rows rendered before the storage migration. "file_path": podcast.file_location, } @@ -720,8 +717,6 @@ async def clone_from_snapshot( new_podcast = Podcast( title=podcast_info.get("title", "Cloned Podcast"), podcast_transcript=podcast_info.get("transcript"), - storage_backend=podcast_info.get("storage_backend"), - storage_key=podcast_info.get("storage_key"), file_location=podcast_info.get("file_path"), status=PodcastStatus.READY, search_space_id=target_search_space_id, diff --git a/surfsense_backend/app/services/quality_score.py b/surfsense_backend/app/services/quality_score.py index 737dd7c2f..2fb37de21 100644 --- a/surfsense_backend/app/services/quality_score.py +++ b/surfsense_backend/app/services/quality_score.py @@ -1,4 +1,4 @@ -"""Pure-function quality scoring for Auto model selection. +"""Pure-function quality scoring for Auto (Fastest) model selection. This module is import-free of any service / request-path dependencies. All numbers are computed once during the OpenRouter refresh tick (or YAML load) @@ -108,23 +108,25 @@ PROVIDER_PRESTIGE_OR: dict[str, int] = { # YAML provider field (the upstream API shape the operator selected). PROVIDER_PRESTIGE_YAML: dict[str, int] = { - "azure": 50, - "openai": 50, - "anthropic": 50, - "gemini": 50, - "vertex_ai": 50, - "xai": 50, - "mistral": 38, - "deepseek": 38, - "cohere": 38, - "groq": 30, - "together_ai": 28, - "fireworks_ai": 28, - "perplexity": 28, - "bedrock": 28, - "openrouter": 25, - "ollama_chat": 12, - "custom": 12, + "AZURE_OPENAI": 50, + "OPENAI": 50, + "ANTHROPIC": 50, + "GOOGLE": 50, + "VERTEX_AI": 50, + "GEMINI": 50, + "XAI": 50, + "MISTRAL": 38, + "DEEPSEEK": 38, + "COHERE": 38, + "GROQ": 30, + "TOGETHER_AI": 28, + "FIREWORKS_AI": 28, + "PERPLEXITY": 28, + "MINIMAX": 28, + "BEDROCK": 28, + "OPENROUTER": 25, + "OLLAMA": 12, + "CUSTOM": 12, } @@ -273,7 +275,7 @@ def static_score_yaml(cfg: dict) -> int: listed this model. Pricing / context fall through to lazy ``litellm`` lookups; failures are silent (we just lose those sub-points). """ - provider = str(cfg.get("provider") or cfg.get("litellm_provider") or "").lower() + provider = str(cfg.get("provider", "")).upper() base = PROVIDER_PRESTIGE_YAML.get(provider, 15) model_name = cfg.get("model_name") or "" diff --git a/surfsense_backend/app/services/revert_service.py b/surfsense_backend/app/services/revert_service.py index 0cb6cd092..6db5e2604 100644 --- a/surfsense_backend/app/services/revert_service.py +++ b/surfsense_backend/app/services/revert_service.py @@ -238,14 +238,9 @@ async def _restore_in_place_document( chunk_embeddings = await asyncio.to_thread(embed_texts, chunk_texts) session.add_all( [ - Chunk( - document_id=doc.id, - content=text, - embedding=embedding, - position=i, - ) - for i, (text, embedding) in enumerate( - zip(chunk_texts, chunk_embeddings, strict=True) + Chunk(document_id=doc.id, content=text, embedding=embedding) + for text, embedding in zip( + chunk_texts, chunk_embeddings, strict=True ) ] ) @@ -341,15 +336,8 @@ async def _reinsert_document_from_revision( chunk_embeddings = await asyncio.to_thread(embed_texts, chunk_texts) session.add_all( [ - Chunk( - document_id=new_doc.id, - content=text, - embedding=embedding, - position=i, - ) - for i, (text, embedding) in enumerate( - zip(chunk_texts, chunk_embeddings, strict=True) - ) + Chunk(document_id=new_doc.id, content=text, embedding=embedding) + for text, embedding in zip(chunk_texts, chunk_embeddings, strict=True) ] ) diff --git a/surfsense_backend/app/services/token_quota_service.py b/surfsense_backend/app/services/token_quota_service.py index d32c18722..310c3eb5e 100644 --- a/surfsense_backend/app/services/token_quota_service.py +++ b/surfsense_backend/app/services/token_quota_service.py @@ -99,18 +99,7 @@ class QuotaStatus(StrEnum): class QuotaResult: - # ``used``/``limit`` are used by the anonymous (Redis) token path. - # ``balance``/``remaining``/``reserved`` are used by the credit (Postgres) - # path, all in USD micro-units. ``remaining`` == spendable (balance - reserved). - __slots__ = ( - "allowed", - "balance", - "limit", - "remaining", - "reserved", - "status", - "used", - ) + __slots__ = ("allowed", "limit", "remaining", "reserved", "status", "used") def __init__( self, @@ -120,7 +109,6 @@ class QuotaResult: limit: int, reserved: int = 0, remaining: int = 0, - balance: int = 0, ): self.allowed = allowed self.status = status @@ -128,7 +116,6 @@ class QuotaResult: self.limit = limit self.reserved = reserved self.remaining = remaining - self.balance = balance def to_dict(self) -> dict[str, Any]: return { @@ -138,7 +125,6 @@ class QuotaResult: "limit": self.limit, "reserved": self.reserved, "remaining": self.remaining, - "balance": self.balance, } @@ -519,33 +505,19 @@ class TokenQuotaService: # ------------------------------------------------------------------ @staticmethod - def _credit_status(balance: int) -> QuotaStatus: - """Map a spendable balance to OK / WARNING / BLOCKED. - - There is no longer a fixed ceiling, so WARNING fires on a low absolute - balance (``config.CREDIT_LOW_BALANCE_WARNING_MICROS``) instead of a - percentage of a limit. - """ - if balance <= 0: - return QuotaStatus.BLOCKED - if balance < config.CREDIT_LOW_BALANCE_WARNING_MICROS: - return QuotaStatus.WARNING - return QuotaStatus.OK - - @staticmethod - async def credit_reserve( + async def premium_reserve( db_session: AsyncSession, user_id: Any, request_id: str, reserve_micros: int, ) -> QuotaResult: - """Reserve ``reserve_micros`` (USD micro-units) from the user's credit - wallet. + """Reserve ``reserve_micros`` (USD micro-units) from the user's + premium credit balance. - ``QuotaResult.balance``/``reserved``/``remaining`` are in micro-USD on - this code path; callers (chat stream, credit-status route, FE display) - convert to dollars by dividing by 1_000_000. ``remaining`` is the - spendable amount (``balance - reserved``). + ``QuotaResult.used``/``limit``/``reserved``/``remaining`` are + all in micro-USD on this code path; callers (chat stream, + token-status route, FE display) convert to dollars by dividing + by 1_000_000. """ from app.db import User @@ -566,41 +538,48 @@ class TokenQuotaService: limit=0, ) - balance = user.credit_micros_balance - reserved = user.credit_micros_reserved + limit = user.premium_credit_micros_limit + used = user.premium_credit_micros_used + reserved = user.premium_credit_micros_reserved - # Block when the new hold would exceed the spendable balance. - if reserved + reserve_micros > balance: - remaining = max(0, balance - reserved) + effective = used + reserved + reserve_micros + if effective > limit: + remaining = max(0, limit - used - reserved) await db_session.rollback() return QuotaResult( allowed=False, status=QuotaStatus.BLOCKED, - used=0, - limit=balance, + used=used, + limit=limit, reserved=reserved, remaining=remaining, - balance=balance, ) - user.credit_micros_reserved = reserved + reserve_micros + user.premium_credit_micros_reserved = reserved + reserve_micros await db_session.commit() new_reserved = reserved + reserve_micros - remaining = max(0, balance - new_reserved) + remaining = max(0, limit - used - new_reserved) + warning_threshold = int(limit * 0.8) + + if (used + new_reserved) >= limit: + status = QuotaStatus.BLOCKED + elif (used + new_reserved) >= warning_threshold: + status = QuotaStatus.WARNING + else: + status = QuotaStatus.OK return QuotaResult( allowed=True, - status=TokenQuotaService._credit_status(remaining), - used=0, - limit=balance, + status=status, + used=used, + limit=limit, reserved=new_reserved, remaining=remaining, - balance=balance, ) @staticmethod - async def credit_finalize( + async def premium_finalize( db_session: AsyncSession, user_id: Any, request_id: str, @@ -608,8 +587,7 @@ class TokenQuotaService: reserved_micros: int, ) -> QuotaResult: """Settle the reservation: release ``reserved_micros`` and debit - ``actual_micros`` (the LiteLLM-reported provider cost in micro-USD) - from the balance. + ``actual_micros`` (the LiteLLM-reported provider cost in micro-USD). """ from app.db import User @@ -627,42 +605,44 @@ class TokenQuotaService: allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0 ) - user.credit_micros_reserved = max( - 0, user.credit_micros_reserved - reserved_micros + user.premium_credit_micros_reserved = max( + 0, user.premium_credit_micros_reserved - reserved_micros + ) + user.premium_credit_micros_used = ( + user.premium_credit_micros_used + actual_micros ) - user.credit_micros_balance = user.credit_micros_balance - actual_micros await db_session.commit() - balance = user.credit_micros_balance - reserved = user.credit_micros_reserved - remaining = max(0, balance - reserved) + limit = user.premium_credit_micros_limit + used = user.premium_credit_micros_used + reserved = user.premium_credit_micros_reserved + remaining = max(0, limit - used - reserved) - # Best-effort auto-reload nudge after the debit settles. - try: - from app.services.auto_reload_service import maybe_trigger_auto_reload - - await maybe_trigger_auto_reload(user_id) - except Exception: - pass + warning_threshold = int(limit * 0.8) + if used >= limit: + status = QuotaStatus.BLOCKED + elif used >= warning_threshold: + status = QuotaStatus.WARNING + else: + status = QuotaStatus.OK return QuotaResult( allowed=True, - status=TokenQuotaService._credit_status(remaining), - used=0, - limit=balance, + status=status, + used=used, + limit=limit, reserved=reserved, remaining=remaining, - balance=balance, ) @staticmethod - async def credit_release( + async def premium_release( db_session: AsyncSession, user_id: Any, reserved_micros: int, ) -> None: - """Release ``reserved_micros`` previously held by ``credit_reserve``. + """Release ``reserved_micros`` previously held by ``premium_reserve``. Used when a request fails before finalize (so the reservation doesn't leak credit). @@ -679,13 +659,13 @@ class TokenQuotaService: .scalar_one_or_none() ) if user is not None: - user.credit_micros_reserved = max( - 0, user.credit_micros_reserved - reserved_micros + user.premium_credit_micros_reserved = max( + 0, user.premium_credit_micros_reserved - reserved_micros ) await db_session.commit() @staticmethod - async def credit_get_usage( + async def premium_get_usage( db_session: AsyncSession, user_id: Any, ) -> QuotaResult: @@ -701,16 +681,24 @@ class TokenQuotaService: allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0 ) - balance = user.credit_micros_balance - reserved = user.credit_micros_reserved - remaining = max(0, balance - reserved) + limit = user.premium_credit_micros_limit + used = user.premium_credit_micros_used + reserved = user.premium_credit_micros_reserved + remaining = max(0, limit - used - reserved) + + warning_threshold = int(limit * 0.8) + if used >= limit: + status = QuotaStatus.BLOCKED + elif used >= warning_threshold: + status = QuotaStatus.WARNING + else: + status = QuotaStatus.OK return QuotaResult( - allowed=remaining > 0, - status=TokenQuotaService._credit_status(remaining), - used=0, - limit=balance, + allowed=used < limit, + status=status, + used=used, + limit=limit, reserved=reserved, remaining=remaining, - balance=balance, ) diff --git a/surfsense_backend/app/services/token_tracking_service.py b/surfsense_backend/app/services/token_tracking_service.py index d1a29b82a..3f07e6f9e 100644 --- a/surfsense_backend/app/services/token_tracking_service.py +++ b/surfsense_backend/app/services/token_tracking_service.py @@ -32,23 +32,6 @@ from app.db import TokenUsage logger = logging.getLogger(__name__) -def _bare_model_name(model: str) -> str: - """Return a model identifier with any provider routing prefix stripped. - - LiteLLM's ``get_llm_provider`` consumes the provider prefix we add in - ``to_litellm`` (e.g. ``azure/gpt-5.2-chat`` → ``gpt-5.2-chat`` because - ``azure`` is in ``litellm.provider_list``). The token-tracking success - callback therefore reports ``kwargs["model"]`` *without* that prefix, - while model metadata is registered under the *prefixed* string. Normalising - both sides to the last path segment lets the two reconcile so the per-model - breakdown carries provider/display_name and the UI attributes the turn to - the correct connection instead of falling back to a bare-name collision. - """ - if not model: - return model - return model.split("/")[-1] - - @dataclass class TokenCallRecord: model: str @@ -57,10 +40,6 @@ class TokenCallRecord: total_tokens: int cost_micros: int = 0 call_kind: str = "chat" - model_ref: str | None = None - model_id: str | None = None - display_name: str | None = None - provider: str | None = None @dataclass @@ -68,46 +47,6 @@ class TurnTokenAccumulator: """Accumulates token usage across all LLM calls within a single user turn.""" calls: list[TokenCallRecord] = field(default_factory=list) - model_metadata: dict[str, dict[str, str | None]] = field(default_factory=dict) - # Secondary index keyed by the bare model name (provider prefix stripped) so - # the LiteLLM callback — which never sees our routing prefix — can still - # reconcile its ``kwargs["model"]`` back to the registered metadata. - model_metadata_by_bare: dict[str, dict[str, str | None]] = field( - default_factory=dict - ) - - def register_model_metadata( - self, - *, - model: str, - model_ref: str | None, - model_id: str | None, - display_name: str | None, - provider: str | None, - ) -> None: - """Attach resolved model metadata for later LiteLLM callback attribution.""" - metadata = { - "model_ref": model_ref, - "model_id": model_id, - "display_name": display_name, - "provider": provider, - } - self.model_metadata[model] = metadata - # Index every reconcilable alias: the prefixed string's bare form and - # the resolved ``model_id`` (which for some providers is itself the bare - # deployment LiteLLM reports). Exact lookups always take precedence. - self.model_metadata_by_bare[_bare_model_name(model)] = metadata - if model_id: - self.model_metadata_by_bare.setdefault(_bare_model_name(model_id), metadata) - - def _lookup_metadata(self, model: str) -> dict[str, str | None]: - """Resolve registered metadata for a callback model, tolerating the - provider-prefix stripping LiteLLM applies before the success callback - fires (see :func:`_bare_model_name`).""" - exact = self.model_metadata.get(model) - if exact is not None: - return exact - return self.model_metadata_by_bare.get(_bare_model_name(model), {}) def add( self, @@ -118,14 +57,9 @@ class TurnTokenAccumulator: cost_micros: int = 0, call_kind: str = "chat", ) -> None: - metadata = self._lookup_metadata(model) self.calls.append( TokenCallRecord( model=model, - model_ref=metadata.get("model_ref"), - model_id=metadata.get("model_id"), - display_name=metadata.get("display_name"), - provider=metadata.get("provider"), prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, @@ -134,18 +68,13 @@ class TurnTokenAccumulator: ) ) - def per_message_summary(self) -> dict[str, dict[str, Any]]: + def per_message_summary(self) -> dict[str, dict[str, int]]: """Return token counts (and cost) grouped by model name.""" - by_model: dict[str, dict[str, Any]] = {} + by_model: dict[str, dict[str, int]] = {} for c in self.calls: entry = by_model.setdefault( c.model, { - "model": c.model, - "model_ref": c.model_ref, - "model_id": c.model_id, - "display_name": c.display_name, - "provider": c.provider, "prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0, @@ -213,27 +142,6 @@ def get_current_accumulator() -> TurnTokenAccumulator | None: return _turn_accumulator.get() -def register_model_usage_metadata( - *, - model: str, - model_ref: str | None, - model_id: str | None, - display_name: str | None, - provider: str | None, -) -> None: - """Register resolved model metadata with the current turn, if one exists.""" - acc = _turn_accumulator.get() - if acc is None: - return - acc.register_model_metadata( - model=model, - model_ref=model_ref, - model_id=model_id, - display_name=display_name, - provider=provider, - ) - - @asynccontextmanager async def scoped_turn() -> AsyncIterator[TurnTokenAccumulator]: """Async context manager that scopes a fresh ``TurnTokenAccumulator`` diff --git a/surfsense_backend/app/services/vision_llm_router_service.py b/surfsense_backend/app/services/vision_llm_router_service.py new file mode 100644 index 000000000..ed5de921c --- /dev/null +++ b/surfsense_backend/app/services/vision_llm_router_service.py @@ -0,0 +1,201 @@ +import logging +from typing import Any + +from litellm import Router + +from app.services.provider_api_base import resolve_api_base + +logger = logging.getLogger(__name__) + +VISION_AUTO_MODE_ID = 0 + +VISION_PROVIDER_MAP = { + "OPENAI": "openai", + "ANTHROPIC": "anthropic", + "GOOGLE": "gemini", + "AZURE_OPENAI": "azure", + "VERTEX_AI": "vertex_ai", + "BEDROCK": "bedrock", + "XAI": "xai", + "OPENROUTER": "openrouter", + "OLLAMA": "ollama_chat", + "GROQ": "groq", + "TOGETHER_AI": "together_ai", + "FIREWORKS_AI": "fireworks_ai", + "DEEPSEEK": "openai", + "MISTRAL": "mistral", + "CUSTOM": "custom", +} + + +class VisionLLMRouterService: + _instance = None + _router: Router | None = None + _model_list: list[dict] = [] + _router_settings: dict = {} + _initialized: bool = False + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + @classmethod + def get_instance(cls) -> "VisionLLMRouterService": + if cls._instance is None: + cls._instance = cls() + return cls._instance + + @classmethod + def initialize( + cls, + global_configs: list[dict], + router_settings: dict | None = None, + ) -> None: + instance = cls.get_instance() + + if instance._initialized: + logger.debug("Vision LLM Router already initialized, skipping") + return + + model_list = [] + for config in global_configs: + deployment = cls._config_to_deployment(config) + if deployment: + model_list.append(deployment) + + if not model_list: + logger.warning( + "No valid vision LLM configs found for router initialization" + ) + return + + instance._model_list = model_list + instance._router_settings = router_settings or {} + + default_settings = { + "routing_strategy": "usage-based-routing", + "num_retries": 3, + "allowed_fails": 3, + "cooldown_time": 60, + "retry_after": 5, + } + + final_settings = {**default_settings, **instance._router_settings} + + try: + instance._router = Router( + model_list=model_list, + routing_strategy=final_settings.get( + "routing_strategy", "usage-based-routing" + ), + num_retries=final_settings.get("num_retries", 3), + allowed_fails=final_settings.get("allowed_fails", 3), + cooldown_time=final_settings.get("cooldown_time", 60), + set_verbose=False, + ) + instance._initialized = True + logger.info( + "Vision LLM Router initialized with %d deployments, strategy: %s", + len(model_list), + final_settings.get("routing_strategy"), + ) + except Exception as e: + logger.error(f"Failed to initialize Vision LLM Router: {e}") + instance._router = None + + @classmethod + def _config_to_deployment(cls, config: dict) -> dict | None: + try: + if not config.get("model_name") or not config.get("api_key"): + return None + + provider = config.get("provider", "").upper() + if config.get("custom_provider"): + provider_prefix = config["custom_provider"] + model_string = f"{provider_prefix}/{config['model_name']}" + else: + provider_prefix = VISION_PROVIDER_MAP.get(provider, provider.lower()) + model_string = f"{provider_prefix}/{config['model_name']}" + + litellm_params: dict[str, Any] = { + "model": model_string, + "api_key": config.get("api_key"), + } + + api_base = resolve_api_base( + provider=provider, + provider_prefix=provider_prefix, + config_api_base=config.get("api_base"), + ) + if api_base: + litellm_params["api_base"] = api_base + + if config.get("api_version"): + litellm_params["api_version"] = config["api_version"] + + if config.get("litellm_params"): + litellm_params.update(config["litellm_params"]) + + deployment: dict[str, Any] = { + "model_name": "auto", + "litellm_params": litellm_params, + } + + if config.get("rpm"): + deployment["rpm"] = config["rpm"] + if config.get("tpm"): + deployment["tpm"] = config["tpm"] + + return deployment + + except Exception as e: + logger.warning(f"Failed to convert vision config to deployment: {e}") + return None + + @classmethod + def get_router(cls) -> Router | None: + instance = cls.get_instance() + return instance._router + + @classmethod + def is_initialized(cls) -> bool: + instance = cls.get_instance() + return instance._initialized and instance._router is not None + + @classmethod + def get_model_count(cls) -> int: + instance = cls.get_instance() + return len(instance._model_list) + + +def is_vision_auto_mode(config_id: int | None) -> bool: + return config_id == VISION_AUTO_MODE_ID + + +def build_vision_model_string( + provider: str, model_name: str, custom_provider: str | None +) -> str: + if custom_provider: + return f"{custom_provider}/{model_name}" + prefix = VISION_PROVIDER_MAP.get(provider.upper(), provider.lower()) + return f"{prefix}/{model_name}" + + +def get_global_vision_llm_config(config_id: int) -> dict | None: + from app.config import config + + if config_id == VISION_AUTO_MODE_ID: + return { + "id": VISION_AUTO_MODE_ID, + "name": "Auto (Fastest)", + "provider": "AUTO", + "model_name": "auto", + "is_auto_mode": True, + } + if config_id > 0: + return None + for cfg in config.GLOBAL_VISION_LLM_CONFIGS: + if cfg.get("id") == config_id: + return cfg + return None diff --git a/surfsense_backend/app/services/vision_model_list_service.py b/surfsense_backend/app/services/vision_model_list_service.py new file mode 100644 index 000000000..fc459910b --- /dev/null +++ b/surfsense_backend/app/services/vision_model_list_service.py @@ -0,0 +1,134 @@ +""" +Service for fetching and caching the vision-capable model list. + +Reuses the same OpenRouter public API and local fallback as the LLM model +list service, but filters for models that accept image input. +""" + +import json +import logging +import time +from pathlib import Path + +import httpx + +logger = logging.getLogger(__name__) + +OPENROUTER_API_URL = "https://openrouter.ai/api/v1/models" +FALLBACK_FILE = ( + Path(__file__).parent.parent / "config" / "vision_model_list_fallback.json" +) +CACHE_TTL_SECONDS = 86400 # 24 hours + +_cache: list[dict] | None = None +_cache_timestamp: float = 0 + +OPENROUTER_SLUG_TO_VISION_PROVIDER: dict[str, str] = { + "openai": "OPENAI", + "anthropic": "ANTHROPIC", + "google": "GOOGLE", + "mistralai": "MISTRAL", + "x-ai": "XAI", +} + + +def _format_context_length(length: int | None) -> str | None: + if not length: + return None + if length >= 1_000_000: + return f"{length / 1_000_000:g}M" + if length >= 1_000: + return f"{length / 1_000:g}K" + return str(length) + + +async def _fetch_from_openrouter() -> list[dict] | None: + try: + async with httpx.AsyncClient(timeout=15) as client: + response = await client.get(OPENROUTER_API_URL) + response.raise_for_status() + data = response.json() + return data.get("data", []) + except Exception as e: + logger.warning("Failed to fetch from OpenRouter API for vision models: %s", e) + return None + + +def _load_fallback() -> list[dict]: + try: + with open(FALLBACK_FILE, encoding="utf-8") as f: + return json.load(f) + except Exception as e: + logger.error("Failed to load vision model fallback list: %s", e) + return [] + + +def _is_vision_model(model: dict) -> bool: + """Return True if the model accepts image input and outputs text.""" + arch = model.get("architecture", {}) + input_mods = arch.get("input_modalities", []) + output_mods = arch.get("output_modalities", []) + return "image" in input_mods and "text" in output_mods + + +def _process_vision_models(raw_models: list[dict]) -> list[dict]: + processed: list[dict] = [] + + for model in raw_models: + model_id: str = model.get("id", "") + name: str = model.get("name", "") + context_length = model.get("context_length") + + if "/" not in model_id: + continue + + if not _is_vision_model(model): + continue + + provider_slug, model_name = model_id.split("/", 1) + context_window = _format_context_length(context_length) + + processed.append( + { + "value": model_id, + "label": name, + "provider": "OPENROUTER", + "context_window": context_window, + } + ) + + native_provider = OPENROUTER_SLUG_TO_VISION_PROVIDER.get(provider_slug) + if native_provider: + if native_provider == "GOOGLE" and not model_name.startswith("gemini-"): + continue + + processed.append( + { + "value": model_name, + "label": name, + "provider": native_provider, + "context_window": context_window, + } + ) + + return processed + + +async def get_vision_model_list() -> list[dict]: + global _cache, _cache_timestamp + + if _cache is not None and (time.time() - _cache_timestamp) < CACHE_TTL_SECONDS: + return _cache + + raw_models = await _fetch_from_openrouter() + + if raw_models is None: + logger.info("Using fallback vision model list") + return _load_fallback() + + processed = _process_vision_models(raw_models) + + _cache = processed + _cache_timestamp = time.time() + + return processed diff --git a/surfsense_backend/app/tasks/celery_tasks/__init__.py b/surfsense_backend/app/tasks/celery_tasks/__init__.py index a1113884f..6ea7a2e68 100644 --- a/surfsense_backend/app/tasks/celery_tasks/__init__.py +++ b/surfsense_backend/app/tasks/celery_tasks/__init__.py @@ -32,27 +32,10 @@ def get_celery_session_maker() -> async_sessionmaker: """ global _celery_engine, _celery_session_maker if _celery_session_maker is None: - # Reap connections orphaned mid-transaction (e.g. a worker that hung or - # crashed mid-index) so they can't hold locks on documents/chunks and - # wedge writes — the failure mode that previously left an "idle in - # transaction" session holding locks for 11+ hours. Kept generous so a - # legitimate long per-document embed window is never killed. - connect_args: dict = {} - idle_ms = config.DB_CELERY_IDLE_IN_TX_TIMEOUT_MS - if ( - idle_ms - and idle_ms > 0 - and config.DATABASE_URL - and "asyncpg" in config.DATABASE_URL - ): - connect_args["server_settings"] = { - "idle_in_transaction_session_timeout": str(idle_ms) - } _celery_engine = create_async_engine( config.DATABASE_URL, poolclass=NullPool, echo=False, - connect_args=connect_args, ) with contextlib.suppress(Exception): from app.observability.bootstrap import instrument_sqlalchemy_engine diff --git a/surfsense_backend/app/tasks/celery_tasks/auto_reload_task.py b/surfsense_backend/app/tasks/celery_tasks/auto_reload_task.py deleted file mode 100644 index 385cdde88..000000000 --- a/surfsense_backend/app/tasks/celery_tasks/auto_reload_task.py +++ /dev/null @@ -1,296 +0,0 @@ -"""Debit-triggered off-session credit auto-reload. - -Enqueued (best-effort) by ``auto_reload_service.maybe_trigger_auto_reload`` -after a credit debit drops the wallet below the user's threshold. This task is -the authoritative path: it re-checks eligibility under a row lock, enforces the -cooldown, then charges the saved card off-session via a Stripe PaymentIntent -(Stripe: charging a saved card off-session). - -Idempotency comes from three layers: -- a per-attempt CreditPurchase row created PENDING before the charge, -- a Stripe idempotency key derived from that row id, -- the ``payment_intent.*`` webhook backstop in ``stripe_routes`` that only - transitions PENDING rows. -""" - -from __future__ import annotations - -import logging -import uuid -from datetime import UTC, datetime, timedelta - -from sqlalchemy import select -from stripe import CardError, StripeClient, StripeError - -from app.celery_app import celery_app -from app.config import config -from app.db import CreditPurchase, CreditPurchaseStatus, User -from app.notifications.service import NotificationService -from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task - -logger = logging.getLogger(__name__) - -# 1_000_000 micro-USD == $1.00 == 100 cents, so cents = micros / 10_000. -_MICROS_PER_CENT = 10_000 - - -def _get_stripe_client() -> StripeClient | None: - if not config.STRIPE_SECRET_KEY: - logger.warning("Auto-reload skipped because STRIPE_SECRET_KEY is not set.") - return None - return StripeClient(config.STRIPE_SECRET_KEY) - - -def _card_error_payment_intent_id(exc: CardError) -> str | None: - """Pull the PaymentIntent id off a declined off-session charge. - - Per Stripe's off-session guide the failed intent is on ``exc.error.payment_intent``, - which may be a StripeObject or a plain dict depending on the SDK path. - """ - err = getattr(exc, "error", None) - pi = getattr(err, "payment_intent", None) if err is not None else None - if pi is None: - return None - if isinstance(pi, dict): - return pi.get("id") - return getattr(pi, "id", None) - - -@celery_app.task(name="auto_reload_credits") -def auto_reload_credits_task(user_id: str): - """Charge the user's saved card to top up credits when below threshold.""" - return run_async_celery_task(lambda: _auto_reload_credits(user_id)) - - -async def _auto_reload_credits(user_id: str) -> None: - if not config.AUTO_RELOAD_ENABLED: - return - - stripe_client = _get_stripe_client() - if stripe_client is None: - return - - cooldown = timedelta(minutes=max(config.AUTO_RELOAD_COOLDOWN_MINUTES, 0)) - now = datetime.now(UTC) - cutoff = now - cooldown - - async with get_celery_session_maker()() as db_session: - # Lock the user row so concurrent debits/tasks can't double-charge. - user = ( - ( - await db_session.execute( - select(User) - .where(User.id == uuid.UUID(user_id)) - .with_for_update(of=User) - ) - ) - .unique() - .scalar_one_or_none() - ) - if user is None or not user.auto_reload_enabled: - return - - if not (user.stripe_customer_id and user.auto_reload_payment_method_id): - return - - threshold = user.auto_reload_threshold_micros - amount = user.auto_reload_amount_micros - if not threshold or not amount: - return - - available = user.credit_micros_balance - user.credit_micros_reserved - if available >= threshold: - # Another reload (or a refund/grant) already restored the balance. - return - - # Cooldown: skip if a recent auto-reload purchase or failure happened. - recent = ( - await db_session.execute( - select(CreditPurchase.id) - .where( - CreditPurchase.user_id == user.id, - CreditPurchase.source == "auto_reload", - CreditPurchase.created_at >= cutoff, - CreditPurchase.status.in_( - [ - CreditPurchaseStatus.PENDING, - CreditPurchaseStatus.COMPLETED, - ] - ), - ) - .limit(1) - ) - ).first() - if recent is not None: - return - if user.auto_reload_failed_at and user.auto_reload_failed_at >= cutoff: - return - - customer_id = user.stripe_customer_id - payment_method_id = user.auto_reload_payment_method_id - amount_cents = max(round(amount / _MICROS_PER_CENT), 1) - - # Create the PENDING purchase row first so its id seeds the Stripe - # idempotency key and the webhook backstop can find it. - purchase = CreditPurchase( - user_id=user.id, - stripe_checkout_session_id=f"auto_reload:{uuid.uuid4()}", - quantity=0, - credit_micros_granted=amount, - amount_total=amount_cents, - currency="usd", - source="auto_reload", - status=CreditPurchaseStatus.PENDING, - ) - db_session.add(purchase) - await db_session.flush() - purchase_id = purchase.id - await db_session.commit() - - # Charge off-session outside the user-row lock so the network call doesn't - # hold the row. The purchase row is the synchronization point now. - try: - payment_intent = stripe_client.v1.payment_intents.create( - params={ - "amount": amount_cents, - "currency": "usd", - "customer": customer_id, - "payment_method": payment_method_id, - "off_session": True, - "confirm": True, - "metadata": { - "user_id": str(user_id), - "purchase_type": "auto_reload", - "purchase_id": str(purchase_id), - }, - }, - options={"idempotency_key": f"auto_reload:{purchase_id}"}, - ) - except CardError as exc: - await _record_failure( - purchase_id, - user_id, - amount, - payment_intent_id=_card_error_payment_intent_id(exc), - reason=getattr(exc, "user_message", None) or "Your card was declined.", - ) - return - except StripeError: - logger.exception("Auto-reload charge failed for user %s", user_id) - await _record_failure( - purchase_id, - user_id, - amount, - payment_intent_id=None, - reason="We couldn't process the charge. Please try again.", - ) - return - - payment_intent_id = str(payment_intent.id) - pi_status = getattr(payment_intent, "status", None) - - async with get_celery_session_maker()() as db_session: - purchase = ( - await db_session.execute( - select(CreditPurchase) - .where(CreditPurchase.id == purchase_id) - .with_for_update() - ) - ).scalar_one_or_none() - if purchase is None: - return - purchase.stripe_payment_intent_id = payment_intent_id - - if pi_status == "succeeded": - if purchase.status != CreditPurchaseStatus.COMPLETED: - user = ( - ( - await db_session.execute( - select(User) - .where(User.id == purchase.user_id) - .with_for_update(of=User) - ) - ) - .unique() - .scalar_one() - ) - purchase.status = CreditPurchaseStatus.COMPLETED - purchase.completed_at = datetime.now(UTC) - user.credit_micros_balance = ( - user.credit_micros_balance + purchase.credit_micros_granted - ) - user.auto_reload_failed_at = None - await db_session.commit() - logger.info( - "Auto-reload succeeded for user %s (+%s micro-USD)", - user_id, - amount, - ) - return - - # Not succeeded synchronously (e.g. requires_action / processing). - # Leave the row PENDING; the payment_intent webhook reconciles it. - await db_session.commit() - logger.info( - "Auto-reload PaymentIntent %s for user %s is %s; awaiting webhook.", - payment_intent_id, - user_id, - pi_status, - ) - - -async def _record_failure( - purchase_id: uuid.UUID, - user_id: str, - amount_micros: int, - *, - payment_intent_id: str | None, - reason: str | None, -) -> None: - """Mark the purchase FAILED, stamp the user, and notify them.""" - async with get_celery_session_maker()() as db_session: - purchase = ( - await db_session.execute( - select(CreditPurchase) - .where(CreditPurchase.id == purchase_id) - .with_for_update() - ) - ).scalar_one_or_none() - if purchase is not None and purchase.status == CreditPurchaseStatus.PENDING: - purchase.status = CreditPurchaseStatus.FAILED - if payment_intent_id: - purchase.stripe_payment_intent_id = payment_intent_id - - user = ( - ( - await db_session.execute( - select(User) - .where(User.id == uuid.UUID(user_id)) - .with_for_update(of=User) - ) - ) - .unique() - .scalar_one_or_none() - ) - if user is not None: - user.auto_reload_failed_at = datetime.now(UTC) - # Disable so a declined card doesn't get retried every debit; the - # user re-enables from settings (which clears the failure flag). - user.auto_reload_enabled = False - - await db_session.commit() - - try: - await NotificationService.auto_reload_failed.notify_auto_reload_failed( - session=db_session, - user_id=uuid.UUID(user_id), - amount_micros=amount_micros, - payment_intent_id=payment_intent_id, - reason=reason, - ) - except Exception: - logger.warning( - "Failed to create auto_reload_failed notification for user %s", - user_id, - exc_info=True, - ) diff --git a/surfsense_backend/app/tasks/celery_tasks/document_tasks.py b/surfsense_backend/app/tasks/celery_tasks/document_tasks.py index 4d71d6c9a..d38014124 100644 --- a/surfsense_backend/app/tasks/celery_tasks/document_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/document_tasks.py @@ -602,29 +602,23 @@ async def _process_file_upload( # Create notification for document processing logger.info(f"[_process_file_upload] Creating notification for: {filename}") - notification = None - heartbeat_task = None - try: - notification = ( - await NotificationService.document_processing.notify_processing_started( - session=session, - user_id=UUID(user_id), - document_type="FILE", - document_name=filename, - search_space_id=search_space_id, - file_size=file_size, - ) - ) - logger.info( - f"[_process_file_upload] Notification created with ID: {notification.id}" - ) - _start_heartbeat(notification.id) - heartbeat_task = asyncio.create_task(_run_heartbeat_loop(notification.id)) - except Exception: - logger.warning( - f"[_process_file_upload] Failed to create notification for: {filename}", - exc_info=True, + notification = ( + await NotificationService.document_processing.notify_processing_started( + session=session, + user_id=UUID(user_id), + document_type="FILE", + document_name=filename, + search_space_id=search_space_id, + file_size=file_size, ) + ) + logger.info( + f"[_process_file_upload] Notification created with ID: {notification.id if notification else 'None'}" + ) + + # Start Redis heartbeat for stale task detection + _start_heartbeat(notification.id) + heartbeat_task = asyncio.create_task(_run_heartbeat_loop(notification.id)) log_entry = await task_logger.log_task_start( task_name="process_file_upload", @@ -652,82 +646,83 @@ async def _process_file_upload( # Update notification on success if result: - if notification: - await NotificationService.document_processing.notify_processing_completed( + await ( + NotificationService.document_processing.notify_processing_completed( session=session, notification=notification, document_id=result.id, chunks_count=None, ) + ) else: # Duplicate detected - if notification: - await NotificationService.document_processing.notify_processing_completed( + await ( + NotificationService.document_processing.notify_processing_completed( session=session, notification=notification, error_message="Document already exists (duplicate)", ) + ) except Exception as e: # Import here to avoid circular dependencies from fastapi import HTTPException - from app.services.etl_credit_service import InsufficientCreditsError + from app.services.page_limit_service import PageLimitExceededError - # Check if this is an insufficient-credit error (either direct or - # wrapped in HTTPException) - credit_error: InsufficientCreditsError | None = None - if isinstance(e, InsufficientCreditsError): - credit_error = e + # Check if this is a page limit error (either direct or wrapped in HTTPException) + page_limit_error: PageLimitExceededError | None = None + if isinstance(e, PageLimitExceededError): + page_limit_error = e elif ( isinstance(e, HTTPException) and e.__cause__ - and isinstance(e.__cause__, InsufficientCreditsError) + and isinstance(e.__cause__, PageLimitExceededError) ): - # HTTPException wraps the original InsufficientCreditsError - credit_error = e.__cause__ - elif isinstance(e, HTTPException) and "credit" in str(e.detail).lower(): - # Fallback: HTTPException with credit message but no cause - credit_error = None # We don't have the details + # HTTPException wraps the original PageLimitExceededError + page_limit_error = e.__cause__ + elif isinstance(e, HTTPException) and "page limit" in str(e.detail).lower(): + # Fallback: HTTPException with page limit message but no cause + page_limit_error = None # We don't have the details - # For insufficient-credit errors, create a dedicated notification - if credit_error is not None: - error_message = str(credit_error) - # Create a dedicated insufficient credits notification + # For page limit errors, create a dedicated page_limit_exceeded notification + if page_limit_error is not None: + error_message = str(page_limit_error) + # Create a dedicated page limit exceeded notification try: - if notification: - await session.refresh(notification) - await NotificationService.document_processing.notify_processing_completed( - session=session, - notification=notification, - error_message="Insufficient credits", - ) + # First, mark the processing notification as failed + await session.refresh(notification) + await NotificationService.document_processing.notify_processing_completed( + session=session, + notification=notification, + error_message="Page limit exceeded", + ) - # Then create a separate insufficient_credits notification for better UX - await NotificationService.insufficient_credits.notify_insufficient_credits( + # Then create a separate page_limit_exceeded notification for better UX + await NotificationService.page_limit.notify_page_limit_exceeded( session=session, user_id=UUID(user_id), document_name=filename, document_type="FILE", search_space_id=search_space_id, - balance_micros=credit_error.balance_micros, - required_micros=credit_error.required_micros, + pages_used=page_limit_error.pages_used, + pages_limit=page_limit_error.pages_limit, + pages_to_add=page_limit_error.pages_to_add, ) except Exception as notif_error: logger.error( - f"Failed to create insufficient credits notification: {notif_error!s}" + f"Failed to create page limit notification: {notif_error!s}" ) - elif isinstance(e, HTTPException) and "credit" in str(e.detail).lower(): + elif isinstance(e, HTTPException) and "page limit" in str(e.detail).lower(): # HTTPException with page limit message but no detailed cause error_message = str(e.detail) try: - if notification: - await session.refresh(notification) - await NotificationService.document_processing.notify_processing_completed( - session=session, - notification=notification, - error_message=error_message, - ) + await session.refresh(notification) + await NotificationService.document_processing.notify_processing_completed( + session=session, + notification=notification, + error_message=error_message, + ) except Exception as notif_error: logger.error( f"Failed to update notification on failure: {notif_error!s}" @@ -736,13 +731,13 @@ async def _process_file_upload( error_message = str(e)[:100] # Update notification on failure - wrapped in try-except to ensure it doesn't fail silently try: - if notification: - await session.refresh(notification) - await NotificationService.document_processing.notify_processing_completed( - session=session, - notification=notification, - error_message=error_message, - ) + # Refresh notification to ensure it's not stale after any rollback + await session.refresh(notification) + await NotificationService.document_processing.notify_processing_completed( + session=session, + notification=notification, + error_message=error_message, + ) except Exception as notif_error: logger.error( f"Failed to update notification on failure: {notif_error!s}" @@ -758,10 +753,8 @@ async def _process_file_upload( raise finally: # Stop heartbeat — key deleted on success, expires on crash - if heartbeat_task: - heartbeat_task.cancel() - if notification: - _stop_heartbeat(notification.id) + heartbeat_task.cancel() + _stop_heartbeat(notification.id) @celery_app.task(name="process_file_upload_with_document", bind=True) @@ -901,36 +894,29 @@ async def _process_file_with_document( logger.info( f"[_process_file_with_document] Creating notification for: {filename}" ) - notification = None - heartbeat_task = None - try: - notification = ( - await NotificationService.document_processing.notify_processing_started( - session=session, - user_id=UUID(user_id), - document_type="FILE", - document_name=filename, - search_space_id=search_space_id, - file_size=file_size, - ) + notification = ( + await NotificationService.document_processing.notify_processing_started( + session=session, + user_id=UUID(user_id), + document_type="FILE", + document_name=filename, + search_space_id=search_space_id, + file_size=file_size, ) + ) - # Store document_id in notification metadata so cleanup task can find the document - if notification.notification_metadata is not None: - notification.notification_metadata["document_id"] = document_id - from sqlalchemy.orm.attributes import flag_modified + # Store document_id in notification metadata so cleanup task can find the document + if notification and notification.notification_metadata is not None: + notification.notification_metadata["document_id"] = document_id + from sqlalchemy.orm.attributes import flag_modified - flag_modified(notification, "notification_metadata") - await session.commit() - await session.refresh(notification) + flag_modified(notification, "notification_metadata") + await session.commit() + await session.refresh(notification) - _start_heartbeat(notification.id) - heartbeat_task = asyncio.create_task(_run_heartbeat_loop(notification.id)) - except Exception: - logger.warning( - f"[_process_file_with_document] Failed to create notification for: {filename}", - exc_info=True, - ) + # Start Redis heartbeat for stale task detection + _start_heartbeat(notification.id) + heartbeat_task = asyncio.create_task(_run_heartbeat_loop(notification.id)) log_entry = await task_logger.log_task_start( task_name="process_file_upload_with_document", @@ -970,13 +956,14 @@ async def _process_file_with_document( # Update notification on success if result: - if notification: - await NotificationService.document_processing.notify_processing_completed( + await ( + NotificationService.document_processing.notify_processing_completed( session=session, notification=notification, document_id=result.id, chunks_count=None, ) + ) logger.info( f"[_process_file_with_document] Successfully processed document {document_id}" ) @@ -985,29 +972,30 @@ async def _process_file_with_document( document.status = DocumentStatus.failed("Duplicate content detected") document.updated_at = get_current_timestamp() await session.commit() - if notification: - await NotificationService.document_processing.notify_processing_completed( + await ( + NotificationService.document_processing.notify_processing_completed( session=session, notification=notification, error_message="Document already exists (duplicate)", ) + ) except Exception as e: # Import here to avoid circular dependencies from fastapi import HTTPException - from app.services.etl_credit_service import InsufficientCreditsError + from app.services.page_limit_service import PageLimitExceededError - # Check if this is an insufficient-credit error - credit_error: InsufficientCreditsError | None = None - if isinstance(e, InsufficientCreditsError): - credit_error = e + # Check if this is a page limit error + page_limit_error: PageLimitExceededError | None = None + if isinstance(e, PageLimitExceededError): + page_limit_error = e elif ( isinstance(e, HTTPException) and e.__cause__ - and isinstance(e.__cause__, InsufficientCreditsError) + and isinstance(e.__cause__, PageLimitExceededError) ): - credit_error = e.__cause__ + page_limit_error = e.__cause__ # Mark document as failed (shows error in UI via Zero) error_message = str(e)[:500] @@ -1018,39 +1006,38 @@ async def _process_file_with_document( f"[_process_file_with_document] Document {document_id} marked as failed: {error_message[:100]}" ) - # Handle insufficient-credit errors with dedicated notification - if credit_error is not None: + # Handle page limit errors with dedicated notification + if page_limit_error is not None: try: - if notification: - await session.refresh(notification) - await NotificationService.document_processing.notify_processing_completed( - session=session, - notification=notification, - error_message="Insufficient credits", - ) - await NotificationService.insufficient_credits.notify_insufficient_credits( + await session.refresh(notification) + await NotificationService.document_processing.notify_processing_completed( + session=session, + notification=notification, + error_message="Page limit exceeded", + ) + await NotificationService.page_limit.notify_page_limit_exceeded( session=session, user_id=UUID(user_id), document_name=filename, document_type="FILE", search_space_id=search_space_id, - balance_micros=credit_error.balance_micros, - required_micros=credit_error.required_micros, + pages_used=page_limit_error.pages_used, + pages_limit=page_limit_error.pages_limit, + pages_to_add=page_limit_error.pages_to_add, ) except Exception as notif_error: logger.error( - f"Failed to create insufficient credits notification: {notif_error!s}" + f"Failed to create page limit notification: {notif_error!s}" ) else: # Update notification on failure try: - if notification: - await session.refresh(notification) - await NotificationService.document_processing.notify_processing_completed( - session=session, - notification=notification, - error_message=str(e)[:100], - ) + await session.refresh(notification) + await NotificationService.document_processing.notify_processing_completed( + session=session, + notification=notification, + error_message=str(e)[:100], + ) except Exception as notif_error: logger.error( f"Failed to update notification on failure: {notif_error!s}" @@ -1067,10 +1054,8 @@ async def _process_file_with_document( finally: # Stop heartbeat — key deleted on success, expires on crash - if heartbeat_task: - heartbeat_task.cancel() - if notification: - _stop_heartbeat(notification.id) + heartbeat_task.cancel() + _stop_heartbeat(notification.id) # Clean up temp file if os.path.exists(temp_path): diff --git a/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py b/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py new file mode 100644 index 000000000..8b311576e --- /dev/null +++ b/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py @@ -0,0 +1,236 @@ +"""Celery tasks for podcast generation.""" + +import asyncio +import logging +import sys +from contextlib import asynccontextmanager + +from sqlalchemy import select + +from app.agents.podcaster.graph import graph as podcaster_graph +from app.agents.podcaster.state import State as PodcasterState +from app.celery_app import celery_app +from app.config import config as app_config +from app.db import Podcast, PodcastStatus +from app.services.billable_calls import ( + BillingSettlementError, + QuotaInsufficientError, + _resolve_agent_billing_for_search_space, + billable_call, +) +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task + +logger = logging.getLogger(__name__) + +if sys.platform.startswith("win"): + try: + asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) + except AttributeError: + logger.warning( + "WindowsProactorEventLoopPolicy is unavailable; async subprocess support may fail." + ) + + +# ============================================================================= +# Content-based podcast generation (for new-chat) +# ============================================================================= + + +@asynccontextmanager +async def _celery_billable_session(): + """Session factory used by billable_call inside the Celery worker loop.""" + async with get_celery_session_maker()() as session: + yield session + + +@celery_app.task(name="generate_content_podcast", bind=True) +def generate_content_podcast_task( + self, + podcast_id: int, + source_content: str, + search_space_id: int, + user_prompt: str | None = None, +) -> dict: + """ + Celery task to generate podcast from source content. + Updates existing podcast record created by the tool. + """ + try: + return run_async_celery_task( + lambda: _generate_content_podcast( + podcast_id, + source_content, + search_space_id, + user_prompt, + ) + ) + except Exception as e: + logger.error(f"Error generating content podcast: {e!s}") + try: + run_async_celery_task(lambda: _mark_podcast_failed(podcast_id)) + except Exception: + logger.exception("Failed to mark podcast %s as failed", podcast_id) + return {"status": "failed", "podcast_id": podcast_id} + + +async def _mark_podcast_failed(podcast_id: int) -> None: + """Mark a podcast as failed in the database.""" + async with get_celery_session_maker()() as session: + try: + result = await session.execute( + select(Podcast).filter(Podcast.id == podcast_id) + ) + podcast = result.scalars().first() + if podcast: + podcast.status = PodcastStatus.FAILED + await session.commit() + except Exception as e: + logger.error(f"Failed to mark podcast as failed: {e}") + + +async def _generate_content_podcast( + podcast_id: int, + source_content: str, + search_space_id: int, + user_prompt: str | None = None, +) -> dict: + """Generate content-based podcast and update existing record.""" + async with get_celery_session_maker()() as session: + result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id)) + podcast = result.scalars().first() + + if not podcast: + raise ValueError(f"Podcast {podcast_id} not found") + + try: + podcast.status = PodcastStatus.GENERATING + await session.commit() + + try: + ( + owner_user_id, + billing_tier, + base_model, + ) = await _resolve_agent_billing_for_search_space( + session, + search_space_id, + thread_id=podcast.thread_id, + ) + except ValueError as resolve_err: + logger.error( + "Podcast %s: cannot resolve billing for search_space=%s: %s", + podcast.id, + search_space_id, + resolve_err, + ) + podcast.status = PodcastStatus.FAILED + await session.commit() + return { + "status": "failed", + "podcast_id": podcast.id, + "reason": "billing_resolution_failed", + } + + graph_config = { + "configurable": { + "podcast_title": podcast.title, + "search_space_id": search_space_id, + "user_prompt": user_prompt, + } + } + + initial_state = PodcasterState( + source_content=source_content, + db_session=session, + ) + + try: + async with billable_call( + user_id=owner_user_id, + search_space_id=search_space_id, + billing_tier=billing_tier, + base_model=base_model, + quota_reserve_micros_override=app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS, + usage_type="podcast_generation", + call_details={ + "podcast_id": podcast.id, + "title": podcast.title, + "thread_id": podcast.thread_id, + }, + billable_session_factory=_celery_billable_session, + ): + graph_result = await podcaster_graph.ainvoke( + initial_state, config=graph_config + ) + except QuotaInsufficientError as exc: + logger.info( + "Podcast %s denied: out of premium credits " + "(used=%d/%d remaining=%d)", + podcast.id, + exc.used_micros, + exc.limit_micros, + exc.remaining_micros, + ) + podcast.status = PodcastStatus.FAILED + await session.commit() + return { + "status": "failed", + "podcast_id": podcast.id, + "reason": "premium_quota_exhausted", + } + except BillingSettlementError: + logger.exception( + "Podcast %s: premium billing settlement failed", + podcast.id, + ) + podcast.status = PodcastStatus.FAILED + await session.commit() + return { + "status": "failed", + "podcast_id": podcast.id, + "reason": "billing_settlement_failed", + } + + podcast_transcript = graph_result.get("podcast_transcript", []) + file_path = graph_result.get("final_podcast_file_path", "") + + serializable_transcript = [] + for entry in podcast_transcript: + if hasattr(entry, "speaker_id"): + serializable_transcript.append( + {"speaker_id": entry.speaker_id, "dialog": entry.dialog} + ) + else: + serializable_transcript.append( + { + "speaker_id": entry.get("speaker_id", 0), + "dialog": entry.get("dialog", ""), + } + ) + + podcast.podcast_transcript = serializable_transcript + podcast.file_location = file_path + podcast.status = PodcastStatus.READY + logger.info( + "Podcast %s: committing READY transcript_entries=%d file=%s", + podcast.id, + len(serializable_transcript), + file_path, + ) + await session.commit() + logger.info("Podcast %s: READY commit complete", podcast.id) + + logger.info(f"Successfully generated podcast: {podcast.id}") + + return { + "status": "ready", + "podcast_id": podcast.id, + "title": podcast.title, + "transcript_entries": len(serializable_transcript), + } + + except Exception as e: + logger.error(f"Error in _generate_content_podcast: {e!s}") + podcast.status = PodcastStatus.FAILED + await session.commit() + raise diff --git a/surfsense_backend/app/tasks/celery_tasks/stripe_reconciliation_task.py b/surfsense_backend/app/tasks/celery_tasks/stripe_reconciliation_task.py index f1ed6c6b3..ace6ef7ca 100644 --- a/surfsense_backend/app/tasks/celery_tasks/stripe_reconciliation_task.py +++ b/surfsense_backend/app/tasks/celery_tasks/stripe_reconciliation_task.py @@ -1,4 +1,4 @@ -"""Reconcile pending Stripe credit purchases that might miss webhook fulfillment.""" +"""Reconcile pending Stripe purchases that might miss webhook fulfillment.""" from __future__ import annotations @@ -11,8 +11,10 @@ from stripe import StripeClient, StripeError from app.celery_app import celery_app from app.config import config from app.db import ( - CreditPurchase, - CreditPurchaseStatus, + PagePurchase, + PagePurchaseStatus, + PremiumTokenPurchase, + PremiumTokenPurchaseStatus, ) from app.routes import stripe_routes from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task @@ -30,14 +32,14 @@ def get_stripe_client() -> StripeClient | None: return StripeClient(config.STRIPE_SECRET_KEY) -@celery_app.task(name="reconcile_pending_stripe_credit_purchases") -def reconcile_pending_stripe_credit_purchases_task(): - """Recover paid credit purchases that were left pending due to missed webhook handling.""" - return run_async_celery_task(_reconcile_pending_credit_purchases) +@celery_app.task(name="reconcile_pending_stripe_page_purchases") +def reconcile_pending_stripe_page_purchases_task(): + """Recover paid purchases that were left pending due to missed webhook handling.""" + return run_async_celery_task(_reconcile_pending_page_purchases) -async def _reconcile_pending_credit_purchases() -> None: - """Reconcile stale pending credit purchases against Stripe source of truth. +async def _reconcile_pending_page_purchases() -> None: + """Reconcile stale pending page purchases against Stripe source of truth. Stripe retries webhook delivery automatically, but best practice is to add an application-level reconciliation path in case all retries fail or the endpoint @@ -55,12 +57,12 @@ async def _reconcile_pending_credit_purchases() -> None: pending_purchases = ( ( await db_session.execute( - select(CreditPurchase) + select(PagePurchase) .where( - CreditPurchase.status == CreditPurchaseStatus.PENDING, - CreditPurchase.created_at <= cutoff, + PagePurchase.status == PagePurchaseStatus.PENDING, + PagePurchase.created_at <= cutoff, ) - .order_by(CreditPurchase.created_at.asc()) + .order_by(PagePurchase.created_at.asc()) .limit(batch_size) ) ) @@ -70,13 +72,13 @@ async def _reconcile_pending_credit_purchases() -> None: if not pending_purchases: logger.debug( - "Stripe credit reconciliation found no pending purchases older than %s minutes.", + "Stripe reconciliation found no pending purchases older than %s minutes.", lookback_minutes, ) return logger.info( - "Stripe credit reconciliation checking %s pending purchases (cutoff=%s, batch=%s).", + "Stripe reconciliation checking %s pending purchases (cutoff=%s, batch=%s).", len(pending_purchases), lookback_minutes, batch_size, @@ -94,7 +96,7 @@ async def _reconcile_pending_credit_purchases() -> None: ) except StripeError: logger.exception( - "Stripe credit reconciliation failed to retrieve checkout session %s", + "Stripe reconciliation failed to retrieve checkout session %s", checkout_session_id, ) await db_session.rollback() @@ -105,24 +107,119 @@ async def _reconcile_pending_credit_purchases() -> None: try: if payment_status in {"paid", "no_payment_required"}: - await stripe_routes._fulfill_completed_credit_purchase( + await stripe_routes._fulfill_completed_purchase( db_session, checkout_session ) fulfilled_count += 1 elif session_status == "expired": - await stripe_routes._mark_credit_purchase_failed( + await stripe_routes._mark_purchase_failed( db_session, str(checkout_session.id) ) failed_count += 1 except Exception: logger.exception( - "Stripe credit reconciliation failed while processing checkout session %s", + "Stripe reconciliation failed while processing checkout session %s", checkout_session_id, ) await db_session.rollback() logger.info( - "Stripe credit reconciliation completed. fulfilled=%s failed=%s checked=%s", + "Stripe page reconciliation completed. fulfilled=%s failed=%s checked=%s", + fulfilled_count, + failed_count, + len(pending_purchases), + ) + + +@celery_app.task(name="reconcile_pending_stripe_token_purchases") +def reconcile_pending_stripe_token_purchases_task(): + """Recover paid token purchases that were left pending due to missed webhook handling.""" + return run_async_celery_task(_reconcile_pending_token_purchases) + + +async def _reconcile_pending_token_purchases() -> None: + """Reconcile stale pending token purchases against Stripe source of truth.""" + stripe_client = get_stripe_client() + if stripe_client is None: + return + + lookback_minutes = max(config.STRIPE_RECONCILIATION_LOOKBACK_MINUTES, 0) + batch_size = max(config.STRIPE_RECONCILIATION_BATCH_SIZE, 1) + cutoff = datetime.now(UTC) - timedelta(minutes=lookback_minutes) + + async with get_celery_session_maker()() as db_session: + pending_purchases = ( + ( + await db_session.execute( + select(PremiumTokenPurchase) + .where( + PremiumTokenPurchase.status + == PremiumTokenPurchaseStatus.PENDING, + PremiumTokenPurchase.created_at <= cutoff, + ) + .order_by(PremiumTokenPurchase.created_at.asc()) + .limit(batch_size) + ) + ) + .scalars() + .all() + ) + + if not pending_purchases: + logger.debug( + "Stripe token reconciliation found no pending purchases older than %s minutes.", + lookback_minutes, + ) + return + + logger.info( + "Stripe token reconciliation checking %s pending purchases (cutoff=%s, batch=%s).", + len(pending_purchases), + lookback_minutes, + batch_size, + ) + + fulfilled_count = 0 + failed_count = 0 + + for purchase in pending_purchases: + checkout_session_id = purchase.stripe_checkout_session_id + + try: + checkout_session = stripe_client.v1.checkout.sessions.retrieve( + checkout_session_id + ) + except StripeError: + logger.exception( + "Stripe token reconciliation failed to retrieve checkout session %s", + checkout_session_id, + ) + await db_session.rollback() + continue + + payment_status = getattr(checkout_session, "payment_status", None) + session_status = getattr(checkout_session, "status", None) + + try: + if payment_status in {"paid", "no_payment_required"}: + await stripe_routes._fulfill_completed_token_purchase( + db_session, checkout_session + ) + fulfilled_count += 1 + elif session_status == "expired": + await stripe_routes._mark_token_purchase_failed( + db_session, str(checkout_session.id) + ) + failed_count += 1 + except Exception: + logger.exception( + "Stripe token reconciliation failed while processing checkout session %s", + checkout_session_id, + ) + await db_session.rollback() + + logger.info( + "Stripe token reconciliation completed. fulfilled=%s failed=%s checked=%s", fulfilled_count, failed_count, len(pending_purchases), diff --git a/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py b/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py index c6ce0b350..08f22140c 100644 --- a/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py @@ -174,10 +174,11 @@ async def _generate_video_presentation( ) except QuotaInsufficientError as exc: logger.info( - "VideoPresentation %s denied: out of credits " - "(balance=%d remaining=%d)", + "VideoPresentation %s denied: out of premium credits " + "(used=%d/%d remaining=%d)", video_pres.id, - exc.balance_micros, + exc.used_micros, + exc.limit_micros, exc.remaining_micros, ) video_pres.status = VideoPresentationStatus.FAILED diff --git a/surfsense_backend/app/tasks/chat/llm_history_normalizer.py b/surfsense_backend/app/tasks/chat/llm_history_normalizer.py deleted file mode 100644 index 3394913c3..000000000 --- a/surfsense_backend/app/tasks/chat/llm_history_normalizer.py +++ /dev/null @@ -1,88 +0,0 @@ -"""Convert persisted chat content into provider-safe LangChain history. - -Assistant UI parts are a UI/storage shape, not an LLM prompt shape. This module -extracts only model-safe content before prior turns are replayed to a provider. -""" - -from __future__ import annotations - -from typing import Any - -_USER_CONTENT_TYPES = {"text", "image", "image_url"} - - -def _text_from_block(block: dict[str, Any]) -> str: - value = block.get("text") or block.get("content") or "" - return value if isinstance(value, str) else "" - - -def assistant_content_to_llm_text(content: Any) -> str: - """Return visible assistant text, dropping reasoning/UI/provider blocks.""" - if isinstance(content, str): - return content - if isinstance(content, dict): - return _text_from_block(content) - if not isinstance(content, list): - return "" - - text_chunks: list[str] = [] - for block in content: - if isinstance(block, str): - if block: - text_chunks.append(block) - continue - if not isinstance(block, dict): - continue - if block.get("type") == "text": - text = _text_from_block(block) - if text: - text_chunks.append(text) - return "\n".join(text_chunks) - - -def user_content_to_llm_content( - content: Any, - *, - allow_images: bool = True, -) -> str | list[dict[str, Any]]: - """Return provider-safe user text/image content for LangChain.""" - if isinstance(content, str): - return content - if isinstance(content, dict): - return _text_from_block(content) - if not isinstance(content, list): - return "" - - parts: list[dict[str, Any]] = [] - text_chunks: list[str] = [] - for block in content: - if isinstance(block, str): - if block: - text_chunks.append(block) - continue - if not isinstance(block, dict): - continue - block_type = block.get("type") - if block_type not in _USER_CONTENT_TYPES: - continue - if block_type == "text": - text = _text_from_block(block) - if text: - parts.append({"type": "text", "text": text}) - text_chunks.append(text) - elif allow_images and block_type == "image": - image = block.get("image") - if isinstance(image, str) and image.startswith("data:"): - parts.append({"type": "image_url", "image_url": {"url": image}}) - elif allow_images and block_type == "image_url": - image_url = block.get("image_url") - if isinstance(image_url, dict): - url = image_url.get("url") - if isinstance(url, str) and url.startswith("data:"): - parts.append({"type": "image_url", "image_url": {"url": url}}) - elif isinstance(image_url, str) and image_url.startswith("data:"): - parts.append({"type": "image_url", "image_url": {"url": image_url}}) - - if allow_images and any(part.get("type") == "image_url" for part in parts): - return parts - return "\n".join(text_chunks) diff --git a/surfsense_backend/app/tasks/chat/message_parts_normalizer.py b/surfsense_backend/app/tasks/chat/message_parts_normalizer.py deleted file mode 100644 index a4b636538..000000000 --- a/surfsense_backend/app/tasks/chat/message_parts_normalizer.py +++ /dev/null @@ -1,89 +0,0 @@ -"""Normalize final LangChain assistant messages into assistant-ui parts. - -Live streaming remains the primary source for rich, incremental UI state. -This module is only used after the graph has finished so refresh persistence -does not depend on provider-specific streaming chunk shapes. -""" - -from __future__ import annotations - -from collections.abc import Iterable -from typing import Any - -from langchain_core.messages import AIMessage - - -def _text_from_content(content: Any) -> str: - if isinstance(content, str): - return content - if not isinstance(content, list): - return "" - - text_parts: list[str] = [] - for block in content: - if not isinstance(block, dict): - continue - if block.get("type") != "text": - continue - value = block.get("text") or block.get("content") or "" - if isinstance(value, str) and value: - text_parts.append(value) - return "".join(text_parts) - - -def normalize_ai_message_to_parts( - message: AIMessage | Any | None, -) -> list[dict[str, Any]]: - """Return user-visible assistant-ui parts for a final AI message. - - We intentionally do not backfill provider ``thinking`` / - ``reasoning_content`` blocks here. If reasoning streamed live, the - ``AssistantContentBuilder`` already captured it. If it only exists in the - final model payload, persisting it retroactively could expose content the - UI never showed during the turn. - """ - if message is None: - return [] - - text = _text_from_content(getattr(message, "content", None)).strip() - if not text: - return [] - return [{"type": "text", "text": text}] - - -def last_ai_message(messages: Iterable[Any] | None) -> AIMessage | Any | None: - if messages is None: - return None - for message in reversed(list(messages)): - if isinstance(message, AIMessage): - return message - if getattr(message, "type", None) == "ai": - return message - return None - - -def final_assistant_parts_from_messages( - messages: Iterable[Any] | None, -) -> list[dict[str, Any]]: - return normalize_ai_message_to_parts(last_ai_message(messages)) - - -def has_non_empty_text_part(parts: Iterable[dict[str, Any]]) -> bool: - return any( - part.get("type") == "text" - and isinstance(part.get("text"), str) - and bool(part.get("text", "").strip()) - for part in parts - ) - - -def merge_streamed_and_final_parts( - streamed_parts: list[dict[str, Any]], - final_parts: list[dict[str, Any]], -) -> list[dict[str, Any]]: - """Use final-state text only when streaming captured no answer text.""" - if has_non_empty_text_part(streamed_parts): - return streamed_parts - if not has_non_empty_text_part(final_parts): - return streamed_parts - return [*streamed_parts, *final_parts] diff --git a/surfsense_backend/app/tasks/chat/streaming/agent/event_loop.py b/surfsense_backend/app/tasks/chat/streaming/agent/event_loop.py index 939cd9b17..d96144bcd 100644 --- a/surfsense_backend/app/tasks/chat/streaming/agent/event_loop.py +++ b/surfsense_backend/app/tasks/chat/streaming/agent/event_loop.py @@ -16,9 +16,6 @@ from app.agents.chat.multi_agent_chat.main_agent.middleware.kb_persistence impor ) from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode from app.services.new_streaming_service import VercelStreamingService -from app.tasks.chat.message_parts_normalizer import ( - final_assistant_parts_from_messages, -) from app.tasks.chat.streaming.contract.file_contract import ( contract_enforcement_active, evaluate_file_contract_outcome, @@ -78,9 +75,6 @@ async def stream_agent_events( state = await agent.aget_state(config) state_values = getattr(state, "values", {}) or {} - result.final_message_parts = final_assistant_parts_from_messages( - state_values.get("messages") - ) # Safety net: if astream_events was cancelled before # KnowledgeBasePersistenceMiddleware.aafter_agent ran, any staged work diff --git a/surfsense_backend/app/tasks/chat/streaming/errors/classifier.py b/surfsense_backend/app/tasks/chat/streaming/errors/classifier.py index 3fc5918ee..6b37df343 100644 --- a/surfsense_backend/app/tasks/chat/streaming/errors/classifier.py +++ b/surfsense_backend/app/tasks/chat/streaming/errors/classifier.py @@ -12,7 +12,6 @@ from app.agents.chat.multi_agent_chat.main_agent.middleware.busy_mutex import ( is_cancel_requested, ) from app.agents.chat.runtime.errors import BusyError -from app.services.llm_error_adapter import LLMErrorCategory, adapt_llm_exception TURN_CANCELLING_INITIAL_DELAY_MS = 200 TURN_CANCELLING_BACKOFF_FACTOR = 2 @@ -103,9 +102,6 @@ def _extract_provider_error_code(parsed: dict[str, Any] | None) -> int | None: def is_provider_rate_limited(exc: BaseException) -> bool: """Return True if the exception looks like an upstream HTTP 429 / rate limit.""" - if adapt_llm_exception(exc).category is LLMErrorCategory.RATE_LIMITED: - return True - raw = str(exc) lowered = raw.lower() if "ratelimit" in type(exc).__name__.lower(): @@ -135,85 +131,6 @@ def is_provider_rate_limited(exc: BaseException) -> bool: ) -def _provider_error_extra(adapted: Any) -> dict[str, Any] | None: - extra: dict[str, Any] = {"provider_error_category": adapted.category.value} - if adapted.provider_status_code is not None: - extra["provider_status_code"] = adapted.provider_status_code - if adapted.provider_error_type: - extra["provider_error_type"] = adapted.provider_error_type - return extra - - -def _classify_provider_exception( - exc: Exception, -) -> ( - tuple[str, str, Literal["info", "warn", "error"], bool, str, dict[str, Any] | None] - | None -): - adapted = adapt_llm_exception(exc) - - if adapted.category is LLMErrorCategory.RATE_LIMITED: - return ( - "rate_limited", - "RATE_LIMITED", - "warn", - True, - "This model is temporarily rate-limited. Please try again in a few seconds or switch models.", - _provider_error_extra(adapted), - ) - - if adapted.category in { - LLMErrorCategory.AUTH_FAILED, - LLMErrorCategory.PERMISSION_DENIED, - }: - return ( - "model_auth_failed", - "MODEL_AUTH_FAILED", - "warn", - True, - "This model's API key is invalid or expired. Switch models, or update the API key.", - _provider_error_extra(adapted), - ) - - if adapted.category is LLMErrorCategory.MODEL_NOT_FOUND: - return ( - "model_not_found", - "MODEL_NOT_FOUND", - "warn", - True, - "The selected model is unavailable or no longer exists. Switch to another model and try again.", - _provider_error_extra(adapted), - ) - - if adapted.category is LLMErrorCategory.CONTEXT_LIMIT: - return ( - "model_context_limit", - "MODEL_CONTEXT_LIMIT", - "warn", - True, - "This request is too large for the selected model. Try a model with a larger context window or reduce the input.", - _provider_error_extra(adapted), - ) - - if adapted.category in { - LLMErrorCategory.TIMEOUT, - LLMErrorCategory.PROVIDER_UNAVAILABLE, - LLMErrorCategory.BAD_GATEWAY, - LLMErrorCategory.CONNECTION_FAILED, - LLMErrorCategory.SERVER_ERROR, - }: - return ( - "model_provider_unavailable", - "MODEL_PROVIDER_UNAVAILABLE", - "warn", - True, - "The selected model provider is temporarily unavailable. Please try again or switch models.", - _provider_error_extra(adapted), - ) - - return None - - def classify_stream_exception( exc: Exception, *, @@ -250,9 +167,15 @@ def classify_stream_exception( None, ) - provider_classification = _classify_provider_exception(exc) - if provider_classification is not None: - return provider_classification + if is_provider_rate_limited(exc): + return ( + "rate_limited", + "RATE_LIMITED", + "warn", + True, + "This model is temporarily rate-limited. Please try again in a few seconds or switch models.", + None, + ) return ( "server_error", diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/orchestrator.py b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/orchestrator.py index 1e6097e53..e33dca376 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/orchestrator.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/orchestrator.py @@ -85,11 +85,11 @@ from app.tasks.chat.streaming.flows.shared.pre_stream_setup import ( setup_connector_and_firecrawl, ) from app.tasks.chat.streaming.flows.shared.premium_quota import ( - CreditReservation, - finalize_credit, - needs_credit_quota, - release_credit, - reserve_credit, + PremiumReservation, + finalize_premium, + needs_premium_quota, + release_premium, + reserve_premium, ) from app.tasks.chat.streaming.flows.shared.rate_limit_recovery import ( can_recover_provider_rate_limit, @@ -182,7 +182,7 @@ async def stream_new_chat( accumulator = start_turn() - premium_reservation: CreditReservation | None = None + premium_reservation: PremiumReservation | None = None busy_error_raised = False emit_stream_error = partial( @@ -259,8 +259,8 @@ async def stream_new_chat( yield streaming_service.format_done() return - if needs_credit_quota(agent_config, user_id): - premium_reservation = await reserve_credit( + if needs_premium_quota(agent_config, user_id): + premium_reservation = await reserve_premium( agent_config=agent_config, user_id=user_id, # type: ignore[arg-type] ) @@ -336,7 +336,7 @@ async def stream_new_chat( else: yield emit_stream_error( message=( - "Buy more credits to continue with this model, or " + "Buy more tokens to continue with this model, or " "switch to a free model" ), error_kind="premium_quota_exhausted", @@ -762,7 +762,7 @@ async def stream_new_chat( # sub-agent calls during a premium turn still contribute to the bill # (they're $0 in practice anyway). if premium_reservation is not None and user_id: - await finalize_credit( + await finalize_premium( reservation=premium_reservation, user_id=user_id, accumulator=accumulator, @@ -812,7 +812,7 @@ async def stream_new_chat( end_turn(str(chat_id)) if premium_reservation is not None and user_id: - await release_credit(reservation=premium_reservation, user_id=user_id) + await release_premium(reservation=premium_reservation, user_id=user_id) await close_session_and_clear_ai_responding(session, chat_id) diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/title_gen.py b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/title_gen.py index d5e8c3729..fe3d210bb 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/title_gen.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/title_gen.py @@ -80,6 +80,7 @@ async def _generate_title( from litellm import acompletion from app.services.llm_router_service import LLMRouterService + from app.services.provider_api_base import resolve_api_base from app.services.token_tracking_service import _turn_accumulator # Excludes this turn's own assistant row (pre-written by @@ -124,12 +125,26 @@ async def _generate_title( router = LLMRouterService.get_router() response = await router.acompletion(model="auto", messages=messages) else: + # Apply the same ``api_base`` cascade chat / vision / image-gen + # call sites use so we never inherit ``litellm.api_base`` + # (commonly set by ``AZURE_OPENAI_ENDPOINT``) when the chat + # config itself ships an empty ``api_base``. Without this the + # title-gen on an OpenRouter chat config would 404 against the + # inherited Azure endpoint — see ``provider_api_base`` for the + # same bug repro on the image-gen / vision paths. raw_model = getattr(llm, "model", "") or "" + provider_prefix = raw_model.split("/", 1)[0] if "/" in raw_model else None + provider_value = agent_config.provider if agent_config is not None else None + title_api_base = resolve_api_base( + provider=provider_value, + provider_prefix=provider_prefix, + config_api_base=getattr(llm, "api_base", None), + ) response = await acompletion( model=raw_model, messages=messages, api_key=getattr(llm, "api_key", None), - api_base=getattr(llm, "api_base", None), + api_base=title_api_base, ) usage_info = None diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/orchestrator.py b/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/orchestrator.py index e1552e79e..6d0924850 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/orchestrator.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/orchestrator.py @@ -64,11 +64,11 @@ from app.tasks.chat.streaming.flows.shared.pre_stream_setup import ( setup_connector_and_firecrawl, ) from app.tasks.chat.streaming.flows.shared.premium_quota import ( - CreditReservation, - finalize_credit, - needs_credit_quota, - release_credit, - reserve_credit, + PremiumReservation, + finalize_premium, + needs_premium_quota, + release_premium, + reserve_premium, ) from app.tasks.chat.streaming.flows.shared.rate_limit_recovery import ( can_recover_provider_rate_limit, @@ -144,7 +144,7 @@ async def stream_resume_chat( accumulator = start_turn() - premium_reservation: CreditReservation | None = None + premium_reservation: PremiumReservation | None = None busy_error_raised = False emit_stream_error = partial( @@ -212,8 +212,8 @@ async def stream_resume_chat( "[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0 ) - if needs_credit_quota(agent_config, user_id): - premium_reservation = await reserve_credit( + if needs_premium_quota(agent_config, user_id): + premium_reservation = await reserve_premium( agent_config=agent_config, user_id=user_id, # type: ignore[arg-type] ) @@ -285,7 +285,7 @@ async def stream_resume_chat( else: yield emit_stream_error( message=( - "Buy more credits to continue with this model, or " + "Buy more tokens to continue with this model, or " "switch to a free model" ), error_kind="premium_quota_exhausted", @@ -544,7 +544,7 @@ async def stream_resume_chat( return if premium_reservation is not None and user_id: - await finalize_credit( + await finalize_premium( reservation=premium_reservation, user_id=user_id, accumulator=accumulator, @@ -584,7 +584,7 @@ async def stream_resume_chat( end_turn(str(chat_id)) if premium_reservation is not None and user_id: - await release_credit(reservation=premium_reservation, user_id=user_id) + await release_premium(reservation=premium_reservation, user_id=user_id) await close_session_and_clear_ai_responding(session, chat_id) diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/shared/assistant_finalize.py b/surfsense_backend/app/tasks/chat/streaming/flows/shared/assistant_finalize.py index 3f767c60b..be1f102f3 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/shared/assistant_finalize.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/shared/assistant_finalize.py @@ -53,7 +53,6 @@ async def finalize_assistant_message( ): return - from app.tasks.chat.message_parts_normalizer import merge_streamed_and_final_parts from app.tasks.chat.persistence import finalize_assistant_turn builder_stats: dict[str, int] | None = None @@ -75,10 +74,6 @@ async def finalize_assistant_message( "text": stream_result.accumulated_text or "", } ] - content_payload = merge_streamed_and_final_parts( - content_payload, - stream_result.final_message_parts, - ) if builder_stats is not None: _perf_log.info( diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py b/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py index 6f905e8f4..7e2bc950b 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py @@ -1,8 +1,8 @@ """Load an LLM + AgentConfig bundle for a given config id. Handles both code paths uniformly: -- ``config_id > 0`` → database-backed model-connection ``Model`` row. -- ``config_id < 0`` → virtual global model materialized from YAML/OpenRouter. +- ``config_id >= 0`` → database-backed ``NewLLMConfig`` row (per-user/per-space). +- ``config_id < 0`` → YAML-defined global LLM config (built-in defaults). Returns ``(llm, agent_config, error_message)``; on success ``error_message`` is ``None``. The caller emits the friendly SSE error frame. @@ -12,78 +12,15 @@ from __future__ import annotations from typing import Any -from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import selectinload from app.agents.chat.runtime.llm_config import ( AgentConfig, - SanitizedChatLiteLLM, + create_chat_litellm_from_agent_config, + create_chat_litellm_from_config, + load_agent_config, + load_global_llm_config_by_id, ) -from app.config import config -from app.db import Model, SearchSpace -from app.services.model_capabilities import has_capability -from app.services.model_resolver import to_litellm -from app.services.token_tracking_service import register_model_usage_metadata - - -def _agent_config_from_resolved( - *, - config_id: int, - config_name: str | None, - provider: str, - model_name: str, - api_key: str | None, - api_base: str | None, - litellm_params: dict | None, - supports_image_input: bool, - billing_tier: str = "free", -) -> AgentConfig: - return AgentConfig( - provider=provider, - model_name=model_name, - api_key=api_key or "", - api_base=api_base, - custom_provider=None, - litellm_params=litellm_params, - config_id=config_id, - config_name=config_name, - is_auto_mode=False, - billing_tier=billing_tier, - is_premium=billing_tier == "premium", - supports_image_input=supports_image_input, - ) - - -async def _load_search_space( - session: AsyncSession, search_space_id: int -) -> SearchSpace | None: - result = await session.execute( - select(SearchSpace).where(SearchSpace.id == search_space_id) - ) - return result.scalars().first() - - -async def _load_db_model( - session: AsyncSession, - *, - model_id: int, - search_space: SearchSpace, -) -> Model | None: - result = await session.execute( - select(Model) - .options(selectinload(Model.connection)) - .where(Model.id == model_id, Model.enabled.is_(True)) - ) - model = result.scalars().first() - if not model or not model.connection or not model.connection.enabled: - return None - conn = model.connection - if conn.search_space_id is not None and conn.search_space_id != search_space.id: - return None - if conn.user_id is not None and conn.user_id != search_space.user_id: - return None - return model async def load_llm_bundle( @@ -92,93 +29,29 @@ async def load_llm_bundle( config_id: int, search_space_id: int, ) -> tuple[Any, AgentConfig | None, str | None]: - search_space = await _load_search_space(session, search_space_id) - if not search_space: - return None, None, f"Search space {search_space_id} not found" - - if config_id > 0: - model = await _load_db_model( - session, - model_id=config_id, - search_space=search_space, + if config_id >= 0: + loaded_agent_config = await load_agent_config( + session=session, + config_id=config_id, + search_space_id=search_space_id, ) - if not model or not has_capability(model, "chat"): + if not loaded_agent_config: return ( None, None, - f"Failed to load chat model with id {config_id}", + f"Failed to load NewLLMConfig with id {config_id}", ) - model_string, litellm_kwargs = to_litellm(model.connection, model.model_id) - display_name = model.display_name or model.model_id - provider = model.connection.provider or "" - register_model_usage_metadata( - model=model_string, - model_ref=f"db:{model.id}", - model_id=model.model_id, - display_name=display_name, - provider=provider, - ) - agent_config = _agent_config_from_resolved( - config_id=config_id, - config_name=display_name, - provider=provider, - model_name=model.model_id, - api_key=model.connection.api_key, - api_base=model.connection.base_url, - litellm_params=(model.connection.extra or {}).get("litellm_params"), - supports_image_input=has_capability(model, "vision"), - billing_tier="free", - ) return ( - SanitizedChatLiteLLM( - model=model_string, **{**litellm_kwargs, "streaming": True} - ), - agent_config, + create_chat_litellm_from_agent_config(loaded_agent_config), + loaded_agent_config, None, ) - global_model = next( - (m for m in config.GLOBAL_MODELS if m.get("id") == config_id), None - ) - if not global_model or not has_capability(global_model, "chat"): - return None, None, f"Failed to load global chat model with id {config_id}" - global_connection = next( - ( - c - for c in config.GLOBAL_CONNECTIONS - if c.get("id") == global_model.get("connection_id") - ), - None, - ) - if not global_connection: - return None, None, f"Failed to load global connection for model {config_id}" - model_string, litellm_kwargs = to_litellm( - global_connection, global_model["model_id"] - ) - display_name = global_model.get("display_name") or global_model.get("model_id") - provider = global_connection.get("provider") or "" - register_model_usage_metadata( - model=model_string, - model_ref=f"global:{config_id}", - model_id=global_model["model_id"], - display_name=display_name, - provider=provider, - ) - agent_config = _agent_config_from_resolved( - config_id=config_id, - config_name=display_name, - provider=provider, - model_name=global_model["model_id"], - api_key=global_connection.get("api_key"), - api_base=global_connection.get("base_url"), - litellm_params=(global_connection.get("extra") or {}).get("litellm_params"), - supports_image_input=has_capability(global_model, "vision"), - billing_tier=str(global_model.get("billing_tier", "free")).lower(), - ) + loaded_llm_config = load_global_llm_config_by_id(config_id) + if not loaded_llm_config: + return None, None, f"Failed to load LLM config with id {config_id}" return ( - SanitizedChatLiteLLM( - model=model_string, **{**litellm_kwargs, "streaming": True} - ), - agent_config, + create_chat_litellm_from_config(loaded_llm_config), + AgentConfig.from_yaml_config(loaded_llm_config), None, ) diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/shared/premium_quota.py b/surfsense_backend/app/tasks/chat/streaming/flows/shared/premium_quota.py index 232071394..6c08cb29f 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/shared/premium_quota.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/shared/premium_quota.py @@ -1,12 +1,13 @@ -"""Credit wallet (USD micro-units) reserve / finalize / release lifecycle. +"""Premium credit (USD micro-units) reserve / finalize / release lifecycle. -Both ``stream_new_chat`` and ``stream_resume_chat`` reserve credits up front (so -a single LLM call can't run away with the budget), then finalize the actual -provider cost reported by LiteLLM when the turn completes successfully, or -release the reservation on the cancellation / interrupted-without-finalize paths. +Both ``stream_new_chat`` and ``stream_resume_chat`` reserve premium credits up +front (so a single LLM call can't run away with the budget), then finalize the +actual provider cost reported by LiteLLM when the turn completes successfully, +or release the reservation on the cancellation / interrupted-without-finalize +paths. -State is held by the orchestrator as a simple ``CreditReservation`` so -reservation, fallback-on-denied, finalize, and release can all be reasoned +State is held by the orchestrator as a simple ``PremiumReservation`` tuple +so reservation, fallback-on-denied, finalize, and release can all be reasoned about from one place. """ @@ -26,8 +27,8 @@ if TYPE_CHECKING: @dataclass -class CreditReservation: - """Active credit reservation for one turn. +class PremiumReservation: + """Active premium-credit reservation for one turn. ``request_id`` is the per-reservation idempotency key (also passed to ``finalize``/``release`` so racing branches resolve to the same row). @@ -40,15 +41,15 @@ class CreditReservation: allowed: bool -def needs_credit_quota(agent_config: AgentConfig | None, user_id: str | None) -> bool: +def needs_premium_quota(agent_config: AgentConfig | None, user_id: str | None) -> bool: return bool(agent_config is not None and user_id and agent_config.is_premium) -async def reserve_credit( +async def reserve_premium( *, agent_config: AgentConfig, user_id: str, -) -> CreditReservation: +) -> PremiumReservation: """Reserve estimated micros up front; returns the reservation handle.""" from app.services.token_quota_service import ( TokenQuotaService, @@ -67,22 +68,22 @@ async def reserve_credit( quota_reserve_tokens=agent_config.quota_reserve_tokens, ) async with shielded_async_session() as quota_session: - quota_result = await TokenQuotaService.credit_reserve( + quota_result = await TokenQuotaService.premium_reserve( db_session=quota_session, user_id=UUID(user_id), request_id=request_id, reserve_micros=reserve_amount_micros, ) - return CreditReservation( + return PremiumReservation( request_id=request_id, reserved_micros=reserve_amount_micros, allowed=quota_result.allowed, ) -async def finalize_credit( +async def finalize_premium( *, - reservation: CreditReservation, + reservation: PremiumReservation, user_id: str, accumulator: TokenAccumulator, ) -> None: @@ -95,7 +96,7 @@ async def finalize_credit( from app.services.token_quota_service import TokenQuotaService async with shielded_async_session() as quota_session: - await TokenQuotaService.credit_finalize( + await TokenQuotaService.premium_finalize( db_session=quota_session, user_id=UUID(user_id), request_id=reservation.request_id, @@ -104,15 +105,15 @@ async def finalize_credit( ) except Exception: logging.getLogger(__name__).warning( - "Failed to finalize credit quota for user %s", + "Failed to finalize premium quota for user %s", user_id, exc_info=True, ) -async def release_credit( +async def release_premium( *, - reservation: CreditReservation, + reservation: PremiumReservation, user_id: str, ) -> None: """Release the reservation on cancellation paths; never raises.""" @@ -120,12 +121,12 @@ async def release_credit( from app.services.token_quota_service import TokenQuotaService async with shielded_async_session() as quota_session: - await TokenQuotaService.credit_release( + await TokenQuotaService.premium_release( db_session=quota_session, user_id=UUID(user_id), reserved_micros=reservation.reserved_micros, ) except Exception: logging.getLogger(__name__).warning( - "Failed to release credit quota for user %s", user_id + "Failed to release premium quota for user %s", user_id ) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_podcast/emission.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_podcast/emission.py index b21357b50..f1a1e9c37 100644 --- a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_podcast/emission.py +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_podcast/emission.py @@ -15,32 +15,22 @@ def iter_completion_emission_frames( out = ctx.tool_output payload = out if isinstance(out, dict) else {"result": out} yield ctx.emit_tool_output_card(payload) - status = out.get("status") if isinstance(out, dict) else None - title = out.get("title", "Podcast") if isinstance(out, dict) else "Podcast" - if status in ( - "awaiting_brief", - "awaiting_review", + if isinstance(out, dict) and out.get("status") in ( "pending", - "drafting", - "rendering", + "generating", + "processing", ): - # This line is persisted with the chat while the podcast keeps moving, - # so it must stay true after the lifecycle outgrows today's status. yield ctx.streaming_service.format_terminal_info( - f"Podcast created: {title}", + f"Podcast queued: {out.get('title', 'Podcast')}", "success", ) - elif status in ("ready", "success"): + elif isinstance(out, dict) and out.get("status") in ("ready", "success"): yield ctx.streaming_service.format_terminal_info( - f"Podcast generated successfully: {title}", + f"Podcast generated successfully: {out.get('title', 'Podcast')}", "success", ) - elif status in ("failed", "error"): - error_msg = ( - out.get("error", "Unknown error") - if isinstance(out, dict) - else "Unknown error" - ) + elif isinstance(out, dict) and out.get("status") in ("failed", "error"): + error_msg = out.get("error", "Unknown error") yield ctx.streaming_service.format_terminal_info( f"Podcast generation failed: {error_msg}", "error", diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_podcast/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_podcast/thinking.py index fe8f9cfb7..5cf78ea72 100644 --- a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_podcast/thinking.py +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_podcast/thinking.py @@ -24,11 +24,11 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking d.get("source_content", "") if isinstance(tool_input, dict) else "" ) return ToolStartThinking( - title="Preparing podcast", + title="Generating podcast", items=[ f"Title: {podcast_title}", f"Content: {content_len:,} characters", - "Proposing brief (language, voices, length)...", + "Preparing audio generation...", ], ) @@ -50,19 +50,17 @@ def resolve_completed_thinking( if isinstance(tool_output, dict) else "Podcast" ) - if podcast_status in ( - "awaiting_brief", - "awaiting_review", - "pending", - "drafting", - "rendering", - ): - # Persisted with the chat while the podcast keeps moving, so the copy - # must stay true after the lifecycle outgrows today's status. + if podcast_status in ("pending", "generating", "processing"): completed = [ f"Title: {podcast_title}", - "Podcast created", - "Review and progress continue on the podcast card", + "Podcast generation started", + "Processing in background...", + ] + elif podcast_status == "already_generating": + completed = [ + f"Title: {podcast_title}", + "Podcast already in progress", + "Please wait for it to complete", ] elif podcast_status in ("failed", "error"): error_msg = ( @@ -81,4 +79,4 @@ def resolve_completed_thinking( ] else: completed = items - return ("Preparing podcast", completed) + return ("Generating podcast", completed) diff --git a/surfsense_backend/app/tasks/chat/streaming/shared/stream_result.py b/surfsense_backend/app/tasks/chat/streaming/shared/stream_result.py index 5e164070a..a940e8a9f 100644 --- a/surfsense_backend/app/tasks/chat/streaming/shared/stream_result.py +++ b/surfsense_backend/app/tasks/chat/streaming/shared/stream_result.py @@ -35,7 +35,3 @@ class StreamResult: # (``StreamResult`` is logged in some error branches) from dumping a # potentially-large parts list. content_builder: Any | None = field(default=None, repr=False) - # User-visible assistant message parts derived from the final LangGraph - # state. Used after streaming completes as a provider-agnostic persistence - # backfill when no text chunks reached the live stream. - final_message_parts: list[dict[str, Any]] = field(default_factory=list) diff --git a/surfsense_backend/app/tasks/connector_indexers/dropbox_indexer.py b/surfsense_backend/app/tasks/connector_indexers/dropbox_indexer.py index 9bf290d85..7cd3e1613 100644 --- a/surfsense_backend/app/tasks/connector_indexers/dropbox_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/dropbox_indexer.py @@ -28,7 +28,7 @@ from app.indexing_pipeline.connector_document import ConnectorDocument from app.indexing_pipeline.document_hashing import compute_identifier_hash from app.indexing_pipeline.exceptions import safe_exception_message from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService -from app.services.etl_credit_service import EtlCreditService +from app.services.page_limit_service import PageLimitService from app.services.task_logging_service import TaskLoggingService from app.tasks.connector_indexers.base import ( check_document_by_unique_identifier, @@ -423,8 +423,9 @@ async def _index_full_scan( }, ) - etl_credit_service = EtlCreditService(session) - available_micros = await etl_credit_service.get_available_micros(user_id) + page_limit_service = PageLimitService(session) + pages_used, pages_limit = await page_limit_service.get_page_usage(user_id) + remaining_quota = pages_limit - pages_used batch_estimated_pages = 0 page_limit_reached = False @@ -466,17 +467,13 @@ async def _index_full_scan( skipped += 1 continue - file_pages = EtlCreditService.estimate_pages_from_metadata( + file_pages = PageLimitService.estimate_pages_from_metadata( file.get("name", ""), file.get("size") ) - if ( - available_micros is not None - and EtlCreditService.pages_to_micros(batch_estimated_pages + file_pages) - > available_micros - ): + if batch_estimated_pages + file_pages > remaining_quota: if not page_limit_reached: logger.warning( - "Insufficient credits during Dropbox full scan, " + "Page limit reached during Dropbox full scan, " "skipping remaining files" ) page_limit_reached = True @@ -501,7 +498,9 @@ async def _index_full_scan( pages_to_deduct = max( 1, batch_estimated_pages * batch_indexed // len(files_to_download) ) - await etl_credit_service.charge_credits(user_id, pages_to_deduct) + await page_limit_service.update_page_usage( + user_id, pages_to_deduct, allow_exceed=True + ) indexed = renamed_count + batch_indexed logger.info( @@ -524,8 +523,9 @@ async def _index_selected_files( vision_llm=None, ) -> tuple[int, int, int, list[str]]: """Index user-selected files using the parallel pipeline.""" - etl_credit_service = EtlCreditService(session) - available_micros = await etl_credit_service.get_available_micros(user_id) + page_limit_service = PageLimitService(session) + pages_used, pages_limit = await page_limit_service.get_page_usage(user_id) + remaining_quota = pages_limit - pages_used batch_estimated_pages = 0 files_to_download: list[dict] = [] @@ -560,16 +560,12 @@ async def _index_selected_files( skipped += 1 continue - file_pages = EtlCreditService.estimate_pages_from_metadata( + file_pages = PageLimitService.estimate_pages_from_metadata( file.get("name", ""), file.get("size") ) - if ( - available_micros is not None - and EtlCreditService.pages_to_micros(batch_estimated_pages + file_pages) - > available_micros - ): + if batch_estimated_pages + file_pages > remaining_quota: display = file_name or file_path - errors.append(f"File '{display}': insufficient credits") + errors.append(f"File '{display}': page limit would be exceeded") continue batch_estimated_pages += file_pages @@ -590,7 +586,9 @@ async def _index_selected_files( pages_to_deduct = max( 1, batch_estimated_pages * batch_indexed // len(files_to_download) ) - await etl_credit_service.charge_credits(user_id, pages_to_deduct) + await page_limit_service.update_page_usage( + user_id, pages_to_deduct, allow_exceed=True + ) return renamed_count + batch_indexed, skipped, unsupported_count, errors diff --git a/surfsense_backend/app/tasks/connector_indexers/github_indexer.py b/surfsense_backend/app/tasks/connector_indexers/github_indexer.py index 557c2ce71..ce9b80e5e 100644 --- a/surfsense_backend/app/tasks/connector_indexers/github_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/github_indexer.py @@ -525,7 +525,6 @@ async def _simple_chunk_content(content: str, chunk_size: int = 4000) -> list: Chunk( content=chunk_text, embedding=embed_text(chunk_text), - position=len(chunks), ) ) diff --git a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py index 37de66ffd..b76f84bac 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py @@ -41,7 +41,7 @@ from app.indexing_pipeline.indexing_pipeline_service import ( PlaceholderInfo, ) from app.services.composio_service import ComposioService -from app.services.etl_credit_service import EtlCreditService +from app.services.page_limit_service import PageLimitService from app.services.task_logging_service import TaskLoggingService from app.tasks.connector_indexers.base import ( check_document_by_unique_identifier, @@ -555,11 +555,11 @@ async def _process_single_file( return 1, 0, 0 return 0, 1, 0 - etl_credit_service = EtlCreditService(session) - estimated_pages = EtlCreditService.estimate_pages_from_metadata( + page_limit_service = PageLimitService(session) + estimated_pages = PageLimitService.estimate_pages_from_metadata( file_name, file.get("size") ) - await etl_credit_service.check_credits(user_id, estimated_pages) + await page_limit_service.check_page_limit(user_id, estimated_pages) markdown, drive_metadata, error = await download_and_extract_content( drive_client, file, vision_llm=vision_llm @@ -602,7 +602,9 @@ async def _process_single_file( continue await pipeline.index(document, connector_doc) - await etl_credit_service.charge_credits(user_id, estimated_pages) + await page_limit_service.update_page_usage( + user_id, estimated_pages, allow_exceed=True + ) logger.info(f"Successfully indexed Google Drive file: {file_name}") return 1, 0, 0 @@ -711,8 +713,9 @@ async def _index_selected_files( Returns (indexed_count, skipped_count, unsupported_count, errors). """ - etl_credit_service = EtlCreditService(session) - available_micros = await etl_credit_service.get_available_micros(user_id) + page_limit_service = PageLimitService(session) + pages_used, pages_limit = await page_limit_service.get_page_usage(user_id) + remaining_quota = pages_limit - pages_used batch_estimated_pages = 0 files_to_download: list[dict] = [] @@ -738,16 +741,12 @@ async def _index_selected_files( skipped += 1 continue - file_pages = EtlCreditService.estimate_pages_from_metadata( + file_pages = PageLimitService.estimate_pages_from_metadata( file.get("name", ""), file.get("size") ) - if ( - available_micros is not None - and EtlCreditService.pages_to_micros(batch_estimated_pages + file_pages) - > available_micros - ): + if batch_estimated_pages + file_pages > remaining_quota: display = file_name or file_id - errors.append(f"File '{display}': insufficient credits") + errors.append(f"File '{display}': page limit would be exceeded") continue batch_estimated_pages += file_pages @@ -776,7 +775,9 @@ async def _index_selected_files( pages_to_deduct = max( 1, batch_estimated_pages * batch_indexed // len(files_to_download) ) - await etl_credit_service.charge_credits(user_id, pages_to_deduct) + await page_limit_service.update_page_usage( + user_id, pages_to_deduct, allow_exceed=True + ) return renamed_count + batch_indexed, skipped, unsupported_count, errors @@ -819,8 +820,9 @@ async def _index_full_scan( # ------------------------------------------------------------------ # Phase 1 (serial): collect files, run skip checks, track renames # ------------------------------------------------------------------ - etl_credit_service = EtlCreditService(session) - available_micros = await etl_credit_service.get_available_micros(user_id) + page_limit_service = PageLimitService(session) + pages_used, pages_limit = await page_limit_service.get_page_usage(user_id) + remaining_quota = pages_limit - pages_used batch_estimated_pages = 0 page_limit_reached = False @@ -875,19 +877,13 @@ async def _index_full_scan( skipped += 1 continue - file_pages = EtlCreditService.estimate_pages_from_metadata( + file_pages = PageLimitService.estimate_pages_from_metadata( file.get("name", ""), file.get("size") ) - if ( - available_micros is not None - and EtlCreditService.pages_to_micros( - batch_estimated_pages + file_pages - ) - > available_micros - ): + if batch_estimated_pages + file_pages > remaining_quota: if not page_limit_reached: logger.warning( - "Insufficient credits during Google Drive full scan, " + "Page limit reached during Google Drive full scan, " "skipping remaining files" ) page_limit_reached = True @@ -942,7 +938,9 @@ async def _index_full_scan( pages_to_deduct = max( 1, batch_estimated_pages * batch_indexed // len(files_to_download) ) - await etl_credit_service.charge_credits(user_id, pages_to_deduct) + await page_limit_service.update_page_usage( + user_id, pages_to_deduct, allow_exceed=True + ) indexed = renamed_count + batch_indexed logger.info( @@ -998,8 +996,9 @@ async def _index_with_delta_sync( # ------------------------------------------------------------------ # Phase 1 (serial): handle removals, collect files for download # ------------------------------------------------------------------ - etl_credit_service = EtlCreditService(session) - available_micros = await etl_credit_service.get_available_micros(user_id) + page_limit_service = PageLimitService(session) + pages_used, pages_limit = await page_limit_service.get_page_usage(user_id) + remaining_quota = pages_limit - pages_used batch_estimated_pages = 0 page_limit_reached = False @@ -1035,17 +1034,13 @@ async def _index_with_delta_sync( skipped += 1 continue - file_pages = EtlCreditService.estimate_pages_from_metadata( + file_pages = PageLimitService.estimate_pages_from_metadata( file.get("name", ""), file.get("size") ) - if ( - available_micros is not None - and EtlCreditService.pages_to_micros(batch_estimated_pages + file_pages) - > available_micros - ): + if batch_estimated_pages + file_pages > remaining_quota: if not page_limit_reached: logger.warning( - "Insufficient credits during Google Drive delta sync, " + "Page limit reached during Google Drive delta sync, " "skipping remaining files" ) page_limit_reached = True @@ -1084,7 +1079,9 @@ async def _index_with_delta_sync( pages_to_deduct = max( 1, batch_estimated_pages * batch_indexed // len(files_to_download) ) - await etl_credit_service.charge_credits(user_id, pages_to_deduct) + await page_limit_service.update_page_usage( + user_id, pages_to_deduct, allow_exceed=True + ) indexed = renamed_count + batch_indexed logger.info( diff --git a/surfsense_backend/app/tasks/connector_indexers/local_folder_indexer.py b/surfsense_backend/app/tasks/connector_indexers/local_folder_indexer.py index 2505fa7c4..1cd92dcf8 100644 --- a/surfsense_backend/app/tasks/connector_indexers/local_folder_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/local_folder_indexer.py @@ -33,7 +33,7 @@ from app.db import ( from app.indexing_pipeline.connector_document import ConnectorDocument from app.indexing_pipeline.document_hashing import compute_identifier_hash from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService -from app.services.etl_credit_service import EtlCreditService, InsufficientCreditsError +from app.services.page_limit_service import PageLimitExceededError, PageLimitService from app.services.task_logging_service import TaskLoggingService from app.tasks.celery_tasks import get_celery_session_maker from app.utils.document_versioning import create_version_snapshot @@ -46,38 +46,38 @@ from .base import ( HeartbeatCallbackType = Callable[[int], Awaitable[None]] -def _estimate_pages_safe(etl_credit_service: EtlCreditService, file_path: str) -> int: +def _estimate_pages_safe(page_limit_service: PageLimitService, file_path: str) -> int: """Estimate page count with a file-size fallback.""" try: - return etl_credit_service.estimate_pages_before_processing(file_path) + return page_limit_service.estimate_pages_before_processing(file_path) except Exception: file_size = os.path.getsize(file_path) return max(1, file_size // (80 * 1024)) -async def _check_credits_or_skip( - etl_credit_service: EtlCreditService, +async def _check_page_limit_or_skip( + page_limit_service: PageLimitService, user_id: str, file_path: str, page_multiplier: int = 1, ) -> tuple[int, int]: - """Estimate pages and check credit; raises InsufficientCreditsError if unaffordable. + """Estimate pages and check the limit; raises PageLimitExceededError if over quota. Returns (estimated_pages, billable_pages). """ - estimated = _estimate_pages_safe(etl_credit_service, file_path) + estimated = _estimate_pages_safe(page_limit_service, file_path) billable = estimated * page_multiplier - await etl_credit_service.check_credits(user_id, billable) + await page_limit_service.check_page_limit(user_id, billable) return estimated, billable def _compute_final_pages( - etl_credit_service: EtlCreditService, + page_limit_service: PageLimitService, estimated_pages: int, content_length: int, ) -> int: """Return the final page count as max(estimated, actual).""" - actual = etl_credit_service.estimate_pages_from_content_length(content_length) + actual = page_limit_service.estimate_pages_from_content_length(content_length) return max(estimated_pages, actual) @@ -162,13 +162,12 @@ async def _read_file_content( All file types (plaintext, audio, direct-convert, document, image) are handled by ``EtlPipelineService``. """ - from app.etl_pipeline.cache import extract_with_cache from app.etl_pipeline.etl_document import EtlRequest, ProcessingMode + from app.etl_pipeline.etl_pipeline_service import EtlPipelineService mode = ProcessingMode.coerce(processing_mode) - result = await extract_with_cache( - EtlRequest(file_path=file_path, filename=filename, processing_mode=mode), - vision_llm=vision_llm, + result = await EtlPipelineService(vision_llm=vision_llm).extract( + EtlRequest(file_path=file_path, filename=filename, processing_mode=mode) ) return result.markdown_content @@ -636,7 +635,7 @@ async def index_local_folder( skipped_count = 0 failed_count = 0 - etl_credit_service = EtlCreditService(session) + page_limit_service = PageLimitService(session) # ================================================================ # PHASE 1: Pre-filter files (mtime / content-hash), version changed @@ -695,12 +694,12 @@ async def index_local_folder( continue try: - estimated_pages, _billable = await _check_credits_or_skip( - etl_credit_service, user_id, file_path_abs + estimated_pages, _billable = await _check_page_limit_or_skip( + page_limit_service, user_id, file_path_abs ) - except InsufficientCreditsError: + except PageLimitExceededError: logger.warning( - f"Insufficient credits, skipping: {file_path_abs}" + f"Page limit exceeded, skipping: {file_path_abs}" ) failed_count += 1 continue @@ -731,12 +730,12 @@ async def index_local_folder( await create_version_snapshot(session, existing_document) else: try: - estimated_pages, _billable = await _check_credits_or_skip( - etl_credit_service, user_id, file_path_abs + estimated_pages, _billable = await _check_page_limit_or_skip( + page_limit_service, user_id, file_path_abs ) - except InsufficientCreditsError: + except PageLimitExceededError: logger.warning( - f"Insufficient credits, skipping: {file_path_abs}" + f"Page limit exceeded, skipping: {file_path_abs}" ) failed_count += 1 continue @@ -859,9 +858,11 @@ async def index_local_folder( est = mtime_info.get("estimated_pages", 1) content_len = mtime_info.get("content_length", 0) final_pages = _compute_final_pages( - etl_credit_service, est, content_len + page_limit_service, est, content_len + ) + await page_limit_service.update_page_usage( + user_id, final_pages, allow_exceed=True ) - await etl_credit_service.charge_credits(user_id, final_pages) else: failed_count += 1 @@ -1071,13 +1072,13 @@ async def _index_single_file( await session.commit() return 0, 0, None - etl_credit_service = EtlCreditService(session) + page_limit_service = PageLimitService(session) try: - estimated_pages, _billable = await _check_credits_or_skip( - etl_credit_service, user_id, str(full_path) + estimated_pages, _billable = await _check_page_limit_or_skip( + page_limit_service, user_id, str(full_path) ) - except InsufficientCreditsError as e: - return 0, 1, f"Insufficient credits: {e}" + except PageLimitExceededError as e: + return 0, 1, f"Page limit exceeded: {e}" try: content, content_hash = await _compute_file_content_hash( @@ -1141,9 +1142,11 @@ async def _index_single_file( if indexed: final_pages = _compute_final_pages( - etl_credit_service, estimated_pages, len(content) + page_limit_service, estimated_pages, len(content) + ) + await page_limit_service.update_page_usage( + user_id, final_pages, allow_exceed=True ) - await etl_credit_service.charge_credits(user_id, final_pages) await task_logger.log_task_success( log_entry, f"Single file indexed: {rel_path}", @@ -1296,7 +1299,7 @@ async def index_uploaded_files( await _set_indexing_flag(session, root_folder_id) - etl_credit_service = EtlCreditService(session) + page_limit_service = PageLimitService(session) pipeline = IndexingPipelineService(session) vision_llm_instance = None @@ -1342,14 +1345,14 @@ async def index_uploaded_files( continue try: - estimated_pages, _billable_pages = await _check_credits_or_skip( - etl_credit_service, + estimated_pages, _billable_pages = await _check_page_limit_or_skip( + page_limit_service, user_id, temp_path, page_multiplier=mode.page_multiplier, ) - except InsufficientCreditsError: - logger.warning(f"Insufficient credits, skipping: {relative_path}") + except PageLimitExceededError: + logger.warning(f"Page limit exceeded, skipping: {relative_path}") failed_count += 1 continue @@ -1422,10 +1425,12 @@ async def index_uploaded_files( if DocumentStatus.is_state(db_doc.status, DocumentStatus.READY): indexed_count += 1 final_pages = _compute_final_pages( - etl_credit_service, estimated_pages, len(content) + page_limit_service, estimated_pages, len(content) ) final_billable = final_pages * mode.page_multiplier - await etl_credit_service.charge_credits(user_id, final_billable) + await page_limit_service.update_page_usage( + user_id, final_billable, allow_exceed=True + ) else: failed_count += 1 diff --git a/surfsense_backend/app/tasks/connector_indexers/onedrive_indexer.py b/surfsense_backend/app/tasks/connector_indexers/onedrive_indexer.py index 1a83551fb..3fd8a79f2 100644 --- a/surfsense_backend/app/tasks/connector_indexers/onedrive_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/onedrive_indexer.py @@ -28,7 +28,7 @@ from app.indexing_pipeline.connector_document import ConnectorDocument from app.indexing_pipeline.document_hashing import compute_identifier_hash from app.indexing_pipeline.exceptions import safe_exception_message from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService -from app.services.etl_credit_service import EtlCreditService +from app.services.page_limit_service import PageLimitService from app.services.task_logging_service import TaskLoggingService from app.tasks.connector_indexers.base import ( check_document_by_unique_identifier, @@ -318,8 +318,9 @@ async def _index_selected_files( vision_llm=None, ) -> tuple[int, int, int, list[str]]: """Index user-selected files using the parallel pipeline.""" - etl_credit_service = EtlCreditService(session) - available_micros = await etl_credit_service.get_available_micros(user_id) + page_limit_service = PageLimitService(session) + pages_used, pages_limit = await page_limit_service.get_page_usage(user_id) + remaining_quota = pages_limit - pages_used batch_estimated_pages = 0 files_to_download: list[dict] = [] @@ -345,16 +346,12 @@ async def _index_selected_files( skipped += 1 continue - file_pages = EtlCreditService.estimate_pages_from_metadata( + file_pages = PageLimitService.estimate_pages_from_metadata( file.get("name", ""), file.get("size") ) - if ( - available_micros is not None - and EtlCreditService.pages_to_micros(batch_estimated_pages + file_pages) - > available_micros - ): + if batch_estimated_pages + file_pages > remaining_quota: display = file_name or file_id - errors.append(f"File '{display}': insufficient credits") + errors.append(f"File '{display}': page limit would be exceeded") continue batch_estimated_pages += file_pages @@ -375,7 +372,9 @@ async def _index_selected_files( pages_to_deduct = max( 1, batch_estimated_pages * batch_indexed // len(files_to_download) ) - await etl_credit_service.charge_credits(user_id, pages_to_deduct) + await page_limit_service.update_page_usage( + user_id, pages_to_deduct, allow_exceed=True + ) return renamed_count + batch_indexed, skipped, unsupported_count, errors @@ -414,8 +413,9 @@ async def _index_full_scan( }, ) - etl_credit_service = EtlCreditService(session) - available_micros = await etl_credit_service.get_available_micros(user_id) + page_limit_service = PageLimitService(session) + pages_used, pages_limit = await page_limit_service.get_page_usage(user_id) + remaining_quota = pages_limit - pages_used batch_estimated_pages = 0 page_limit_reached = False @@ -448,17 +448,13 @@ async def _index_full_scan( skipped += 1 continue - file_pages = EtlCreditService.estimate_pages_from_metadata( + file_pages = PageLimitService.estimate_pages_from_metadata( file.get("name", ""), file.get("size") ) - if ( - available_micros is not None - and EtlCreditService.pages_to_micros(batch_estimated_pages + file_pages) - > available_micros - ): + if batch_estimated_pages + file_pages > remaining_quota: if not page_limit_reached: logger.warning( - "Insufficient credits during OneDrive full scan, " + "Page limit reached during OneDrive full scan, " "skipping remaining files" ) page_limit_reached = True @@ -483,7 +479,9 @@ async def _index_full_scan( pages_to_deduct = max( 1, batch_estimated_pages * batch_indexed // len(files_to_download) ) - await etl_credit_service.charge_credits(user_id, pages_to_deduct) + await page_limit_service.update_page_usage( + user_id, pages_to_deduct, allow_exceed=True + ) indexed = renamed_count + batch_indexed logger.info( @@ -534,8 +532,9 @@ async def _index_with_delta_sync( logger.info(f"Processing {len(changes)} delta changes") - etl_credit_service = EtlCreditService(session) - available_micros = await etl_credit_service.get_available_micros(user_id) + page_limit_service = PageLimitService(session) + pages_used, pages_limit = await page_limit_service.get_page_usage(user_id) + remaining_quota = pages_limit - pages_used batch_estimated_pages = 0 page_limit_reached = False @@ -572,17 +571,13 @@ async def _index_with_delta_sync( skipped += 1 continue - file_pages = EtlCreditService.estimate_pages_from_metadata( + file_pages = PageLimitService.estimate_pages_from_metadata( change.get("name", ""), change.get("size") ) - if ( - available_micros is not None - and EtlCreditService.pages_to_micros(batch_estimated_pages + file_pages) - > available_micros - ): + if batch_estimated_pages + file_pages > remaining_quota: if not page_limit_reached: logger.warning( - "Insufficient credits during OneDrive delta sync, " + "Page limit reached during OneDrive delta sync, " "skipping remaining files" ) page_limit_reached = True @@ -607,7 +602,9 @@ async def _index_with_delta_sync( pages_to_deduct = max( 1, batch_estimated_pages * batch_indexed // len(files_to_download) ) - await etl_credit_service.charge_credits(user_id, pages_to_deduct) + await page_limit_service.update_page_usage( + user_id, pages_to_deduct, allow_exceed=True + ) indexed = renamed_count + batch_indexed logger.info( diff --git a/surfsense_backend/app/tasks/document_processors/file_processors.py b/surfsense_backend/app/tasks/document_processors/file_processors.py index 174ac966d..f6929b87c 100644 --- a/surfsense_backend/app/tasks/document_processors/file_processors.py +++ b/surfsense_backend/app/tasks/document_processors/file_processors.py @@ -1,9 +1,8 @@ """ File document processors orchestrating content extraction and indexing. -Delegates content extraction to the cache-aware ``extract_with_cache`` facade -(over ``EtlPipelineService``) and keeps only orchestration concerns -(notifications, logging, page limits, saving). +Delegates content extraction to ``app.etl_pipeline.EtlPipelineService`` and +keeps only orchestration concerns (notifications, logging, page limits, saving). """ from __future__ import annotations @@ -80,10 +79,10 @@ async def _notify( # --------------------------------------------------------------------------- -def _estimate_pages_safe(etl_credit_service, file_path: str) -> int: +def _estimate_pages_safe(page_limit_service, file_path: str) -> int: """Estimate page count with a file-size fallback.""" try: - return etl_credit_service.estimate_pages_before_processing(file_path) + return page_limit_service.estimate_pages_before_processing(file_path) except Exception: file_size = os.path.getsize(file_path) return max(1, file_size // (80 * 1024)) @@ -117,8 +116,8 @@ async def _log_page_divergence( async def _process_non_document_upload(ctx: _ProcessingContext) -> Document | None: """Extract content from a non-document file (plaintext/direct_convert/audio/image) via the unified ETL pipeline.""" - from app.etl_pipeline.cache import extract_with_cache from app.etl_pipeline.etl_document import EtlRequest + from app.etl_pipeline.etl_pipeline_service import EtlPipelineService await _notify(ctx, "parsing", "Processing file") await ctx.task_logger.log_task_progress( @@ -137,9 +136,8 @@ async def _process_non_document_upload(ctx: _ProcessingContext) -> Document | No vision_llm = await get_vision_llm(ctx.session, ctx.search_space_id) - etl_result = await extract_with_cache( - EtlRequest(file_path=ctx.file_path, filename=ctx.filename), - vision_llm=vision_llm, + etl_result = await EtlPipelineService(vision_llm=vision_llm).extract( + EtlRequest(file_path=ctx.file_path, filename=ctx.filename) ) with contextlib.suppress(Exception): @@ -185,16 +183,13 @@ async def _process_non_document_upload(ctx: _ProcessingContext) -> Document | No async def _process_document_upload(ctx: _ProcessingContext) -> Document | None: """Route a document file to the configured ETL service via the unified pipeline.""" - from app.etl_pipeline.cache import extract_with_cache from app.etl_pipeline.etl_document import EtlRequest, ProcessingMode - from app.services.etl_credit_service import ( - EtlCreditService, - InsufficientCreditsError, - ) + from app.etl_pipeline.etl_pipeline_service import EtlPipelineService + from app.services.page_limit_service import PageLimitExceededError, PageLimitService mode = ProcessingMode.coerce(ctx.processing_mode) - etl_credit_service = EtlCreditService(ctx.session) - estimated_pages = _estimate_pages_safe(etl_credit_service, ctx.file_path) + page_limit_service = PageLimitService(ctx.session) + estimated_pages = _estimate_pages_safe(page_limit_service, ctx.file_path) billable_pages = estimated_pages * mode.page_multiplier await ctx.task_logger.log_task_progress( @@ -209,16 +204,16 @@ async def _process_document_upload(ctx: _ProcessingContext) -> Document | None: ) try: - await etl_credit_service.check_credits(ctx.user_id, billable_pages) - except InsufficientCreditsError as e: + await page_limit_service.check_page_limit(ctx.user_id, billable_pages) + except PageLimitExceededError as e: await ctx.task_logger.log_task_failure( ctx.log_entry, - f"Insufficient credits before processing: {ctx.filename}", + f"Page limit exceeded before processing: {ctx.filename}", str(e), { - "error_type": "InsufficientCredits", - "balance_micros": e.balance_micros, - "required_micros": e.required_micros, + "error_type": "PageLimitExceeded", + "pages_used": e.pages_used, + "pages_limit": e.pages_limit, "estimated_pages": estimated_pages, "billable_pages": billable_pages, "processing_mode": mode.value, @@ -239,14 +234,13 @@ async def _process_document_upload(ctx: _ProcessingContext) -> Document | None: vision_llm = await get_vision_llm(ctx.session, ctx.search_space_id) - etl_result = await extract_with_cache( + etl_result = await EtlPipelineService(vision_llm=vision_llm).extract( EtlRequest( file_path=ctx.file_path, filename=ctx.filename, estimated_pages=estimated_pages, processing_mode=mode, - ), - vision_llm=vision_llm, + ) ) with contextlib.suppress(Exception): @@ -265,7 +259,9 @@ async def _process_document_upload(ctx: _ProcessingContext) -> Document | None: ) if result: - await etl_credit_service.charge_credits(ctx.user_id, billable_pages) + await page_limit_service.update_page_usage( + ctx.user_id, billable_pages, allow_exceed=True + ) if ctx.connector: await update_document_from_connector(result, ctx.connector, ctx.session) await ctx.task_logger.log_task_success( @@ -341,11 +337,11 @@ async def process_file_in_background( except Exception as e: await session.rollback() - from app.services.etl_credit_service import InsufficientCreditsError + from app.services.page_limit_service import PageLimitExceededError - if isinstance(e, InsufficientCreditsError): + if isinstance(e, PageLimitExceededError): error_message = str(e) - elif isinstance(e, HTTPException) and "credit" in str(e.detail).lower(): + elif isinstance(e, HTTPException) and "page limit" in str(e.detail).lower(): error_message = str(e.detail) else: error_message = f"Failed to process file: {filename}" @@ -384,6 +380,7 @@ async def _extract_file_content( Tuple of (markdown_content, etl_service_name, billable_pages). """ from app.etl_pipeline.etl_document import EtlRequest, ProcessingMode + from app.etl_pipeline.etl_pipeline_service import EtlPipelineService from app.etl_pipeline.file_classifier import ( FileCategory, classify_file as etl_classify, @@ -417,12 +414,12 @@ async def _extract_file_content( ) if category == FileCategory.DOCUMENT: - from app.services.etl_credit_service import EtlCreditService + from app.services.page_limit_service import PageLimitService - etl_credit_service = EtlCreditService(session) - estimated_pages = _estimate_pages_safe(etl_credit_service, file_path) + page_limit_service = PageLimitService(session) + estimated_pages = _estimate_pages_safe(page_limit_service, file_path) billable_pages = estimated_pages * mode.page_multiplier - await etl_credit_service.check_credits(user_id, billable_pages) + await page_limit_service.check_page_limit(user_id, billable_pages) # Vision LLM is provided to the ETL pipeline for any file category # when the operator opts in. Image files run through it directly; @@ -434,16 +431,13 @@ async def _extract_file_content( vision_llm = await get_vision_llm(session, search_space_id) - from app.etl_pipeline.cache import extract_with_cache - - result = await extract_with_cache( + result = await EtlPipelineService(vision_llm=vision_llm).extract( EtlRequest( file_path=file_path, filename=filename, estimated_pages=estimated_pages, processing_mode=mode, - ), - vision_llm=vision_llm, + ) ) with contextlib.suppress(Exception): @@ -530,10 +524,12 @@ async def process_file_in_background_with_document( ) if billable_pages > 0: - from app.services.etl_credit_service import EtlCreditService + from app.services.page_limit_service import PageLimitService - etl_credit_service = EtlCreditService(session) - await etl_credit_service.charge_credits(user_id, billable_pages) + page_limit_service = PageLimitService(session) + await page_limit_service.update_page_usage( + user_id, billable_pages, allow_exceed=True + ) await task_logger.log_task_success( log_entry, @@ -551,11 +547,11 @@ async def process_file_in_background_with_document( except Exception as e: await session.rollback() - from app.services.etl_credit_service import InsufficientCreditsError + from app.services.page_limit_service import PageLimitExceededError - if isinstance(e, InsufficientCreditsError): + if isinstance(e, PageLimitExceededError): error_message = str(e) - elif isinstance(e, HTTPException) and "credit" in str(e.detail).lower(): + elif isinstance(e, HTTPException) and "page limit" in str(e.detail).lower(): error_message = str(e.detail) else: error_message = f"Failed to process file: {filename}" diff --git a/surfsense_backend/app/utils/content_utils.py b/surfsense_backend/app/utils/content_utils.py index aae936888..05a4610c7 100644 --- a/surfsense_backend/app/utils/content_utils.py +++ b/surfsense_backend/app/utils/content_utils.py @@ -18,11 +18,6 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload -from app.tasks.chat.llm_history_normalizer import ( - assistant_content_to_llm_text, - user_content_to_llm_content, -) - if TYPE_CHECKING: from app.db import ChatVisibility @@ -100,28 +95,17 @@ async def bootstrap_history_from_db( langchain_messages: list[HumanMessage | AIMessage] = [] for msg in db_messages: + text_content = extract_text_content(msg.content) + if not text_content: + continue if msg.role == "user": - user_content = user_content_to_llm_content( - msg.content, - allow_images=False, - ) - if not user_content: - continue if is_shared: author_name = ( msg.author.display_name if msg.author else None ) or "A team member" - if isinstance(user_content, str): - user_content = f"**[{author_name}]:** {user_content}" - elif user_content and user_content[0].get("type") == "text": - user_content[0] = { - **user_content[0], - "text": f"**[{author_name}]:** {user_content[0].get('text', '')}", - } - langchain_messages.append(HumanMessage(content=user_content)) + text_content = f"**[{author_name}]:** {text_content}" + langchain_messages.append(HumanMessage(content=text_content)) elif msg.role == "assistant": - assistant_text = assistant_content_to_llm_text(msg.content) - if assistant_text: - langchain_messages.append(AIMessage(content=assistant_text)) + langchain_messages.append(AIMessage(content=text_content)) return langchain_messages diff --git a/surfsense_backend/app/utils/document_converters.py b/surfsense_backend/app/utils/document_converters.py index fef51d692..694ae22ac 100644 --- a/surfsense_backend/app/utils/document_converters.py +++ b/surfsense_backend/app/utils/document_converters.py @@ -188,10 +188,8 @@ async def create_document_chunks(content: str) -> list[Chunk]: chunk_texts = [c.text for c in config.chunker_instance.chunk(content)] chunk_embeddings = await asyncio.to_thread(embed_texts, chunk_texts) return [ - Chunk(content=text, embedding=emb, position=i) - for i, (text, emb) in enumerate( - zip(chunk_texts, chunk_embeddings, strict=False) - ) + Chunk(content=text, embedding=emb) + for text, emb in zip(chunk_texts, chunk_embeddings, strict=False) ] diff --git a/surfsense_backend/app/zero_publication.py b/surfsense_backend/app/zero_publication.py index b14ee14d1..d2755d0a1 100644 --- a/surfsense_backend/app/zero_publication.py +++ b/surfsense_backend/app/zero_publication.py @@ -38,7 +38,10 @@ DOCUMENT_COLS = [ USER_COLS = [ "id", - "credit_micros_balance", + "pages_limit", + "pages_used", + "premium_credit_micros_limit", + "premium_credit_micros_used", ] AUTOMATION_RUN_COLS = [ @@ -52,22 +55,6 @@ AUTOMATION_RUN_COLS = [ "created_at", ] -# Enough to drive the lifecycle UI by push: status, the reviewable brief, and -# its version. The bulky source_content and transcript are deliberately excluded -# and fetched over REST when a gate opens. -PODCAST_COLS = [ - "id", - "title", - "status", - "spec", - "spec_version", - "duration_seconds", - "error", - "search_space_id", - "thread_id", - "created_at", -] - ZERO_PUBLICATION: Mapping[str, Sequence[str] | None] = { "notifications": None, "documents": DOCUMENT_COLS, @@ -78,7 +65,6 @@ ZERO_PUBLICATION: Mapping[str, Sequence[str] | None] = { "chat_session_state": None, "user": USER_COLS, "automation_runs": AUTOMATION_RUN_COLS, - "podcasts": PODCAST_COLS, } @@ -86,15 +72,18 @@ def _quote_identifier(identifier: str) -> str: return '"' + identifier.replace('"', '""') + '"' -def _table_columns(conn: Connection, table: str) -> set[str]: - rows = conn.execute( - text( - "SELECT column_name FROM information_schema.columns " - "WHERE table_schema = current_schema() AND table_name = :table" - ), - {"table": table}, - ).fetchall() - return {row[0] for row in rows} +def _column_exists(conn: Connection, table: str, column: str) -> bool: + return ( + conn.execute( + text( + "SELECT 1 FROM information_schema.columns " + "WHERE table_schema = current_schema() " + "AND table_name = :table AND column_name = :column" + ), + {"table": table, "column": column}, + ).fetchone() + is not None + ) def _expected_columns(conn: Connection, table: str) -> list[str] | None: @@ -103,39 +92,17 @@ def _expected_columns(conn: Connection, table: str) -> list[str] | None: return None expected = list(columns) - if table in {"documents", "user", "podcasts"} and "_0_version" in _table_columns( - conn, table - ): + if table in {"documents", "user"} and _column_exists(conn, table, "_0_version"): expected.append("_0_version") return expected -def _format_table_entry(conn: Connection, table: str) -> str | None: - """Render one SET TABLE entry, or ``None`` if the table isn't ready. - - Historical migrations (e.g. 155/156) call ``apply_publication`` while the - schema is still mid-history, before later migrations add columns that the - canonical shape references. A table is only published once it exists AND - every canonical column exists; otherwise it is omitted entirely and a later - reconcile migration (e.g. 159) picks it up once its columns land. Partial - column lists are deliberately avoided: publishing a column early would - block later ``ALTER COLUMN ... TYPE`` migrations on it (Postgres forbids - retyping columns a publication depends on). ``verify_publication`` remains - strict against the unfiltered canonical shape. - """ - - actual = _table_columns(conn, table) - if not actual: - return None - - table_sql = _quote_identifier(table) +def _format_table_entry(conn: Connection, table: str) -> str: columns = _expected_columns(conn, table) + table_sql = _quote_identifier(table) if columns is None: return table_sql - if any(column not in actual for column in columns): - return None - column_sql = ", ".join(_quote_identifier(column) for column in columns) return f"{table_sql} ({column_sql})" @@ -143,8 +110,9 @@ def _format_table_entry(conn: Connection, table: str) -> str | None: def build_set_table_sql(conn: Connection) -> str: """Build the canonical plain SET TABLE statement for Zero's event triggers.""" - entries = [_format_table_entry(conn, table) for table in ZERO_PUBLICATION] - table_list = ", ".join(entry for entry in entries if entry is not None) + table_list = ", ".join( + _format_table_entry(conn, table) for table in ZERO_PUBLICATION + ) return f"ALTER PUBLICATION {_quote_identifier(PUBLICATION_NAME)} SET TABLE {table_list}" diff --git a/surfsense_backend/pyproject.toml b/surfsense_backend/pyproject.toml index 6afc7fd15..16d46445c 100644 --- a/surfsense_backend/pyproject.toml +++ b/surfsense_backend/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "surf-new-backend" -version = "0.0.29" +version = "0.0.27" description = "SurfSense Backend" requires-python = ">=3.12" dependencies = [ @@ -41,6 +41,7 @@ dependencies = [ "elasticsearch>=9.1.1", "faster-whisper>=1.1.0", "celery[redis]>=5.5.3", + "flower>=2.0.1", "redis>=5.2.1", "firecrawl-py>=4.9.0", "boto3>=1.35.0", diff --git a/surfsense_backend/scripts/verify_chat_image_capability.py b/surfsense_backend/scripts/verify_chat_image_capability.py index a6be063eb..a49d4eab2 100644 --- a/surfsense_backend/scripts/verify_chat_image_capability.py +++ b/surfsense_backend/scripts/verify_chat_image_capability.py @@ -55,6 +55,7 @@ from app.services.openrouter_integration_service import ( # noqa: E402 _OPENROUTER_DYNAMIC_MARKER, OpenRouterIntegrationService, ) +from app.services.provider_api_base import resolve_api_base # noqa: E402 from app.services.provider_capabilities import ( # noqa: E402 derive_supports_image_input, is_known_text_only_chat_model, @@ -153,13 +154,13 @@ def _probe_chat_capability(cfg: dict) -> tuple[bool, str]: litellm_params.get("base_model") if isinstance(litellm_params, dict) else None ) cap = derive_supports_image_input( - provider=cfg.get("litellm_provider"), + provider=cfg.get("provider"), model_name=cfg.get("model_name"), base_model=base_model, custom_provider=cfg.get("custom_provider"), ) block = is_known_text_only_chat_model( - provider=cfg.get("litellm_provider"), + provider=cfg.get("provider"), model_name=cfg.get("model_name"), base_model=base_model, custom_provider=cfg.get("custom_provider"), @@ -178,7 +179,11 @@ def _probe_chat_capability(cfg: dict) -> tuple[bool, str]: def _build_chat_model_string(cfg: dict) -> str: if cfg.get("custom_provider"): return f"{cfg['custom_provider']}/{cfg['model_name']}" - prefix = cfg.get("litellm_provider") or "openai" + from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP + + prefix = _PROVIDER_PREFIX_MAP.get( + (cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower() + ) return f"{prefix}/{cfg['model_name']}" @@ -190,6 +195,11 @@ def _build_chat_model_string(cfg: dict) -> str: async def _live_chat_image_call(cfg: dict) -> tuple[bool, str]: """Send a 1x1 PNG + `reply with one word: ok` to the chat config.""" model_string = _build_chat_model_string(cfg) + api_base = resolve_api_base( + provider=cfg.get("provider"), + provider_prefix=model_string.split("/", 1)[0], + config_api_base=cfg.get("api_base") or None, + ) kwargs: dict[str, Any] = { "model": model_string, "api_key": cfg.get("api_key"), @@ -208,8 +218,8 @@ async def _live_chat_image_call(cfg: dict) -> tuple[bool, str]: "max_tokens": 16, "timeout": 60, } - if cfg.get("api_base"): - kwargs["api_base"] = cfg["api_base"] + if api_base: + kwargs["api_base"] = api_base if cfg.get("litellm_params"): # Strip pricing keys — they're tracking-only and confuse some # provider validators (e.g. azure/openai reject unknown kwargs @@ -247,11 +257,20 @@ _IMAGE_GEN_PROMPTS: tuple[str, ...] = ( async def _live_image_gen_call(cfg: dict) -> tuple[bool, str]: """Generate one tiny image to verify the deployment is reachable.""" + from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP + if cfg.get("custom_provider"): prefix = cfg["custom_provider"] else: - prefix = cfg.get("litellm_provider") or "openai" + prefix = _PROVIDER_PREFIX_MAP.get( + (cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower() + ) model_string = f"{prefix}/{cfg['model_name']}" + api_base = resolve_api_base( + provider=cfg.get("provider"), + provider_prefix=prefix, + config_api_base=cfg.get("api_base") or None, + ) base_kwargs: dict[str, Any] = { "model": model_string, "api_key": cfg.get("api_key"), @@ -259,8 +278,8 @@ async def _live_image_gen_call(cfg: dict) -> tuple[bool, str]: "size": "1024x1024", "timeout": 120, } - if cfg.get("api_base"): - base_kwargs["api_base"] = cfg["api_base"] + if api_base: + base_kwargs["api_base"] = api_base if cfg.get("api_version"): base_kwargs["api_version"] = cfg["api_version"] if cfg.get("litellm_params"): @@ -330,6 +349,31 @@ async def probe_chat_configs(report: Report, *, live: bool) -> None: report.add(result) +async def probe_vision_configs(report: Report, *, live: bool) -> None: + print("\n[vision configs from global_vision_llm_configs (YAML-static)]") + for cfg in config.GLOBAL_VISION_LLM_CONFIGS: + if _is_or_dynamic(cfg): + continue + result = ProbeResult( + label=str(cfg.get("name") or cfg.get("model_name")), + surface="vision", + config_id=cfg.get("id"), + ) + # For vision configs, capability is implied — they're in the + # dedicated vision pool. Run the same resolver to flag any + # surprise disagreement. + cap_ok, cap_note = _probe_chat_capability(cfg) + result.capability_ok = cap_ok + result.capability_note = cap_note + if live: + t0 = time.perf_counter() + ok, note = await _live_chat_image_call(cfg) + result.live_ok = ok + result.live_note = note + result.duration_s = time.perf_counter() - t0 + report.add(result) + + async def probe_image_gen_configs(report: Report, *, live: bool) -> None: print( "\n[image generation configs from global_image_generation_configs (YAML-static)]" @@ -355,7 +399,7 @@ async def probe_image_gen_configs(report: Report, *, live: bool) -> None: async def probe_openrouter_catalog(report: Report, *, live: bool) -> None: - """Sample chat/vision-capable and image-gen models + """Sample one chat (vision-capable), one vision, one image-gen model from the live OpenRouter catalogue. Doesn't iterate the full pool (would be hundreds of probes); just validates the integration end- to-end on a representative model from each surface.""" @@ -380,6 +424,9 @@ async def probe_openrouter_catalog(report: Report, *, live: bool) -> None: for c in config.GLOBAL_LLM_CONFIGS if c.get("provider") == "OPENROUTER" and c.get("supports_image_input") ] + or_vision = [ + c for c in config.GLOBAL_VISION_LLM_CONFIGS if c.get("provider") == "OPENROUTER" + ] or_image_gen = [ c for c in config.GLOBAL_IMAGE_GEN_CONFIGS if c.get("provider") == "OPENROUTER" ] @@ -399,6 +446,11 @@ async def probe_openrouter_catalog(report: Report, *, live: bool) -> None: ("or-chat", _pick_first(or_chat, "anthropic/claude")), ("or-chat", _pick_first(or_chat, "google/gemini-2.5-flash")), ] + vision_picks = [ + ("or-vision", _pick_first(or_vision, "openai/gpt-4o")), + ("or-vision", _pick_first(or_vision, "anthropic/claude")), + ("or-vision", _pick_first(or_vision, "google/gemini-2.5-flash")), + ] image_picks = [ ("or-image", _pick_first(or_image_gen, "google/gemini-2.5-flash-image")), # OpenRouter publishes OpenAI image models as ``openai/gpt-5-image*`` @@ -408,11 +460,11 @@ async def probe_openrouter_catalog(report: Report, *, live: bool) -> None: ] print( - f" catalog: chat_vision={len(or_chat)} image_gen={len(or_image_gen)} " + f" catalog: chat={len(or_chat)} vision={len(or_vision)} image_gen={len(or_image_gen)} " f"(service initialized={service.is_initialized() if hasattr(service, 'is_initialized') else 'n/a'})" ) - for surface, picked in chat_picks + image_picks: + for surface, picked in chat_picks + vision_picks + image_picks: if not picked: report.add( ProbeResult( @@ -453,6 +505,7 @@ async def probe_openrouter_catalog(report: Report, *, live: bool) -> None: async def main(args: argparse.Namespace) -> int: print("Loaded global configs:") print(f" chat: {len(config.GLOBAL_LLM_CONFIGS)} entries") + print(f" vision: {len(config.GLOBAL_VISION_LLM_CONFIGS)} entries") print(f" image-gen: {len(config.GLOBAL_IMAGE_GEN_CONFIGS)} entries") print(f" OR settings present: {bool(config.OPENROUTER_INTEGRATION_SETTINGS)}") @@ -473,6 +526,8 @@ async def main(args: argparse.Namespace) -> int: report = Report() if not args.skip_chat: await probe_chat_configs(report, live=args.live) + if not args.skip_vision: + await probe_vision_configs(report, live=args.live) if not args.skip_image_gen: await probe_image_gen_configs(report, live=args.live) if not args.skip_openrouter: @@ -492,6 +547,7 @@ def _parse_args() -> argparse.Namespace: ) parser.set_defaults(live=True) parser.add_argument("--skip-chat", action="store_true") + parser.add_argument("--skip-vision", action="store_true") parser.add_argument("--skip-image-gen", action="store_true") parser.add_argument("--skip-openrouter", action="store_true") return parser.parse_args() diff --git a/surfsense_backend/tests/conftest.py b/surfsense_backend/tests/conftest.py index e227ed287..e2b586aa2 100644 --- a/surfsense_backend/tests/conftest.py +++ b/surfsense_backend/tests/conftest.py @@ -13,14 +13,6 @@ TEST_DATABASE_URL = os.environ.get("TEST_DATABASE_URL", _DEFAULT_TEST_DB) # DATABASE_URL in the environment (e.g. from .env or shell profile). os.environ["DATABASE_URL"] = TEST_DATABASE_URL -# Integration tests authenticate over HTTP via email/password, so the -# password-auth routers must be mounted (they are skipped under AUTH_TYPE=GOOGLE). -# setdefault (not load_dotenv, which runs later with override=False) lets a -# developer's .env=GOOGLE be overridden here while still honouring an explicitly -# exported shell AUTH_TYPE. -os.environ.setdefault("AUTH_TYPE", "LOCAL") -os.environ.setdefault("REGISTRATION_ENABLED", "TRUE") - import pytest # noqa: E402 from app.db import DocumentType # noqa: E402 diff --git a/surfsense_backend/tests/e2e/fakes/embeddings.py b/surfsense_backend/tests/e2e/fakes/embeddings.py index 9a01fb84b..ab9e24df9 100644 --- a/surfsense_backend/tests/e2e/fakes/embeddings.py +++ b/surfsense_backend/tests/e2e/fakes/embeddings.py @@ -57,9 +57,9 @@ def install(patches: list[Any]) -> None: # Consumers that did `from app.utils.document_converters import embed_text/texts` ("app.indexing_pipeline.document_embedder.embed_text", fake_embed_text), ("app.indexing_pipeline.document_embedder.embed_texts", fake_embed_texts), - # Index-cache facade binding (the actual call site for indexing.index) + # Pipeline service binding (the actual call site for indexing.index) ( - "app.indexing_pipeline.cache.cached_indexing.embed_texts", + "app.indexing_pipeline.indexing_pipeline_service.embed_texts", fake_embed_texts, ), ] diff --git a/surfsense_backend/tests/e2e/fixtures/global_llm_config.yaml b/surfsense_backend/tests/e2e/fixtures/global_llm_config.yaml index 9ea5e1a29..017fa1eb3 100644 --- a/surfsense_backend/tests/e2e/fixtures/global_llm_config.yaml +++ b/surfsense_backend/tests/e2e/fixtures/global_llm_config.yaml @@ -19,7 +19,7 @@ # so the resolved auto-pin id is never sent to a real LLM provider. # The values below only need to pass # auto_model_pin_service._is_usable_global_config() -# which requires id / model_name / litellm_provider / api_key all truthy. +# which requires id / model_name / provider / api_key all truthy. # # Why TWO entries (premium + free): # auto_model_pin_service.resolve_or_get_pinned_llm_config_id() splits @@ -44,10 +44,9 @@ global_llm_configs: anonymous_enabled: false seo_enabled: false quality_score: 1.0 - litellm_provider: "openai" + provider: "OPENAI" model_name: "fake-e2e-model-premium" api_key: "fake-e2e-api-key-not-for-production" - api_base: "https://api.openai.com/v1" supports_image_input: false quota_reserve_tokens: 1024 rpm: 1000 @@ -61,10 +60,9 @@ global_llm_configs: anonymous_enabled: false seo_enabled: false quality_score: 1.0 - litellm_provider: "openai" + provider: "OPENAI" model_name: "fake-e2e-model-free" api_key: "fake-e2e-api-key-not-for-production" - api_base: "https://api.openai.com/v1" supports_image_input: false quota_reserve_tokens: 1024 rpm: 1000 diff --git a/surfsense_backend/tests/integration/conftest.py b/surfsense_backend/tests/integration/conftest.py index 6b8aa3cdb..19f8e3d0a 100644 --- a/surfsense_backend/tests/integration/conftest.py +++ b/surfsense_backend/tests/integration/conftest.py @@ -123,24 +123,11 @@ async def db_search_space(db_session: AsyncSession, db_user: User) -> SearchSpac return space -@pytest.fixture(autouse=True) -def _derivation_caches_disabled(monkeypatch): - """Keep integration tests hermetic regardless of the developer's .env. - - With the embedding cache enabled, a successful index of some markdown makes - every later index of the same markdown a cache hit -- silently bypassing - patched ``embed_texts`` fakes/failure injections in unrelated tests. Cache - tests opt back in explicitly via ``monkeypatch.setattr``. - """ - monkeypatch.setattr(app_config, "ETL_CACHE_ENABLED", False) - monkeypatch.setattr(app_config, "EMBEDDING_CACHE_ENABLED", False) - - @pytest.fixture def patched_embed_texts(monkeypatch) -> MagicMock: mock = MagicMock(side_effect=lambda texts: [[0.1] * _EMBEDDING_DIM for _ in texts]) monkeypatch.setattr( - "app.indexing_pipeline.cache.cached_indexing.embed_texts", + "app.indexing_pipeline.indexing_pipeline_service.embed_texts", mock, ) return mock @@ -150,7 +137,7 @@ def patched_embed_texts(monkeypatch) -> MagicMock: def patched_embed_texts_raises(monkeypatch) -> MagicMock: mock = MagicMock(side_effect=RuntimeError("Embedding unavailable")) monkeypatch.setattr( - "app.indexing_pipeline.cache.cached_indexing.embed_texts", + "app.indexing_pipeline.indexing_pipeline_service.embed_texts", mock, ) return mock @@ -160,11 +147,11 @@ def patched_embed_texts_raises(monkeypatch) -> MagicMock: def patched_chunk_text(monkeypatch) -> MagicMock: mock = MagicMock(return_value=["Test chunk content."]) monkeypatch.setattr( - "app.indexing_pipeline.cache.cached_indexing.chunk_text", + "app.indexing_pipeline.indexing_pipeline_service.chunk_text", mock, ) monkeypatch.setattr( - "app.indexing_pipeline.cache.cached_indexing.chunk_text_hybrid", + "app.indexing_pipeline.indexing_pipeline_service.chunk_text_hybrid", mock, ) return mock diff --git a/surfsense_backend/tests/integration/document_upload/conftest.py b/surfsense_backend/tests/integration/document_upload/conftest.py index bd889360f..13e3ab59c 100644 --- a/surfsense_backend/tests/integration/document_upload/conftest.py +++ b/surfsense_backend/tests/integration/document_upload/conftest.py @@ -204,34 +204,32 @@ async def _cleanup_documents( # --------------------------------------------------------------------------- -# Credit-wallet helpers (direct DB for setup, API for verification) +# Page-limit helpers (direct DB for setup, API for verification) # --------------------------------------------------------------------------- -async def _get_user_credit(email: str) -> tuple[int, int]: +async def _get_user_page_usage(email: str) -> tuple[int, int]: conn = await asyncpg.connect(_ASYNCPG_URL) try: row = await conn.fetchrow( - "SELECT credit_micros_balance, credit_micros_reserved " - 'FROM "user" WHERE email = $1', + 'SELECT pages_used, pages_limit FROM "user" WHERE email = $1', email, ) assert row is not None, f"User {email!r} not found in database" - return row["credit_micros_balance"], row["credit_micros_reserved"] + return row["pages_used"], row["pages_limit"] finally: await conn.close() -async def _set_user_credit( - email: str, *, balance_micros: int, reserved_micros: int = 0 +async def _set_user_page_limits( + email: str, *, pages_used: int, pages_limit: int ) -> None: conn = await asyncpg.connect(_ASYNCPG_URL) try: await conn.execute( - 'UPDATE "user" SET credit_micros_balance = $1, ' - "credit_micros_reserved = $2 WHERE email = $3", - balance_micros, - reserved_micros, + 'UPDATE "user" SET pages_used = $1, pages_limit = $2 WHERE email = $3', + pages_used, + pages_limit, email, ) finally: @@ -239,39 +237,23 @@ async def _set_user_credit( @pytest.fixture -async def credits(): - """Manipulate the test user's credit wallet (direct DB for setup only). +async def page_limits(): + """Manipulate the test user's page limits (direct DB for setup only). - Force-enables ETL credit billing for the duration of the test (it is off - by default for self-hosted/OSS, which would bypass all gating), and - automatically restores the original balance and billing flag afterwards. - - ``MICROS_PER_PAGE`` is exposed so callers can size balances by page count. + Automatically restores original values after each test. """ - class _Credits: - micros_per_page = app_config.MICROS_PER_PAGE - - async def set(self, *, balance_micros: int, reserved_micros: int = 0) -> None: - await _set_user_credit( - TEST_EMAIL, - balance_micros=balance_micros, - reserved_micros=reserved_micros, + class _PageLimits: + async def set(self, *, pages_used: int, pages_limit: int) -> None: + await _set_user_page_limits( + TEST_EMAIL, pages_used=pages_used, pages_limit=pages_limit ) - def pages(self, n: int) -> int: - return n * app_config.MICROS_PER_PAGE - - original_billing = app_config.ETL_CREDIT_BILLING_ENABLED - app_config.ETL_CREDIT_BILLING_ENABLED = True - original = await _get_user_credit(TEST_EMAIL) - try: - yield _Credits() - finally: - app_config.ETL_CREDIT_BILLING_ENABLED = original_billing - await _set_user_credit( - TEST_EMAIL, balance_micros=original[0], reserved_micros=original[1] - ) + original = await _get_user_page_usage(TEST_EMAIL) + yield _PageLimits() + await _set_user_page_limits( + TEST_EMAIL, pages_used=original[0], pages_limit=original[1] + ) # --------------------------------------------------------------------------- @@ -283,11 +265,11 @@ async def credits(): def _mock_external_apis(monkeypatch): """Mock LLM, embedding, and chunking — these are external API boundaries.""" monkeypatch.setattr( - "app.indexing_pipeline.cache.cached_indexing.embed_texts", + "app.indexing_pipeline.indexing_pipeline_service.embed_texts", MagicMock(side_effect=lambda texts: [[0.1] * _EMBEDDING_DIM for _ in texts]), ) monkeypatch.setattr( - "app.indexing_pipeline.cache.cached_indexing.chunk_text", + "app.indexing_pipeline.indexing_pipeline_service.chunk_text", MagicMock(return_value=["Test chunk content."]), ) diff --git a/surfsense_backend/tests/integration/document_upload/test_etl_credits.py b/surfsense_backend/tests/integration/document_upload/test_page_limits.py similarity index 67% rename from surfsense_backend/tests/integration/document_upload/test_etl_credits.py rename to surfsense_backend/tests/integration/document_upload/test_page_limits.py index 6a2972598..985fd7128 100644 --- a/surfsense_backend/tests/integration/document_upload/test_etl_credits.py +++ b/surfsense_backend/tests/integration/document_upload/test_page_limits.py @@ -1,14 +1,14 @@ """ -Integration tests for ETL credit enforcement during document upload. +Integration tests for page-limit enforcement during document upload. -These tests manipulate the test user's ``credit_micros_balance`` column -directly in the database (setup only) and then exercise the upload pipeline -to verify that: +These tests manipulate the test user's ``pages_used`` / ``pages_limit`` +columns directly in the database (setup only) and then exercise the upload +pipeline to verify that: - - Uploads are rejected *before* ETL when the wallet can't cover the cost. - - The balance decreases after a successful upload (verified via API). - - An ``insufficient_credits`` notification is created on rejection. - - The balance is not modified when a document fails processing. + - Uploads are rejected *before* ETL when the limit is exhausted. + - ``pages_used`` increases after a successful upload (verified via API). + - A ``page_limit_exceeded`` notification is created on rejection. + - ``pages_used`` is not modified when a document fails processing. All tests reuse the existing small fixtures (``sample.pdf``, ``sample.txt``) so no additional processing time is introduced. @@ -32,37 +32,36 @@ pytestmark = pytest.mark.integration # --------------------------------------------------------------------------- -# Helper: read credit balance through the public API +# Helper: read pages_used through the public API # --------------------------------------------------------------------------- -async def _get_balance(client: httpx.AsyncClient, headers: dict[str, str]) -> int: - """Fetch the current user's credit_micros_balance via the /users/me API.""" +async def _get_pages_used(client: httpx.AsyncClient, headers: dict[str, str]) -> int: + """Fetch the current user's pages_used via the /users/me API.""" resp = await client.get("/users/me", headers=headers) assert resp.status_code == 200, ( f"GET /users/me failed ({resp.status_code}): {resp.text}" ) - return resp.json()["credit_micros_balance"] + return resp.json()["pages_used"] # --------------------------------------------------------------------------- -# Test A: Successful upload decrements the balance +# Test A: Successful upload increments pages_used # --------------------------------------------------------------------------- -class TestBalanceDecrementsOnSuccess: - """After a successful PDF upload the user's balance must shrink.""" +class TestPageUsageIncrementsOnSuccess: + """After a successful PDF upload the user's ``pages_used`` must grow.""" - async def test_balance_decreases_after_pdf_upload( + async def test_pages_used_increases_after_pdf_upload( self, client: httpx.AsyncClient, headers: dict[str, str], search_space_id: int, cleanup_doc_ids: list[int], - credits, + page_limits, ): - await credits.set(balance_micros=credits.pages(1000)) - before = await _get_balance(client, headers) + await page_limits.set(pages_used=0, pages_limit=1000) resp = await upload_file( client, headers, "sample.pdf", search_space_id=search_space_id @@ -77,28 +76,30 @@ class TestBalanceDecrementsOnSuccess: for did in doc_ids: assert statuses[did]["status"]["state"] == "ready" - after = await _get_balance(client, headers) - assert after < before, "balance should have dropped after successful processing" + used = await _get_pages_used(client, headers) + assert used > 0, "pages_used should have increased after successful processing" # --------------------------------------------------------------------------- -# Test B: Upload rejected when the wallet is empty +# Test B: Upload rejected when page limit is fully exhausted # --------------------------------------------------------------------------- -class TestUploadRejectedWhenCreditExhausted: - """When the balance is zero the document should reach ``failed`` status - with an insufficient-credit reason.""" +class TestUploadRejectedWhenLimitExhausted: + """ + When ``pages_used == pages_limit`` (zero remaining) the document + should reach ``failed`` status with a page-limit reason. + """ - async def test_pdf_fails_when_no_credit_remaining( + async def test_pdf_fails_when_no_pages_remaining( self, client: httpx.AsyncClient, headers: dict[str, str], search_space_id: int, cleanup_doc_ids: list[int], - credits, + page_limits, ): - await credits.set(balance_micros=0) + await page_limits.set(pages_used=100, pages_limit=100) resp = await upload_file( client, headers, "sample.pdf", search_space_id=search_space_id @@ -113,19 +114,19 @@ class TestUploadRejectedWhenCreditExhausted: for did in doc_ids: assert statuses[did]["status"]["state"] == "failed" reason = statuses[did]["status"].get("reason", "").lower() - assert "credit" in reason, ( - f"Expected 'credit' in failure reason, got: {reason!r}" + assert "page limit" in reason, ( + f"Expected 'page limit' in failure reason, got: {reason!r}" ) - async def test_balance_unchanged_after_rejection( + async def test_pages_used_unchanged_after_limit_rejection( self, client: httpx.AsyncClient, headers: dict[str, str], search_space_id: int, cleanup_doc_ids: list[int], - credits, + page_limits, ): - await credits.set(balance_micros=0) + await page_limits.set(pages_used=50, pages_limit=50) resp = await upload_file( client, headers, "sample.pdf", search_space_id=search_space_id @@ -138,30 +139,30 @@ class TestUploadRejectedWhenCreditExhausted: client, headers, doc_ids, search_space_id=search_space_id, timeout=300.0 ) - balance = await _get_balance(client, headers) - assert balance == 0, ( - f"balance should remain 0 after rejected upload, got {balance}" + used = await _get_pages_used(client, headers) + assert used == 50, ( + f"pages_used should remain 50 after rejected upload, got {used}" ) # --------------------------------------------------------------------------- -# Test C: Insufficient-credits notification is created on rejection +# Test C: Page-limit notification is created on rejection # --------------------------------------------------------------------------- -class TestInsufficientCreditsNotification: - """An ``insufficient_credits`` notification must be created when upload - is rejected due to an empty wallet.""" +class TestPageLimitNotification: + """A ``page_limit_exceeded`` notification must be created when upload + is rejected due to the limit.""" - async def test_insufficient_credits_notification_created( + async def test_page_limit_exceeded_notification_created( self, client: httpx.AsyncClient, headers: dict[str, str], search_space_id: int, cleanup_doc_ids: list[int], - credits, + page_limits, ): - await credits.set(balance_micros=0) + await page_limits.set(pages_used=100, pages_limit=100) resp = await upload_file( client, headers, "sample.pdf", search_space_id=search_space_id @@ -177,18 +178,19 @@ class TestInsufficientCreditsNotification: notifications = await get_notifications( client, headers, - type_filter="insufficient_credits", + type_filter="page_limit_exceeded", search_space_id=search_space_id, ) assert len(notifications) >= 1, ( - "Expected at least one insufficient_credits notification" + "Expected at least one page_limit_exceeded notification" ) latest = notifications[0] assert ( - "credit" in latest["title"].lower() or "credit" in latest["message"].lower() + "page limit" in latest["title"].lower() + or "page limit" in latest["message"].lower() ), ( - f"Notification should mention credit: title={latest['title']!r}, " + f"Notification should mention page limit: title={latest['title']!r}, " f"message={latest['message']!r}" ) @@ -208,9 +210,9 @@ class TestDocumentProcessingNotification: headers: dict[str, str], search_space_id: int, cleanup_doc_ids: list[int], - credits, + page_limits, ): - await credits.set(balance_micros=credits.pages(1000)) + await page_limits.set(pages_used=0, pages_limit=1000) resp = await upload_file( client, headers, "sample.txt", search_space_id=search_space_id @@ -240,24 +242,23 @@ class TestDocumentProcessingNotification: # --------------------------------------------------------------------------- -# Test E: balance unchanged when a document fails for non-credit reasons +# Test E: pages_used unchanged when a document fails for non-limit reasons # --------------------------------------------------------------------------- -class TestBalanceUnchangedOnProcessingFailure: - """If a document fails during ETL (e.g. empty/corrupt file) rather than a - credit rejection, the balance should remain unchanged.""" +class TestPagesUnchangedOnProcessingFailure: + """If a document fails during ETL (e.g. empty/corrupt file) rather than + a page-limit rejection, ``pages_used`` should remain unchanged.""" - async def test_balance_stable_on_etl_failure( + async def test_pages_used_stable_on_etl_failure( self, client: httpx.AsyncClient, headers: dict[str, str], search_space_id: int, cleanup_doc_ids: list[int], - credits, + page_limits, ): - starting = credits.pages(1000) - await credits.set(balance_micros=starting) + await page_limits.set(pages_used=10, pages_limit=1000) resp = await upload_file( client, headers, "empty.pdf", search_space_id=search_space_id @@ -273,32 +274,28 @@ class TestBalanceUnchangedOnProcessingFailure: for did in doc_ids: assert statuses[did]["status"]["state"] == "failed" - balance = await _get_balance(client, headers) - assert balance == starting, ( - f"balance should remain {starting} after ETL failure, got {balance}" - ) + used = await _get_pages_used(client, headers) + assert used == 10, f"pages_used should remain 10 after ETL failure, got {used}" # --------------------------------------------------------------------------- -# Test F: Second upload rejected after first consumes remaining credit +# Test F: Second upload rejected after first consumes remaining quota # --------------------------------------------------------------------------- -class TestSecondUploadExceedsCredit: - """Upload one PDF successfully, consuming the credit, then verify a second - upload is rejected.""" +class TestSecondUploadExceedsLimit: + """Upload one PDF successfully, consuming the quota, then verify a + second upload is rejected.""" - async def test_second_upload_rejected_after_credit_consumed( + async def test_second_upload_rejected_after_quota_consumed( self, client: httpx.AsyncClient, headers: dict[str, str], search_space_id: int, cleanup_doc_ids: list[int], - credits, + page_limits, ): - # Exactly one page of credit: the first 1-page PDF fits, the second - # is rejected once the wallet hits zero. - await credits.set(balance_micros=credits.pages(1)) + await page_limits.set(pages_used=0, pages_limit=1) resp1 = await upload_file( client, headers, "sample.pdf", search_space_id=search_space_id @@ -330,6 +327,6 @@ class TestSecondUploadExceedsCredit: for did in second_ids: assert statuses2[did]["status"]["state"] == "failed" reason = statuses2[did]["status"].get("reason", "").lower() - assert "credit" in reason, ( - f"Expected 'credit' in failure reason, got: {reason!r}" + assert "page limit" in reason, ( + f"Expected 'page limit' in failure reason, got: {reason!r}" ) diff --git a/surfsense_backend/tests/integration/document_upload/test_stripe_credit_purchases.py b/surfsense_backend/tests/integration/document_upload/test_stripe_page_purchases.py similarity index 76% rename from surfsense_backend/tests/integration/document_upload/test_stripe_credit_purchases.py rename to surfsense_backend/tests/integration/document_upload/test_stripe_page_purchases.py index e1955494d..143c9e252 100644 --- a/surfsense_backend/tests/integration/document_upload/test_stripe_credit_purchases.py +++ b/surfsense_backend/tests/integration/document_upload/test_stripe_page_purchases.py @@ -1,10 +1,3 @@ -"""Integration tests for Stripe credit-pack purchases. - -Buying credit packs tops up ``user.credit_micros_balance``. Legacy page-pack -buying has been removed; these tests exercise the credit checkout session, -webhook fulfillment (idempotent), and the reconciliation fallback. -""" - from __future__ import annotations from types import SimpleNamespace @@ -26,8 +19,6 @@ pytestmark = pytest.mark.integration _ASYNCPG_URL = TEST_DATABASE_URL.replace("postgresql+asyncpg://", "postgresql://") -_CREDIT_MICROS_PER_UNIT = 1_000_000 - async def _execute(query: str, *args) -> None: conn = await asyncpg.connect(_ASYNCPG_URL) @@ -51,12 +42,10 @@ async def _get_user_id(email: str) -> str: return str(row["id"]) -async def _get_balance(email: str) -> int: - row = await _fetchrow( - 'SELECT credit_micros_balance FROM "user" WHERE email = $1', email - ) +async def _get_pages_limit(email: str) -> int: + row = await _fetchrow('SELECT pages_limit FROM "user" WHERE email = $1', email) assert row is not None, f"User {email!r} not found" - return row["credit_micros_balance"] + return row["pages_limit"] def _extract_access_token(response: httpx.Response) -> str | None: @@ -112,23 +101,10 @@ def headers(auth_token: str) -> dict[str, str]: @pytest.fixture(autouse=True) -async def _cleanup_credit_purchases(): - await _execute("DELETE FROM credit_purchases") +async def _cleanup_page_purchases(): + await _execute("DELETE FROM page_purchases") yield - await _execute("DELETE FROM credit_purchases") - - -def _configure_credit_buying(monkeypatch) -> None: - monkeypatch.setattr(stripe_routes.config, "STRIPE_CREDIT_BUYING_ENABLED", True) - monkeypatch.setattr( - stripe_routes.config, "STRIPE_CREDIT_PRICE_ID", "price_credit_1" - ) - monkeypatch.setattr( - stripe_routes.config, "STRIPE_CREDIT_MICROS_PER_UNIT", _CREDIT_MICROS_PER_UNIT - ) - monkeypatch.setattr( - stripe_routes.config, "NEXT_FRONTEND_URL", "http://localhost:3000" - ) + await _execute("DELETE FROM page_purchases") class _FakeCreateStripeClient: @@ -176,19 +152,18 @@ class _FakeReconciliationStripeClient: class TestStripeCheckoutSessionCreation: - async def test_credit_status_reflects_backend_toggle( + async def test_get_status_reflects_backend_toggle( self, client, headers, monkeypatch ): - monkeypatch.setattr(stripe_routes.config, "STRIPE_CREDIT_BUYING_ENABLED", False) - disabled = await client.get("/api/v1/stripe/credit-status", headers=headers) - assert disabled.status_code == 200, disabled.text - assert disabled.json()["credit_buying_enabled"] is False - assert "credit_micros_balance" in disabled.json() + monkeypatch.setattr(stripe_routes.config, "STRIPE_PAGE_BUYING_ENABLED", False) + disabled_response = await client.get("/api/v1/stripe/status", headers=headers) + assert disabled_response.status_code == 200, disabled_response.text + assert disabled_response.json() == {"page_buying_enabled": False} - monkeypatch.setattr(stripe_routes.config, "STRIPE_CREDIT_BUYING_ENABLED", True) - enabled = await client.get("/api/v1/stripe/credit-status", headers=headers) - assert enabled.status_code == 200, enabled.text - assert enabled.json()["credit_buying_enabled"] is True + monkeypatch.setattr(stripe_routes.config, "STRIPE_PAGE_BUYING_ENABLED", True) + enabled_response = await client.get("/api/v1/stripe/status", headers=headers) + assert enabled_response.status_code == 200, enabled_response.text + assert enabled_response.json() == {"page_buying_enabled": True} async def test_create_checkout_session_records_pending_purchase( self, @@ -207,10 +182,14 @@ class TestStripeCheckoutSessionCreation: fake_client = _FakeCreateStripeClient(checkout_session) monkeypatch.setattr(stripe_routes, "get_stripe_client", lambda: fake_client) - _configure_credit_buying(monkeypatch) + monkeypatch.setattr(stripe_routes.config, "STRIPE_PRICE_ID", "price_pages_1000") + monkeypatch.setattr( + stripe_routes.config, "NEXT_FRONTEND_URL", "http://localhost:3000" + ) + monkeypatch.setattr(stripe_routes.config, "STRIPE_PAGES_PER_UNIT", 1000) response = await client.post( - "/api/v1/stripe/create-credit-checkout-session", + "/api/v1/stripe/create-checkout-session", headers=headers, json={"quantity": 2, "search_space_id": search_space_id}, ) @@ -220,7 +199,7 @@ class TestStripeCheckoutSessionCreation: assert fake_client.last_params is not None assert fake_client.last_params["mode"] == "payment" assert fake_client.last_params["line_items"] == [ - {"price": "price_credit_1", "quantity": 2} + {"price": "price_pages_1000", "quantity": 2} ] assert ( fake_client.last_params["success_url"] @@ -231,21 +210,19 @@ class TestStripeCheckoutSessionCreation: fake_client.last_params["cancel_url"] == f"http://localhost:3000/dashboard/{search_space_id}/purchase-cancel" ) - assert fake_client.last_params["metadata"]["purchase_type"] == "credits" purchase = await _fetchrow( """ - SELECT quantity, credit_micros_granted, status, source - FROM credit_purchases + SELECT quantity, pages_granted, status + FROM page_purchases WHERE stripe_checkout_session_id = $1 """, checkout_session.id, ) assert purchase is not None assert purchase["quantity"] == 2 - assert purchase["credit_micros_granted"] == 2 * _CREDIT_MICROS_PER_UNIT + assert purchase["pages_granted"] == 2000 assert purchase["status"] == "PENDING" - assert purchase["source"] == "checkout" async def test_create_checkout_session_returns_503_when_buying_disabled( self, @@ -254,34 +231,34 @@ class TestStripeCheckoutSessionCreation: search_space_id: int, monkeypatch, ): - monkeypatch.setattr(stripe_routes.config, "STRIPE_CREDIT_BUYING_ENABLED", False) + monkeypatch.setattr(stripe_routes.config, "STRIPE_PAGE_BUYING_ENABLED", False) response = await client.post( - "/api/v1/stripe/create-credit-checkout-session", + "/api/v1/stripe/create-checkout-session", headers=headers, json={"quantity": 2, "search_space_id": search_space_id}, ) assert response.status_code == 503, response.text assert ( - response.json()["detail"] == "Credit purchases are temporarily unavailable." + response.json()["detail"] == "Page purchases are temporarily unavailable." ) - count = await _fetchrow("SELECT COUNT(*) AS count FROM credit_purchases") - assert count is not None - assert count["count"] == 0 + purchase_count = await _fetchrow("SELECT COUNT(*) AS count FROM page_purchases") + assert purchase_count is not None + assert purchase_count["count"] == 0 class TestStripeWebhookFulfillment: - async def test_webhook_grants_credit_once( + async def test_webhook_grants_pages_once( self, client, headers, search_space_id: int, - credits, + page_limits, monkeypatch, ): - await credits.set(balance_micros=5_000_000) + await page_limits.set(pages_used=0, pages_limit=100) checkout_session = SimpleNamespace( id="cs_test_webhook_123", @@ -293,16 +270,21 @@ class TestStripeWebhookFulfillment: create_client = _FakeCreateStripeClient(checkout_session) monkeypatch.setattr(stripe_routes, "get_stripe_client", lambda: create_client) - _configure_credit_buying(monkeypatch) + monkeypatch.setattr(stripe_routes.config, "STRIPE_PRICE_ID", "price_pages_1000") + monkeypatch.setattr( + stripe_routes.config, "NEXT_FRONTEND_URL", "http://localhost:3000" + ) + monkeypatch.setattr(stripe_routes.config, "STRIPE_PAGES_PER_UNIT", 1000) create_response = await client.post( - "/api/v1/stripe/create-credit-checkout-session", + "/api/v1/stripe/create-checkout-session", headers=headers, json={"quantity": 3, "search_space_id": search_space_id}, ) assert create_response.status_code == 200, create_response.text - assert await _get_balance(TEST_EMAIL) == 5_000_000 + initial_limit = await _get_pages_limit(TEST_EMAIL) + assert initial_limit == 100 user_id = await _get_user_id(TEST_EMAIL) webhook_checkout_session = SimpleNamespace( @@ -314,8 +296,7 @@ class TestStripeWebhookFulfillment: metadata={ "user_id": user_id, "quantity": "3", - "credit_micros_per_unit": str(_CREDIT_MICROS_PER_UNIT), - "purchase_type": "credits", + "pages_per_unit": "1000", }, ) event = SimpleNamespace( @@ -334,12 +315,13 @@ class TestStripeWebhookFulfillment: ) assert first_response.status_code == 200, first_response.text - assert await _get_balance(TEST_EMAIL) == 5_000_000 + 3 * _CREDIT_MICROS_PER_UNIT + updated_limit = await _get_pages_limit(TEST_EMAIL) + assert updated_limit == 3100 purchase = await _fetchrow( """ SELECT status, amount_total, currency, stripe_payment_intent_id - FROM credit_purchases + FROM page_purchases WHERE stripe_checkout_session_id = $1 """, checkout_session.id, @@ -357,8 +339,7 @@ class TestStripeWebhookFulfillment: ) assert second_response.status_code == 200, second_response.text - # Idempotent: a duplicate webhook does not double-grant. - assert await _get_balance(TEST_EMAIL) == 5_000_000 + 3 * _CREDIT_MICROS_PER_UNIT + assert await _get_pages_limit(TEST_EMAIL) == 3100 class TestStripeReconciliation: @@ -367,10 +348,10 @@ class TestStripeReconciliation: client, headers, search_space_id: int, - credits, + page_limits, monkeypatch, ): - await credits.set(balance_micros=1_000_000) + await page_limits.set(pages_used=220, pages_limit=150) checkout_session = SimpleNamespace( id="cs_test_reconcile_paid_123", @@ -382,15 +363,19 @@ class TestStripeReconciliation: create_client = _FakeCreateStripeClient(checkout_session) monkeypatch.setattr(stripe_routes, "get_stripe_client", lambda: create_client) - _configure_credit_buying(monkeypatch) + monkeypatch.setattr(stripe_routes.config, "STRIPE_PRICE_ID", "price_pages_1000") + monkeypatch.setattr( + stripe_routes.config, "NEXT_FRONTEND_URL", "http://localhost:3000" + ) + monkeypatch.setattr(stripe_routes.config, "STRIPE_PAGES_PER_UNIT", 1000) create_response = await client.post( - "/api/v1/stripe/create-credit-checkout-session", + "/api/v1/stripe/create-checkout-session", headers=headers, json={"quantity": 3, "search_space_id": search_space_id}, ) assert create_response.status_code == 200, create_response.text - assert await _get_balance(TEST_EMAIL) == 1_000_000 + assert await _get_pages_limit(TEST_EMAIL) == 150 reconciled_session = SimpleNamespace( id=checkout_session.id, @@ -417,15 +402,15 @@ class TestStripeReconciliation: 20, ) - await stripe_reconciliation_task._reconcile_pending_credit_purchases() + await stripe_reconciliation_task._reconcile_pending_page_purchases() assert reconcile_client.requested_ids == [checkout_session.id] - assert await _get_balance(TEST_EMAIL) == 1_000_000 + 3 * _CREDIT_MICROS_PER_UNIT + assert await _get_pages_limit(TEST_EMAIL) == 3220 purchase = await _fetchrow( """ SELECT status, amount_total, currency, stripe_payment_intent_id - FROM credit_purchases + FROM page_purchases WHERE stripe_checkout_session_id = $1 """, checkout_session.id, @@ -441,10 +426,10 @@ class TestStripeReconciliation: client, headers, search_space_id: int, - credits, + page_limits, monkeypatch, ): - await credits.set(balance_micros=500_000) + await page_limits.set(pages_used=0, pages_limit=500) checkout_session = SimpleNamespace( id="cs_test_reconcile_expired_123", @@ -456,10 +441,14 @@ class TestStripeReconciliation: create_client = _FakeCreateStripeClient(checkout_session) monkeypatch.setattr(stripe_routes, "get_stripe_client", lambda: create_client) - _configure_credit_buying(monkeypatch) + monkeypatch.setattr(stripe_routes.config, "STRIPE_PRICE_ID", "price_pages_1000") + monkeypatch.setattr( + stripe_routes.config, "NEXT_FRONTEND_URL", "http://localhost:3000" + ) + monkeypatch.setattr(stripe_routes.config, "STRIPE_PAGES_PER_UNIT", 1000) create_response = await client.post( - "/api/v1/stripe/create-credit-checkout-session", + "/api/v1/stripe/create-checkout-session", headers=headers, json={"quantity": 1, "search_space_id": search_space_id}, ) @@ -490,14 +479,14 @@ class TestStripeReconciliation: 20, ) - await stripe_reconciliation_task._reconcile_pending_credit_purchases() + await stripe_reconciliation_task._reconcile_pending_page_purchases() - assert await _get_balance(TEST_EMAIL) == 500_000 + assert await _get_pages_limit(TEST_EMAIL) == 500 purchase = await _fetchrow( """ SELECT status - FROM credit_purchases + FROM page_purchases WHERE stripe_checkout_session_id = $1 """, checkout_session.id, diff --git a/surfsense_backend/tests/integration/etl_pipeline/cache/conftest.py b/surfsense_backend/tests/integration/etl_pipeline/cache/conftest.py deleted file mode 100644 index 4369cc64d..000000000 --- a/surfsense_backend/tests/integration/etl_pipeline/cache/conftest.py +++ /dev/null @@ -1,32 +0,0 @@ -"""Real-infra fixtures for the parse-cache integration tests. - -``cache_local_storage`` points the cache's blob store at a throwaway directory so -tests exercise the real ``LocalFileBackend`` (no cloud, no mocks). ``clean_cache_table`` -removes rows written through the facade's own committing session, which the -savepoint-rolled-back ``db_session`` cannot undo. -""" - -from __future__ import annotations - -import pytest -import pytest_asyncio -from sqlalchemy import text - - -@pytest.fixture -def cache_local_storage(tmp_path, monkeypatch): - from app.config import config - from app.etl_pipeline.cache.storage.backend import resolve_cache_backend - - monkeypatch.setattr(config, "ETL_CACHE_STORAGE_BACKEND", "local") - monkeypatch.setattr(config, "ETL_CACHE_STORAGE_LOCAL_PATH", str(tmp_path)) - resolve_cache_backend.cache_clear() - yield tmp_path - resolve_cache_backend.cache_clear() - - -@pytest_asyncio.fixture -async def clean_cache_table(async_engine): - yield - async with async_engine.begin() as conn: - await conn.execute(text("DELETE FROM etl_cache_parses")) diff --git a/surfsense_backend/tests/integration/etl_pipeline/cache/test_cached_extraction.py b/surfsense_backend/tests/integration/etl_pipeline/cache/test_cached_extraction.py deleted file mode 100644 index f9acd02d5..000000000 --- a/surfsense_backend/tests/integration/etl_pipeline/cache/test_cached_extraction.py +++ /dev/null @@ -1,82 +0,0 @@ -"""extract_with_cache end-to-end: real DB + real local storage. - -The only seam mocked is the parser itself (``EtlPipelineService.extract``) -- the -external boundary the facade wraps. Everything else (eligibility, hashing, recall, -remember, blob I/O) runs for real, so these tests prove the actual cost saving: -identical bytes are parsed once and reused. -""" - -from __future__ import annotations - -import pytest - -from app.config import config -from app.etl_pipeline.cache.cached_extraction import extract_with_cache -from app.etl_pipeline.etl_document import EtlRequest, EtlResult, ProcessingMode - -pytestmark = pytest.mark.integration - - -class _CountingParser: - """Stand-in for the external parser; records how often it actually ran.""" - - def __init__(self, **_kwargs) -> None: - pass - - calls = 0 - - async def extract(self, request: EtlRequest) -> EtlResult: - type(self).calls += 1 - return EtlResult( - markdown_content="# Parsed once\n", - etl_service="LLAMACLOUD", - actual_pages=3, - content_type="application/pdf", - ) - - -@pytest.fixture -def counting_parser(monkeypatch): - _CountingParser.calls = 0 - monkeypatch.setattr( - "app.etl_pipeline.cache.cached_extraction.EtlPipelineService", - _CountingParser, - ) - return _CountingParser - - -async def test_identical_uploads_are_parsed_once_then_served_from_cache( - tmp_path, monkeypatch, counting_parser, cache_local_storage, clean_cache_table -): - monkeypatch.setattr(config, "ETL_CACHE_ENABLED", True) - monkeypatch.setattr(config, "ETL_SERVICE", "LLAMACLOUD") - - pdf = tmp_path / "doc.pdf" - pdf.write_bytes(b"%PDF-1.4 unique-bytes-for-this-test") - request = EtlRequest( - file_path=str(pdf), filename="doc.pdf", processing_mode=ProcessingMode.BASIC - ) - - first = await extract_with_cache(request) - second = await extract_with_cache(request) - - assert counting_parser.calls == 1 # second upload reused the cache - assert first.markdown_content == second.markdown_content == "# Parsed once\n" - assert second.actual_pages == 3 - assert second.content_type == "application/pdf" - - -async def test_disabled_cache_parses_every_time(tmp_path, monkeypatch, counting_parser): - monkeypatch.setattr(config, "ETL_CACHE_ENABLED", False) - monkeypatch.setattr(config, "ETL_SERVICE", "LLAMACLOUD") - - pdf = tmp_path / "doc.pdf" - pdf.write_bytes(b"%PDF-1.4 another-unique-payload") - request = EtlRequest( - file_path=str(pdf), filename="doc.pdf", processing_mode=ProcessingMode.BASIC - ) - - await extract_with_cache(request) - await extract_with_cache(request) - - assert counting_parser.calls == 2 # bypassed: no reuse diff --git a/surfsense_backend/tests/integration/etl_pipeline/cache/test_cached_parse_repository.py b/surfsense_backend/tests/integration/etl_pipeline/cache/test_cached_parse_repository.py deleted file mode 100644 index 4665c44c8..000000000 --- a/surfsense_backend/tests/integration/etl_pipeline/cache/test_cached_parse_repository.py +++ /dev/null @@ -1,94 +0,0 @@ -"""CachedParseRepository against real Postgres: the SQL behind eviction & dedup. - -These verify the parts that only a real database can: coldest-first ordering by -reuse then recency, TTL cutoff selection, the size accumulator, and the -insert-once guarantee under a duplicate key. -""" - -from __future__ import annotations - -from datetime import UTC, datetime, timedelta - -import pytest - -from app.etl_pipeline.cache.persistence import CachedParseRepository -from app.etl_pipeline.cache.schemas import ParseKey - -pytestmark = pytest.mark.integration - - -def _key(sha: str) -> ParseKey: - return ParseKey.for_document(sha, etl_service="LLAMACLOUD", mode="basic", version=1) - - -async def _insert(repo, *, sha, size=100, storage_key=None): - key = _key(sha) - await repo.insert( - key=key, - content_type="application/pdf", - actual_pages=1, - storage_backend="local", - storage_key=storage_key or f"etl_cache/{sha}.md", - size_bytes=size, - ) - return key - - -async def test_total_size_bytes_sums_all_rows(db_session): - repo = CachedParseRepository(db_session) - await _insert(repo, sha="a" * 64, size=100) - await _insert(repo, sha="b" * 64, size=250) - - assert await repo.total_size_bytes() == 350 - - -async def test_select_coldest_orders_by_reuse_then_recency(db_session): - repo = CachedParseRepository(db_session) - ka = await _insert(repo, sha="a" * 64) - kb = await _insert(repo, sha="b" * 64) - kc = await _insert(repo, sha="c" * 64) - - # Warm B once and C twice; A stays untouched and should be coldest. - await repo.mark_used((await repo.get(kb)).id) - await repo.mark_used((await repo.get(kc)).id) - await repo.mark_used((await repo.get(kc)).id) - - coldest = await repo.select_coldest(limit=10) - - ids_by_reuse = [c.id for c in coldest] - assert ids_by_reuse[:3] == [ - (await repo.get(ka)).id, - (await repo.get(kb)).id, - (await repo.get(kc)).id, - ] - - -async def test_select_expired_returns_only_rows_older_than_cutoff(db_session): - repo = CachedParseRepository(db_session) - await _insert(repo, sha="a" * 64) - - future = datetime.now(UTC) + timedelta(days=1) - past = datetime.now(UTC) - timedelta(days=1) - - # Row was just used, so it's older than a future cutoff but not a past one. - assert len(await repo.select_expired(cutoff=future, limit=10)) == 1 - assert await repo.select_expired(cutoff=past, limit=10) == [] - - -async def test_duplicate_key_insert_keeps_the_first_row(db_session): - repo = CachedParseRepository(db_session) - key = await _insert(repo, sha="a" * 64, size=100, storage_key="etl_cache/first.md") - - # Same content-addressed key (a concurrent re-parse): must be a no-op. - await repo.insert( - key=key, - content_type="application/pdf", - actual_pages=1, - storage_backend="local", - storage_key="etl_cache/second.md", - size_bytes=999, - ) - - row = await repo.get(key) - assert row.storage_key == "etl_cache/first.md" - assert await repo.total_size_bytes() == 100 diff --git a/surfsense_backend/tests/integration/etl_pipeline/cache/test_etl_cache_service.py b/surfsense_backend/tests/integration/etl_pipeline/cache/test_etl_cache_service.py deleted file mode 100644 index e6041d63e..000000000 --- a/surfsense_backend/tests/integration/etl_pipeline/cache/test_etl_cache_service.py +++ /dev/null @@ -1,65 +0,0 @@ -"""EtlCacheService end-to-end against real Postgres + real local storage. - -Exercises the public cache surface -- ``recall`` / ``remember`` -- with no mocks: -a miss returns nothing, and a remembered parse comes back as an equivalent -``EtlResult`` rebuilt from the row and the blob. -""" - -from __future__ import annotations - -import pytest - -from app.etl_pipeline.cache.schemas import ParseKey -from app.etl_pipeline.cache.service import EtlCacheService -from app.etl_pipeline.etl_document import EtlResult - -pytestmark = pytest.mark.integration - - -def _key(sha: str = "c" * 64) -> ParseKey: - return ParseKey.for_document(sha, etl_service="LLAMACLOUD", mode="basic", version=1) - - -async def test_recall_is_a_miss_for_an_unknown_key(db_session, cache_local_storage): - service = EtlCacheService(db_session) - assert await service.recall(_key()) is None - - -async def test_remembered_parse_recalls_as_equivalent_result( - db_session, cache_local_storage -): - service = EtlCacheService(db_session) - stored = EtlResult( - markdown_content="# Cached doc\n\nBody paragraph.\n", - etl_service="LLAMACLOUD", - actual_pages=7, - content_type="application/pdf", - ) - - await service.remember(_key(), stored) - recalled = await service.recall(_key()) - - assert recalled is not None - assert recalled.markdown_content == stored.markdown_content - assert recalled.etl_service == "LLAMACLOUD" - assert recalled.actual_pages == 7 - assert recalled.content_type == "application/pdf" - - -async def test_repeated_recall_keeps_serving_the_same_content( - db_session, cache_local_storage -): - service = EtlCacheService(db_session) - stored = EtlResult( - markdown_content="# Stable\n", - etl_service="LLAMACLOUD", - actual_pages=1, - content_type="application/pdf", - ) - await service.remember(_key(), stored) - - first = await service.recall(_key()) - second = await service.recall(_key()) - - assert first is not None and second is not None - assert first.markdown_content == second.markdown_content == "# Stable\n" diff --git a/surfsense_backend/tests/integration/etl_pipeline/cache/test_eviction_task.py b/surfsense_backend/tests/integration/etl_pipeline/cache/test_eviction_task.py deleted file mode 100644 index 939ac74a5..000000000 --- a/surfsense_backend/tests/integration/etl_pipeline/cache/test_eviction_task.py +++ /dev/null @@ -1,96 +0,0 @@ -"""The eviction task on real infra: TTL expiry first, then coldest-over-budget. - -Seeds entries through the real cache (DB rows + local blobs), runs the actual -``_evict`` coroutine, and checks what survives via ``recall`` -- no mocks. TTL and -budget are driven through config so each phase can be exercised in isolation. -""" - -from __future__ import annotations - -import pytest - -from app.config import config -from app.etl_pipeline.cache.eviction.task import _evict -from app.etl_pipeline.cache.schemas import ParseKey -from app.etl_pipeline.cache.service import EtlCacheService -from app.etl_pipeline.etl_document import EtlResult -from app.tasks.celery_tasks import get_celery_session_maker - -pytestmark = pytest.mark.integration - - -def _key(sha: str) -> ParseKey: - return ParseKey.for_document(sha, etl_service="LLAMACLOUD", mode="basic", version=1) - - -def _result(markdown: str) -> EtlResult: - return EtlResult( - markdown_content=markdown, - etl_service="LLAMACLOUD", - actual_pages=1, - content_type="application/pdf", - ) - - -async def _remember(key: ParseKey, result: EtlResult) -> None: - async with get_celery_session_maker()() as session: - await EtlCacheService(session).remember(key, result) - - -async def _recall(key: ParseKey) -> EtlResult | None: - async with get_celery_session_maker()() as session: - return await EtlCacheService(session).recall(key) - - -async def test_expired_entries_are_pruned( - monkeypatch, cache_local_storage, clean_cache_table -): - monkeypatch.setattr(config, "ETL_CACHE_ENABLED", True) - monkeypatch.setattr( - config, "ETL_CACHE_TTL_DAYS", -1 - ) # cutoff in the future -> stale - monkeypatch.setattr(config, "ETL_CACHE_MAX_TOTAL_MB", 10_000) # size phase no-op - - key = _key("a" * 64) - await _remember(key, _result("# stale doc\n")) - - await _evict() - - assert await _recall(key) is None - - -async def test_coldest_entries_are_shed_when_over_budget( - monkeypatch, cache_local_storage, clean_cache_table -): - monkeypatch.setattr(config, "ETL_CACHE_ENABLED", True) - monkeypatch.setattr(config, "ETL_CACHE_TTL_DAYS", 3650) # nothing TTL-expired - monkeypatch.setattr(config, "ETL_CACHE_MAX_TOTAL_MB", 1) # ~1 MiB budget - - cold = _key("a" * 64) - warm = _key("b" * 64) - # Two ~0.6 MiB entries together exceed the 1 MiB budget; one must go. - await _remember(cold, _result("x" * 600_000)) - await _remember(warm, _result("y" * 600_000)) - - # A reuse makes `warm` warmer than `cold`, so `cold` is the eviction target. - assert await _recall(warm) is not None - - await _evict() - - assert await _recall(cold) is None - assert await _recall(warm) is not None - - -async def test_nothing_is_evicted_within_ttl_and_budget( - monkeypatch, cache_local_storage, clean_cache_table -): - monkeypatch.setattr(config, "ETL_CACHE_ENABLED", True) - monkeypatch.setattr(config, "ETL_CACHE_TTL_DAYS", 3650) - monkeypatch.setattr(config, "ETL_CACHE_MAX_TOTAL_MB", 10_000) - - key = _key("a" * 64) - await _remember(key, _result("# keep me\n")) - - await _evict() - - assert await _recall(key) is not None diff --git a/surfsense_backend/tests/integration/etl_pipeline/cache/test_markdown_store.py b/surfsense_backend/tests/integration/etl_pipeline/cache/test_markdown_store.py deleted file mode 100644 index a9d685017..000000000 --- a/surfsense_backend/tests/integration/etl_pipeline/cache/test_markdown_store.py +++ /dev/null @@ -1,42 +0,0 @@ -"""MarkdownCacheStore against a real local filesystem backend (no mocks). - -Proves the blob side of the cache: markdown written under a content-addressed key -comes back byte-for-byte, and a delete actually removes it. -""" - -from __future__ import annotations - -import pytest - -from app.etl_pipeline.cache.schemas import ParseKey -from app.etl_pipeline.cache.storage import MarkdownCacheStore -from app.etl_pipeline.cache.storage.object_keys import build_parse_object_key - -pytestmark = pytest.mark.integration - - -def _key() -> ParseKey: - return ParseKey.for_document( - "d" * 64, etl_service="LLAMACLOUD", mode="basic", version=1 - ) - - -async def test_save_then_load_round_trips_markdown(cache_local_storage): - store = MarkdownCacheStore() - markdown = "# Title\n\nBody with unicode: café, naïve, 漢字.\n" - - storage_key = await store.save(_key(), markdown) - - assert storage_key == build_parse_object_key(_key()) - assert await store.load(storage_key) == markdown - - -async def test_delete_removes_the_blob(cache_local_storage): - store = MarkdownCacheStore() - storage_key = await store.save(_key(), "to be deleted") - - await store.delete(storage_key) - - # Eviction deleted the blob; a later read must fail rather than serve stale. - with pytest.raises(FileNotFoundError): - await store.load(storage_key) diff --git a/surfsense_backend/tests/integration/indexing_pipeline/adapters/test_file_upload_adapter.py b/surfsense_backend/tests/integration/indexing_pipeline/adapters/test_file_upload_adapter.py index 814129c8d..311716052 100644 --- a/surfsense_backend/tests/integration/indexing_pipeline/adapters/test_file_upload_adapter.py +++ b/surfsense_backend/tests/integration/indexing_pipeline/adapters/test_file_upload_adapter.py @@ -177,7 +177,7 @@ async def test_reindex_sets_status_ready(db_session, db_search_space, db_user, m async def test_reindex_replaces_chunks(db_session, db_search_space, db_user, mocker): """Reindexing replaces old chunks with new content rather than appending.""" mocker.patch( - "app.indexing_pipeline.cache.cached_indexing.chunk_text_hybrid", + "app.indexing_pipeline.indexing_pipeline_service.chunk_text_hybrid", side_effect=[["Original chunk."], ["Updated chunk."]], ) diff --git a/surfsense_backend/tests/integration/indexing_pipeline/cache/conftest.py b/surfsense_backend/tests/integration/indexing_pipeline/cache/conftest.py deleted file mode 100644 index 6acb457ee..000000000 --- a/surfsense_backend/tests/integration/indexing_pipeline/cache/conftest.py +++ /dev/null @@ -1,33 +0,0 @@ -"""Real-infra fixtures for the embedding-cache integration tests. - -``cache_local_storage`` points the shared cache backend at a throwaway directory -so tests exercise the real ``LocalFileBackend`` (no cloud, no mocks); the -embedding cache reuses the ETL cache backend, hence the ``ETL_CACHE_STORAGE_*`` -knobs. ``clean_embedding_cache_table`` removes rows written through the store's -own committing session, which the savepoint-rolled-back ``db_session`` cannot undo. -""" - -from __future__ import annotations - -import pytest -import pytest_asyncio -from sqlalchemy import text - - -@pytest.fixture -def cache_local_storage(tmp_path, monkeypatch): - from app.config import config - from app.etl_pipeline.cache.storage.backend import resolve_cache_backend - - monkeypatch.setattr(config, "ETL_CACHE_STORAGE_BACKEND", "local") - monkeypatch.setattr(config, "ETL_CACHE_STORAGE_LOCAL_PATH", str(tmp_path)) - resolve_cache_backend.cache_clear() - yield tmp_path - resolve_cache_backend.cache_clear() - - -@pytest_asyncio.fixture -async def clean_embedding_cache_table(async_engine): - yield - async with async_engine.begin() as conn: - await conn.execute(text("DELETE FROM embedding_cache_sets")) diff --git a/surfsense_backend/tests/integration/indexing_pipeline/cache/test_cached_embedding_repository.py b/surfsense_backend/tests/integration/indexing_pipeline/cache/test_cached_embedding_repository.py deleted file mode 100644 index 446932793..000000000 --- a/surfsense_backend/tests/integration/indexing_pipeline/cache/test_cached_embedding_repository.py +++ /dev/null @@ -1,110 +0,0 @@ -"""CachedEmbeddingSetRepository against real Postgres: the SQL behind eviction & dedup. - -These verify the parts only a real database can: the size accumulator, -coldest-first ordering by reuse then recency, TTL cutoff selection, the -insert-once guarantee under a duplicate key, and the reuse counter. -""" - -from __future__ import annotations - -from datetime import UTC, datetime, timedelta - -import pytest - -from app.indexing_pipeline.cache.persistence import CachedEmbeddingSetRepository -from app.indexing_pipeline.cache.schemas import EmbeddingKey - -pytestmark = pytest.mark.integration - - -def _key(sha: str) -> EmbeddingKey: - return EmbeddingKey( - markdown_sha256=sha, - embedding_model="test-model", - embedding_dim=4, - chunker_kind="hybrid", - chunker_version=1, - ) - - -async def _insert(repo, *, sha, size=100, storage_key=None, chunk_count=1): - key = _key(sha) - await repo.insert( - key=key, - storage_backend="local", - storage_key=storage_key or f"embedding_cache/{sha}.emb", - size_bytes=size, - chunk_count=chunk_count, - ) - return key - - -async def test_total_size_bytes_sums_all_rows(db_session): - repo = CachedEmbeddingSetRepository(db_session) - await _insert(repo, sha="a" * 64, size=100) - await _insert(repo, sha="b" * 64, size=250) - - assert await repo.total_size_bytes() == 350 - - -async def test_select_coldest_orders_by_reuse_then_recency(db_session): - repo = CachedEmbeddingSetRepository(db_session) - ka = await _insert(repo, sha="a" * 64) - kb = await _insert(repo, sha="b" * 64) - kc = await _insert(repo, sha="c" * 64) - - # Warm B once and C twice; A stays untouched and should be coldest. - await repo.mark_used((await repo.get(kb)).id) - await repo.mark_used((await repo.get(kc)).id) - await repo.mark_used((await repo.get(kc)).id) - - coldest = await repo.select_coldest(limit=10) - - assert [c.id for c in coldest][:3] == [ - (await repo.get(ka)).id, - (await repo.get(kb)).id, - (await repo.get(kc)).id, - ] - - -async def test_select_expired_returns_only_rows_older_than_cutoff(db_session): - repo = CachedEmbeddingSetRepository(db_session) - await _insert(repo, sha="a" * 64) - - future = datetime.now(UTC) + timedelta(days=1) - past = datetime.now(UTC) - timedelta(days=1) - - # Row was just used, so it predates a future cutoff but not a past one. - assert len(await repo.select_expired(cutoff=future, limit=10)) == 1 - assert await repo.select_expired(cutoff=past, limit=10) == [] - - -async def test_duplicate_key_insert_keeps_the_first_row(db_session): - repo = CachedEmbeddingSetRepository(db_session) - key = await _insert( - repo, sha="a" * 64, size=100, storage_key="embedding_cache/first.emb" - ) - - # Same content-addressed key (a concurrent re-embed): must be a no-op. - await repo.insert( - key=key, - storage_backend="local", - storage_key="embedding_cache/second.emb", - size_bytes=999, - chunk_count=42, - ) - - row = await repo.get(key) - assert row.storage_key == "embedding_cache/first.emb" - assert await repo.total_size_bytes() == 100 - - -async def test_mark_used_increments_reuse_count(db_session): - repo = CachedEmbeddingSetRepository(db_session) - key = await _insert(repo, sha="a" * 64) - assert (await repo.get(key)).times_reused == 0 - - await repo.mark_used((await repo.get(key)).id) - await repo.mark_used((await repo.get(key)).id) - - assert (await repo.get(key)).times_reused == 2 diff --git a/surfsense_backend/tests/integration/indexing_pipeline/cache/test_embedding_cache_service.py b/surfsense_backend/tests/integration/indexing_pipeline/cache/test_embedding_cache_service.py deleted file mode 100644 index 548208131..000000000 --- a/surfsense_backend/tests/integration/indexing_pipeline/cache/test_embedding_cache_service.py +++ /dev/null @@ -1,74 +0,0 @@ -"""EmbeddingCacheService end-to-end against real Postgres + real local storage. - -Exercises the public cache surface -- ``recall`` / ``remember`` -- with no mocks: -a miss returns nothing, a remembered set comes back as equivalent vectors, and a -dimension mismatch is refused rather than served. -""" - -from __future__ import annotations - -import numpy as np -import pytest - -from app.indexing_pipeline.cache.schemas import CachedChunk, EmbeddingKey, EmbeddingSet -from app.indexing_pipeline.cache.service import EmbeddingCacheService - -pytestmark = pytest.mark.integration - - -def _key(sha: str = "c" * 64, *, dim: int = 4) -> EmbeddingKey: - return EmbeddingKey( - markdown_sha256=sha, - embedding_model="test-model", - embedding_dim=dim, - chunker_kind="hybrid", - chunker_version=1, - ) - - -async def test_recall_is_a_miss_for_an_unknown_key(db_session, cache_local_storage): - service = EmbeddingCacheService(db_session) - assert await service.recall(_key()) is None - - -async def test_remembered_set_recalls_as_equivalent_vectors( - db_session, cache_local_storage, clean_embedding_cache_table -): - service = EmbeddingCacheService(db_session) - stored = EmbeddingSet( - summary_embedding=np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32), - chunks=[ - CachedChunk( - "first chunk", np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32) - ), - CachedChunk( - "second chunk", np.array([0.0, 1.0, 0.0, 0.0], dtype=np.float32) - ), - ], - ) - - await service.remember(_key(), stored) - recalled = await service.recall(_key()) - - assert recalled is not None - assert np.array_equal(recalled.summary_embedding, stored.summary_embedding) - assert [c.text for c in recalled.chunks] == ["first chunk", "second chunk"] - assert np.array_equal(recalled.chunks[0].embedding, stored.chunks[0].embedding) - assert np.array_equal(recalled.chunks[1].embedding, stored.chunks[1].embedding) - - -async def test_recall_refuses_a_set_whose_dimension_changed( - db_session, cache_local_storage, clean_embedding_cache_table -): - # A model kept its name but changed its output width: never serve the stale blob. - service = EmbeddingCacheService(db_session) - stored = EmbeddingSet( - summary_embedding=np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32), - chunks=[CachedChunk("c", np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32))], - ) - await service.remember(_key(dim=4), stored) - - # Same identity (model + chunker + markdown), but the caller now expects dim 8. - recalled = await service.recall(_key(dim=8)) - - assert recalled is None diff --git a/surfsense_backend/tests/integration/indexing_pipeline/cache/test_embedding_store.py b/surfsense_backend/tests/integration/indexing_pipeline/cache/test_embedding_store.py deleted file mode 100644 index 83becd7b5..000000000 --- a/surfsense_backend/tests/integration/indexing_pipeline/cache/test_embedding_store.py +++ /dev/null @@ -1,63 +0,0 @@ -"""EmbeddingCacheStore against a real local filesystem backend (no mocks). - -Proves the blob side of the cache: an embedding set written under a -content-addressed key comes back with identical vectors, and a delete actually -removes it. -""" - -from __future__ import annotations - -import numpy as np -import pytest - -from app.indexing_pipeline.cache.schemas import CachedChunk, EmbeddingKey, EmbeddingSet -from app.indexing_pipeline.cache.storage import EmbeddingCacheStore -from app.indexing_pipeline.cache.storage.object_keys import build_embedding_object_key - -pytestmark = pytest.mark.integration - - -def _key() -> EmbeddingKey: - return EmbeddingKey( - markdown_sha256="d" * 64, - embedding_model="test-model", - embedding_dim=4, - chunker_kind="hybrid", - chunker_version=1, - ) - - -def _set() -> EmbeddingSet: - return EmbeddingSet( - summary_embedding=np.array([0.5, 0.25, 0.125, 0.0625], dtype=np.float32), - chunks=[ - CachedChunk("café, naïve, 漢字", np.array([1, 2, 3, 4], dtype=np.float32)), - CachedChunk("second", np.array([5, 6, 7, 8], dtype=np.float32)), - ], - ) - - -async def test_save_then_load_round_trips_the_embedding_set(cache_local_storage): - store = EmbeddingCacheStore() - embedding_set = _set() - - storage_key, size_bytes = await store.save(_key(), embedding_set) - loaded = await store.load(storage_key) - - assert storage_key == build_embedding_object_key(_key()) - assert size_bytes > 0 - assert np.array_equal(loaded.summary_embedding, embedding_set.summary_embedding) - assert [c.text for c in loaded.chunks] == ["café, naïve, 漢字", "second"] - assert np.array_equal(loaded.chunks[0].embedding, embedding_set.chunks[0].embedding) - assert np.array_equal(loaded.chunks[1].embedding, embedding_set.chunks[1].embedding) - - -async def test_delete_removes_the_blob(cache_local_storage): - store = EmbeddingCacheStore() - storage_key, _ = await store.save(_key(), _set()) - - await store.delete(storage_key) - - # Eviction deleted the blob; a later read must fail rather than serve stale. - with pytest.raises(FileNotFoundError): - await store.load(storage_key) diff --git a/surfsense_backend/tests/integration/indexing_pipeline/test_index_editions.py b/surfsense_backend/tests/integration/indexing_pipeline/test_index_editions.py deleted file mode 100644 index 68d5ec0af..000000000 --- a/surfsense_backend/tests/integration/indexing_pipeline/test_index_editions.py +++ /dev/null @@ -1,193 +0,0 @@ -"""Edit path: re-indexing a document diffs chunks instead of replacing them. - -Unchanged paragraphs must keep their chunk rows (ids survive -> embeddings and -HNSW entries untouched), only new text is embedded, removed text is deleted, -and (position) keeps presentation order correct throughout. -""" - -import pytest -from sqlalchemy import select - -from app.db import Chunk, DocumentStatus -from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService - -pytestmark = pytest.mark.integration - -_V1 = "Intro paragraph.\n\nBody paragraph.\n\nOutro paragraph." - - -@pytest.fixture -def paragraph_chunker(monkeypatch): - """One chunk per markdown paragraph, so edits map to chunk-level diffs.""" - - def _split(markdown, **_kwargs): - return [p for p in markdown.split("\n\n") if p.strip()] - - monkeypatch.setattr( - "app.indexing_pipeline.cache.cached_indexing.chunk_text", _split - ) - monkeypatch.setattr( - "app.indexing_pipeline.cache.cached_indexing.chunk_text_hybrid", _split - ) - - -async def _index(service, connector_doc): - prepared = await service.prepare_for_indexing([connector_doc]) - document = prepared[0] - await service.index(document, connector_doc) - return document - - -async def _load_chunks(db_session, document_id): - result = await db_session.execute( - select(Chunk) - .where(Chunk.document_id == document_id) - .order_by(Chunk.position, Chunk.id) - ) - return result.scalars().all() - - -@pytest.mark.usefixtures("paragraph_chunker") -async def test_edit_keeps_unchanged_rows_and_embeds_only_the_new_text( - db_session, - db_search_space, - make_connector_document, - patched_embed_texts, -): - service = IndexingPipelineService(session=db_session) - doc_v1 = make_connector_document( - search_space_id=db_search_space.id, source_markdown=_V1 - ) - document = await _index(service, doc_v1) - - ids_v1 = {c.content: c.id for c in await _load_chunks(db_session, document.id)} - patched_embed_texts.reset_mock() - - edited = "Intro paragraph.\n\nBody paragraph EDITED.\n\nOutro paragraph." - doc_v2 = make_connector_document( - search_space_id=db_search_space.id, source_markdown=edited - ) - await _index(service, doc_v2) - - chunks = await _load_chunks(db_session, document.id) - by_content = {c.content: c for c in chunks} - - # Untouched paragraphs keep their rows (same ids => embeddings reused, - # no HNSW/GIN churn); the edited paragraph got a fresh row. - assert by_content["Intro paragraph."].id == ids_v1["Intro paragraph."] - assert by_content["Outro paragraph."].id == ids_v1["Outro paragraph."] - assert "Body paragraph." not in by_content - assert by_content["Body paragraph EDITED."].id not in ids_v1.values() - - # Exactly one embed call: the document summary plus only the edited text. - (embedded_texts,) = patched_embed_texts.call_args.args - assert embedded_texts == [edited, "Body paragraph EDITED."] - - assert [c.position for c in chunks] == [0, 1, 2] - assert [c.content for c in chunks] == [ - "Intro paragraph.", - "Body paragraph EDITED.", - "Outro paragraph.", - ] - - -@pytest.mark.usefixtures("paragraph_chunker", "patched_embed_texts") -async def test_head_insert_shifts_positions_without_new_rows_for_old_text( - db_session, - db_search_space, - make_connector_document, -): - service = IndexingPipelineService(session=db_session) - document = await _index( - service, - make_connector_document( - search_space_id=db_search_space.id, source_markdown=_V1 - ), - ) - ids_v1 = {c.content: c.id for c in await _load_chunks(db_session, document.id)} - - await _index( - service, - make_connector_document( - search_space_id=db_search_space.id, - source_markdown="Brand new opener.\n\n" + _V1, - ), - ) - - chunks = await _load_chunks(db_session, document.id) - assert [c.content for c in chunks] == [ - "Brand new opener.", - "Intro paragraph.", - "Body paragraph.", - "Outro paragraph.", - ] - assert [c.position for c in chunks] == [0, 1, 2, 3] - # The three original rows survived the shift. - surviving = {c.content: c.id for c in chunks if c.content in ids_v1} - assert surviving == ids_v1 - - -@pytest.mark.usefixtures("paragraph_chunker", "patched_embed_texts") -async def test_removed_paragraph_is_deleted_and_order_compacts( - db_session, - db_search_space, - make_connector_document, -): - service = IndexingPipelineService(session=db_session) - document = await _index( - service, - make_connector_document( - search_space_id=db_search_space.id, source_markdown=_V1 - ), - ) - ids_v1 = {c.content: c.id for c in await _load_chunks(db_session, document.id)} - - await _index( - service, - make_connector_document( - search_space_id=db_search_space.id, - source_markdown="Intro paragraph.\n\nOutro paragraph.", - ), - ) - - chunks = await _load_chunks(db_session, document.id) - assert [(c.content, c.position) for c in chunks] == [ - ("Intro paragraph.", 0), - ("Outro paragraph.", 1), - ] - assert chunks[0].id == ids_v1["Intro paragraph."] - assert chunks[1].id == ids_v1["Outro paragraph."] - - -@pytest.mark.usefixtures("paragraph_chunker", "patched_embed_texts") -async def test_kill_switch_falls_back_to_full_replace( - db_session, - db_search_space, - make_connector_document, - monkeypatch, -): - from app.config import config - - service = IndexingPipelineService(session=db_session) - document = await _index( - service, - make_connector_document( - search_space_id=db_search_space.id, source_markdown=_V1 - ), - ) - ids_v1 = {c.id for c in await _load_chunks(db_session, document.id)} - - monkeypatch.setattr(config, "CHUNK_RECONCILE_ENABLED", False) - await _index( - service, - make_connector_document( - search_space_id=db_search_space.id, - source_markdown=_V1 + "\n\nAppended paragraph.", - ), - ) - - chunks = await _load_chunks(db_session, document.id) - # Legacy behavior: every row is recreated, even unchanged paragraphs. - assert {c.id for c in chunks}.isdisjoint(ids_v1) - assert [c.position for c in chunks] == [0, 1, 2, 3] - assert DocumentStatus.is_state(document.status, DocumentStatus.READY) diff --git a/surfsense_backend/tests/integration/indexing_pipeline/test_local_folder_pipeline.py b/surfsense_backend/tests/integration/indexing_pipeline/test_local_folder_pipeline.py index e37c34388..2cd378343 100644 --- a/surfsense_backend/tests/integration/indexing_pipeline/test_local_folder_pipeline.py +++ b/surfsense_backend/tests/integration/indexing_pipeline/test_local_folder_pipeline.py @@ -961,37 +961,24 @@ class TestDirectConvert: # ==================================================================== -# Tier 8: ETL Credits (CR1-CR6) +# Tier 8: Page Limits (PL1-PL6) # ==================================================================== -class TestEtlCredits: - @pytest.fixture(autouse=True) - def _enable_etl_billing(self, monkeypatch): - """Force ETL credit billing on (off by default for self-hosted/OSS).""" - from app.config import config - - monkeypatch.setattr(config, "ETL_CREDIT_BILLING_ENABLED", True) - - @staticmethod - def _micros(pages: int) -> int: - from app.config import config - - return pages * config.MICROS_PER_PAGE - +class TestPageLimits: @pytest.mark.usefixtures(*UNIFIED_FIXTURES) - async def test_cr1_full_scan_debits_balance( + async def test_pl1_full_scan_increments_pages_used( self, db_session: AsyncSession, db_user: User, db_search_space: SearchSpace, tmp_path: Path, ): - """CR1: Successful full-scan sync debits user.credit_micros_balance.""" + """PL1: Successful full-scan sync increments user.pages_used.""" from app.tasks.connector_indexers.local_folder_indexer import index_local_folder - starting = self._micros(500) - db_user.credit_micros_balance = starting + db_user.pages_used = 0 + db_user.pages_limit = 500 await db_session.flush() (tmp_path / "note.md").write_text("# Hello World\n\nContent here.") @@ -1008,22 +995,21 @@ class TestEtlCredits: assert count == 1 await db_session.refresh(db_user) - assert db_user.credit_micros_balance < starting, ( - "balance should drop after indexing" - ) + assert db_user.pages_used > 0, "pages_used should increase after indexing" @pytest.mark.usefixtures(*UNIFIED_FIXTURES) - async def test_cr2_full_scan_blocked_when_credit_exhausted( + async def test_pl2_full_scan_blocked_when_limit_exhausted( self, db_session: AsyncSession, db_user: User, db_search_space: SearchSpace, tmp_path: Path, ): - """CR2: Full-scan skips file when the wallet is empty.""" + """PL2: Full-scan skips file when page limit is exhausted.""" from app.tasks.connector_indexers.local_folder_indexer import index_local_folder - db_user.credit_micros_balance = 0 + db_user.pages_used = 100 + db_user.pages_limit = 100 await db_session.flush() (tmp_path / "note.md").write_text("# Hello World\n\nContent here.") @@ -1039,23 +1025,21 @@ class TestEtlCredits: assert count == 0 await db_session.refresh(db_user) - assert db_user.credit_micros_balance == 0, ( - "balance should not change on rejection" - ) + assert db_user.pages_used == 100, "pages_used should not change on rejection" @pytest.mark.usefixtures(*UNIFIED_FIXTURES) - async def test_cr3_single_file_debits_balance( + async def test_pl3_single_file_increments_pages_used( self, db_session: AsyncSession, db_user: User, db_search_space: SearchSpace, tmp_path: Path, ): - """CR3: Single-file mode debits balance on success.""" + """PL3: Single-file mode increments user.pages_used on success.""" from app.tasks.connector_indexers.local_folder_indexer import index_local_folder - starting = self._micros(500) - db_user.credit_micros_balance = starting + db_user.pages_used = 0 + db_user.pages_limit = 500 await db_session.flush() (tmp_path / "note.md").write_text("# Hello World\n\nContent here.") @@ -1073,22 +1057,21 @@ class TestEtlCredits: assert count == 1 await db_session.refresh(db_user) - assert db_user.credit_micros_balance < starting, ( - "balance should drop after indexing" - ) + assert db_user.pages_used > 0, "pages_used should increase after indexing" @pytest.mark.usefixtures(*UNIFIED_FIXTURES) - async def test_cr4_single_file_blocked_when_credit_exhausted( + async def test_pl4_single_file_blocked_when_limit_exhausted( self, db_session: AsyncSession, db_user: User, db_search_space: SearchSpace, tmp_path: Path, ): - """CR4: Single-file mode skips file when the wallet is empty.""" + """PL4: Single-file mode skips file when page limit is exhausted.""" from app.tasks.connector_indexers.local_folder_indexer import index_local_folder - db_user.credit_micros_balance = 0 + db_user.pages_used = 100 + db_user.pages_limit = 100 await db_session.flush() (tmp_path / "note.md").write_text("# Hello World\n\nContent here.") @@ -1104,25 +1087,24 @@ class TestEtlCredits: assert count == 0 assert err is not None - assert "credit" in err.lower() + assert "page limit" in err.lower() await db_session.refresh(db_user) - assert db_user.credit_micros_balance == 0, ( - "balance should not change on rejection" - ) + assert db_user.pages_used == 100, "pages_used should not change on rejection" @pytest.mark.usefixtures(*UNIFIED_FIXTURES) - async def test_cr5_unchanged_resync_no_extra_debit( + async def test_pl5_unchanged_resync_no_extra_pages( self, db_session: AsyncSession, db_user: User, db_search_space: SearchSpace, tmp_path: Path, ): - """CR5: Re-syncing an unchanged file does not consume additional credit.""" + """PL5: Re-syncing an unchanged file does not consume additional pages.""" from app.tasks.connector_indexers.local_folder_indexer import index_local_folder - db_user.credit_micros_balance = self._micros(500) + db_user.pages_used = 0 + db_user.pages_limit = 500 await db_session.flush() (tmp_path / "note.md").write_text("# Hello\n\nSame content.") @@ -1137,8 +1119,8 @@ class TestEtlCredits: assert count1 == 1 await db_session.refresh(db_user) - balance_after_first = db_user.credit_micros_balance - assert balance_after_first < self._micros(500) + pages_after_first = db_user.pages_used + assert pages_after_first > 0 count2, _, _, _ = await index_local_folder( session=db_session, @@ -1151,12 +1133,12 @@ class TestEtlCredits: assert count2 == 0 await db_session.refresh(db_user) - assert db_user.credit_micros_balance == balance_after_first, ( - "balance should not change for unchanged files" + assert db_user.pages_used == pages_after_first, ( + "pages_used should not increase for unchanged files" ) @pytest.mark.usefixtures(*UNIFIED_FIXTURES) - async def test_cr6_batch_partial_credit_exhaustion( + async def test_pl6_batch_partial_page_limit_exhaustion( self, db_session: AsyncSession, db_user: User, @@ -1164,11 +1146,11 @@ class TestEtlCredits: tmp_path: Path, patched_batch_sessions, ): - """CR6: Batch mode with a tiny balance: some files succeed, rest fail.""" + """PL6: Batch mode with a very low page limit: some files succeed, rest fail.""" from app.tasks.connector_indexers.local_folder_indexer import index_local_folder - # Exactly one page of credit. - db_user.credit_micros_balance = self._micros(1) + db_user.pages_used = 0 + db_user.pages_limit = 1 await db_session.flush() (tmp_path / "a.md").write_text("File A content") @@ -1189,13 +1171,12 @@ class TestEtlCredits: ) assert count >= 1, "at least one file should succeed" - assert failed >= 1, "at least one file should fail due to insufficient credits" + assert failed >= 1, "at least one file should fail due to page limit" assert count + failed == 3 await db_session.refresh(db_user) - # The wallet was drained by the successful file(s); it may dip slightly - # negative when the actual page count exceeds the pre-check estimate. - assert db_user.credit_micros_balance <= 0 + assert db_user.pages_used > 0 + assert db_user.pages_used <= db_user.pages_limit + 1 # ==================================================================== diff --git a/surfsense_backend/tests/integration/notifications/test_document_processing_handler.py b/surfsense_backend/tests/integration/notifications/test_document_processing_handler.py index 5ca560f11..f602f2e66 100644 --- a/surfsense_backend/tests/integration/notifications/test_document_processing_handler.py +++ b/surfsense_backend/tests/integration/notifications/test_document_processing_handler.py @@ -78,23 +78,3 @@ async def test_processing_completed_failure( assert done.title == "Failed: report.pdf" assert done.message == "Processing failed: bad file" assert done.notification_metadata["status"] == "failed" - - -async def test_processing_started_truncates_long_filename( - db_session: AsyncSession, db_user: User, db_search_space: SearchSpace -): - """A long filename is truncated in the title but kept in metadata.""" - long_name = "a" * 250 - - notification = await handler.notify_processing_started( - session=db_session, - user_id=db_user.id, - document_type="FILE", - document_name=long_name, - search_space_id=db_search_space.id, - ) - - assert len(notification.title) <= 200 - assert notification.title.startswith("Processing: ") - assert notification.title.endswith("...") - assert notification.notification_metadata["document_name"] == long_name diff --git a/surfsense_backend/tests/integration/notifications/test_insufficient_credits_handler.py b/surfsense_backend/tests/integration/notifications/test_page_limit_handler.py similarity index 52% rename from surfsense_backend/tests/integration/notifications/test_insufficient_credits_handler.py rename to surfsense_backend/tests/integration/notifications/test_page_limit_handler.py index bdfa1b30c..ab89d63c9 100644 --- a/surfsense_backend/tests/integration/notifications/test_insufficient_credits_handler.py +++ b/surfsense_backend/tests/integration/notifications/test_page_limit_handler.py @@ -1,4 +1,4 @@ -"""Behavior guard for the insufficient-credits notification handler.""" +"""Behavior guard for the page-limit notification handler.""" from __future__ import annotations @@ -10,50 +10,52 @@ from app.notifications.service import NotificationService pytestmark = pytest.mark.integration -handler = NotificationService.insufficient_credits +handler = NotificationService.page_limit -async def test_insufficient_credits_message_and_action( +async def test_page_limit_message_and_action( db_session: AsyncSession, db_user: User, db_search_space: SearchSpace ): - """An insufficient-credits notification states cost and carries a buy-credits link.""" - notification = await handler.notify_insufficient_credits( + """A page-limit notification states usage and carries an upgrade action link.""" + notification = await handler.notify_page_limit_exceeded( session=db_session, user_id=db_user.id, document_name="short.pdf", document_type="FILE", search_space_id=db_search_space.id, - balance_micros=250_000, - required_micros=1_000_000, + pages_used=95, + pages_limit=100, + pages_to_add=10, ) - assert notification.type == "insufficient_credits" - assert notification.title == "Insufficient credits: short.pdf" + assert notification.type == "page_limit_exceeded" + assert notification.title == "Page limit exceeded: short.pdf" assert notification.message == ( - "This document costs about $1.00 to process but you have " - "$0.25 of credit left. Add more credits to continue." + "This document has ~10 page(s) but you've used 95/100 pages. " + "Upgrade to process more documents." ) assert notification.notification_metadata["status"] == "failed" - assert notification.notification_metadata["action_label"] == "Buy credits" + assert notification.notification_metadata["action_label"] == "Upgrade Plan" assert notification.notification_metadata["action_url"] == ( - f"/dashboard/{db_search_space.id}/buy-more" + f"/dashboard/{db_search_space.id}/more-pages" ) -async def test_insufficient_credits_truncates_long_name( +async def test_page_limit_truncates_long_name( db_session: AsyncSession, db_user: User, db_search_space: SearchSpace ): """A long document name is truncated in the notification title.""" long_name = "a" * 50 - notification = await handler.notify_insufficient_credits( + notification = await handler.notify_page_limit_exceeded( session=db_session, user_id=db_user.id, document_name=long_name, document_type="FILE", search_space_id=db_search_space.id, - balance_micros=250_000, - required_micros=1_000_000, + pages_used=95, + pages_limit=100, + pages_to_add=10, ) - assert notification.title == f"Insufficient credits: {'a' * 40}..." + assert notification.title == f"Page limit exceeded: {'a' * 40}..." diff --git a/surfsense_backend/tests/integration/podcasts/conftest.py b/surfsense_backend/tests/integration/podcasts/conftest.py deleted file mode 100644 index 75248a6a1..000000000 --- a/surfsense_backend/tests/integration/podcasts/conftest.py +++ /dev/null @@ -1,324 +0,0 @@ -"""Podcast API + task integration fixtures. - -The app's DB session and current-user dependencies ride the test's transactional -`db_session`, so seeded rows and rows touched through the endpoints (or the task -bodies) share one transaction that rolls back per test. Only true externals are -faked: the Celery broker (`*_task.delay`) is captured instead of dispatched, the -object store is a tiny in-memory backend, the Celery tasks' own session maker is -bound to the test transaction, and — for the render task — the TTS provider and -the FFmpeg merge are stubbed. `TTS_SERVICE` is pinned so the deterministic brief -proposal can resolve voices. -""" - -from __future__ import annotations - -import contextlib -import uuid -from collections.abc import AsyncGenerator, AsyncIterator -from pathlib import Path - -import httpx -import pytest -import pytest_asyncio -from httpx import ASGITransport -from sqlalchemy.ext.asyncio import AsyncSession - -from app.app import app, limiter -from app.config import config as app_config -from app.db import SearchSpace, User, get_async_session -from app.podcasts.persistence import Podcast, PodcastStatus -from app.podcasts.schemas import ( - DurationTarget, - PodcastSpec, - PodcastStyle, - SpeakerRole, - SpeakerSpec, - Transcript, - TranscriptTurn, -) -from app.podcasts.service import PodcastService -from app.podcasts.tts import SynthesisRequest, SynthesizedAudio, TextToSpeech -from app.routes.search_spaces_routes import create_default_roles_and_membership -from app.users import current_active_user - -pytestmark = pytest.mark.integration - -limiter.enabled = False - - -@pytest_asyncio.fixture -async def client( - db_session: AsyncSession, - db_user: User, -) -> AsyncGenerator[httpx.AsyncClient, None]: - async def override_session() -> AsyncGenerator[AsyncSession, None]: - yield db_session - - async def override_user() -> User: - return db_user - - previous_overrides = app.dependency_overrides.copy() - app.dependency_overrides[get_async_session] = override_session - app.dependency_overrides[current_active_user] = override_user - - try: - async with httpx.AsyncClient( - transport=ASGITransport(app=app), - base_url="http://test", - timeout=30.0, - follow_redirects=False, - ) as test_client: - yield test_client - finally: - app.dependency_overrides.clear() - app.dependency_overrides.update(previous_overrides) - - -@pytest.fixture(autouse=True) -def tts_service(monkeypatch) -> str: - """Pin a provider with language-agnostic voices so brief proposal resolves.""" - service = "openai/tts-1" - monkeypatch.setattr(app_config, "TTS_SERVICE", service) - return service - - -class CapturedTasks: - """Records the args each podcast Celery task was enqueued with.""" - - def __init__(self) -> None: - self.draft: list[tuple] = [] - self.render: list[tuple] = [] - - -@pytest.fixture(autouse=True) -def captured_tasks(monkeypatch) -> CapturedTasks: - """Capture `*_task.delay` instead of hitting the broker (a boundary).""" - captured = CapturedTasks() - from app.podcasts.tasks import draft_transcript_task, render_audio_task - - monkeypatch.setattr( - draft_transcript_task, "delay", lambda *a, **k: captured.draft.append((a, k)) - ) - monkeypatch.setattr( - render_audio_task, "delay", lambda *a, **k: captured.render.append((a, k)) - ) - return captured - - -class FakeStorageBackend: - """In-memory object store standing in for the real audio backend.""" - - backend_name = "memory" - - def __init__(self) -> None: - self.objects: dict[str, bytes] = {} - self.deleted: list[str] = [] - - async def put(self, key: str, data: bytes, content_type: str | None = None) -> None: - self.objects[key] = data - - async def open_stream(self, key: str) -> AsyncIterator[bytes]: - yield self.objects.get(key, b"audio-bytes") - - async def exists(self, key: str) -> bool: - return key in self.objects - - async def delete(self, key: str) -> None: - self.deleted.append(key) - - -@pytest.fixture -def fake_storage(monkeypatch) -> FakeStorageBackend: - """Route audio storage to an in-memory backend for the stream routes.""" - backend = FakeStorageBackend() - monkeypatch.setattr("app.podcasts.storage.get_storage_backend", lambda: backend) - monkeypatch.setattr("app.file_storage.factory.get_storage_backend", lambda: backend) - return backend - - -@pytest.fixture -def bind_task_session(db_session: AsyncSession, monkeypatch) -> AsyncSession: - """Bind the Celery tasks' own session maker to the test transaction. - - Task bodies open ``get_celery_session_maker()()`` rather than receiving a - session, so this hands them the test's session without closing it on exit; a - task's ``commit()`` then releases a savepoint and the per-test rollback still - cleans up. - """ - - def _make_session(): - @contextlib.asynccontextmanager - async def _ctx() -> AsyncIterator[AsyncSession]: - yield db_session - - return _ctx() - - for module in ( - "app.podcasts.tasks.draft", - "app.podcasts.tasks.render", - "app.podcasts.tasks.runtime", - ): - monkeypatch.setattr(f"{module}.get_celery_session_maker", lambda: _make_session) - return db_session - - -class FakeTextToSpeech(TextToSpeech): - """In-memory TTS provider: every segment yields fixed bytes (the boundary). - - Records each request so tests can assert how often synthesis was paid for. - """ - - def __init__(self) -> None: - self.requests: list[SynthesisRequest] = [] - - @property - def container(self) -> str: - return "mp3" - - async def synthesize(self, request: SynthesisRequest) -> SynthesizedAudio: - self.requests.append(request) - return SynthesizedAudio(data=b"segment-audio", container="mp3") - - -@pytest.fixture -def fake_tts(monkeypatch) -> FakeTextToSpeech: - """Stand in for the configured TTS provider in the render task.""" - provider = FakeTextToSpeech() - monkeypatch.setattr( - "app.podcasts.tasks.render.get_text_to_speech", lambda: provider - ) - return provider - - -@pytest.fixture -def fake_merge(monkeypatch) -> None: - """Stub the FFmpeg merge (an external binary) to emit a fixed MP3.""" - - async def _merge(segment_paths: list[Path], output_path: Path) -> None: - output_path.write_bytes(b"merged-audio") - - monkeypatch.setattr("app.podcasts.rendering.renderer.concat_to_mp3", _merge) - - -def build_spec( - *, - language: str = "en", - voice_ids: tuple[str, str] = ("openai:alloy", "openai:nova"), -) -> PodcastSpec: - """A valid two-speaker brief; tests override only what they assert on.""" - return PodcastSpec( - language=language, - style=PodcastStyle.CONVERSATIONAL, - speakers=[ - SpeakerSpec( - slot=0, name="Host", role=SpeakerRole.HOST, voice_id=voice_ids[0] - ), - SpeakerSpec( - slot=1, name="Guest", role=SpeakerRole.GUEST, voice_id=voice_ids[1] - ), - ], - duration=DurationTarget(min_seconds=600, max_seconds=1200), - ) - - -def build_transcript() -> Transcript: - return Transcript( - turns=[ - TranscriptTurn(speaker=0, text="Welcome to the show."), - TranscriptTurn(speaker=1, text="Glad to be here."), - ] - ) - - -@pytest.fixture -def make_podcast(db_session: AsyncSession): - """Create a podcast advanced to a target lifecycle state via the service. - - Setup runs through the same public service the API uses, on the test's - session, so the endpoint under test reads a realistically-built row. - """ - - ladder = [ - PodcastStatus.AWAITING_BRIEF, - PodcastStatus.DRAFTING, - PodcastStatus.RENDERING, - PodcastStatus.READY, - ] - - async def _make( - *, - search_space_id: int, - status: PodcastStatus = PodcastStatus.AWAITING_BRIEF, - title: str = "Test Podcast", - thread_id: int | None = None, - ) -> Podcast: - service = PodcastService(db_session) - podcast = await service.create( - title=title, search_space_id=search_space_id, thread_id=thread_id - ) - if status is PodcastStatus.PENDING: - await db_session.flush() - return podcast - - targets = ladder[: ladder.index(status) + 1] - for target in targets: - if target is PodcastStatus.AWAITING_BRIEF: - await service.attach_brief(podcast, build_spec()) - elif target is PodcastStatus.DRAFTING: - await service.begin_drafting(podcast) - elif target is PodcastStatus.RENDERING: - await service.attach_transcript(podcast, build_transcript()) - elif target is PodcastStatus.READY: - await service.attach_audio( - podcast, - storage_backend="memory", - storage_key="podcasts/audio.mp3", - duration_seconds=123, - ) - await db_session.flush() - return podcast - - return _make - - -@pytest.fixture -def act_as(): - """Switch the authenticated user for subsequent requests on ``client``. - - The ``client`` fixture installs db_user and restores the prior overrides on - teardown, so re-pointing the auth dependency here is undone per test. - """ - - def _act(user: User) -> None: - app.dependency_overrides[current_active_user] = lambda: user - - return _act - - -@pytest_asyncio.fixture -async def db_other_user(db_session: AsyncSession) -> User: - """A second user who is not a member of ``db_search_space``.""" - user = User( - id=uuid.uuid4(), - email="stranger@surfsense.net", - hashed_password="hashed", - is_active=True, - is_superuser=False, - is_verified=True, - ) - db_session.add(user) - await db_session.flush() - return user - - -@pytest_asyncio.fixture -async def foreign_podcast( - db_session: AsyncSession, db_other_user: User, make_podcast -) -> Podcast: - """A podcast in a space owned by the other user, invisible to db_user.""" - space = SearchSpace(name="Stranger Space", user_id=db_other_user.id) - db_session.add(space) - await db_session.flush() - await create_default_roles_and_membership(db_session, space.id, db_other_user.id) - await db_session.flush() - return await make_podcast(search_space_id=space.id, title="Foreign") diff --git a/surfsense_backend/tests/integration/podcasts/test_brief_gate.py b/surfsense_backend/tests/integration/podcasts/test_brief_gate.py deleted file mode 100644 index 46d97172d..000000000 --- a/surfsense_backend/tests/integration/podcasts/test_brief_gate.py +++ /dev/null @@ -1,80 +0,0 @@ -"""The brief review gate: edit the spec, then approve to start drafting. - -Covers what the user can do while ``awaiting_brief`` — edit the brief under -optimistic concurrency and approve it — and the HTTP status codes the service's -guards map to when an edit races or comes too late. -""" - -from __future__ import annotations - -import pytest - -pytestmark = pytest.mark.integration - -BASE = "/api/v1/podcasts" - - -async def _create(client, search_space_id: int) -> dict: - resp = await client.post( - BASE, - json={ - "title": "Episode", - "search_space_id": search_space_id, - "source_content": "Source content.", - }, - ) - assert resp.status_code == 201 - return resp.json() - - -async def test_approve_brief_starts_drafting_and_enqueues_draft( - client, db_search_space, captured_tasks -): - podcast = await _create(client, db_search_space.id) - - resp = await client.post(f"{BASE}/{podcast['id']}/brief/approve") - - assert resp.status_code == 200 - assert resp.json()["status"] == "drafting" - assert captured_tasks.draft == [((podcast["id"], db_search_space.id), {})] - assert captured_tasks.render == [] - - -async def test_update_spec_bumps_version_and_persists(client, db_search_space): - podcast = await _create(client, db_search_space.id) - spec = podcast["spec"] - spec["focus"] = "A sharper angle" - - resp = await client.patch( - f"{BASE}/{podcast['id']}/spec", - json={"spec": spec, "expected_version": podcast["spec_version"]}, - ) - - assert resp.status_code == 200 - body = resp.json() - assert body["spec_version"] == podcast["spec_version"] + 1 - assert body["spec"]["focus"] == "A sharper angle" - assert body["status"] == "awaiting_brief" - - -async def test_update_spec_with_stale_version_conflicts(client, db_search_space): - podcast = await _create(client, db_search_space.id) - - resp = await client.patch( - f"{BASE}/{podcast['id']}/spec", - json={"spec": podcast["spec"], "expected_version": 999}, - ) - - assert resp.status_code == 409 - - -async def test_update_spec_after_approval_is_rejected(client, db_search_space): - podcast = await _create(client, db_search_space.id) - await client.post(f"{BASE}/{podcast['id']}/brief/approve") - - resp = await client.patch( - f"{BASE}/{podcast['id']}/spec", - json={"spec": podcast["spec"], "expected_version": podcast["spec_version"]}, - ) - - assert resp.status_code == 409 diff --git a/surfsense_backend/tests/integration/podcasts/test_cancel.py b/surfsense_backend/tests/integration/podcasts/test_cancel.py deleted file mode 100644 index 4fe4cfc55..000000000 --- a/surfsense_backend/tests/integration/podcasts/test_cancel.py +++ /dev/null @@ -1,57 +0,0 @@ -"""Cancelling a podcast: allowed while in flight, refused once an episode exists. - -Cancellation is the escape hatch for a podcast that has produced nothing yet. -Once a finished episode exists — including during a regeneration, whose audio -survives until a new render commits — cancel is refused (409): reverting the -regeneration is the way back, and no user action may destroy playable audio. -""" - -import pytest - -from app.podcasts.persistence import PodcastStatus - -pytestmark = pytest.mark.integration - -BASE = "/api/v1/podcasts" - - -async def test_cancel_from_a_live_state_succeeds(client, db_search_space, make_podcast): - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.AWAITING_BRIEF - ) - - resp = await client.post(f"{BASE}/{podcast.id}/cancel") - - assert resp.status_code == 200 - assert resp.json()["status"] == "cancelled" - - -async def test_cancel_from_a_terminal_state_conflicts( - client, db_search_space, make_podcast -): - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.READY - ) - - resp = await client.post(f"{BASE}/{podcast.id}/cancel") - - assert resp.status_code == 409 - - -async def test_cancel_of_a_regeneration_is_rejected( - client, db_search_space, make_podcast -): - # Cancelling here would destroy a playable episode; reverting the - # regeneration is the way back. - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.READY - ) - await client.post(f"{BASE}/{podcast.id}/transcript/regenerate") - - resp = await client.post(f"{BASE}/{podcast.id}/cancel") - - assert resp.status_code == 409 - # The regeneration is still revertable afterwards. - follow_up = await client.post(f"{BASE}/{podcast.id}/regenerate/revert") - assert follow_up.status_code == 200 - assert follow_up.json()["status"] == "ready" diff --git a/surfsense_backend/tests/integration/podcasts/test_create.py b/surfsense_backend/tests/integration/podcasts/test_create.py deleted file mode 100644 index 19b5aeca2..000000000 --- a/surfsense_backend/tests/integration/podcasts/test_create.py +++ /dev/null @@ -1,51 +0,0 @@ -"""Creating a podcast proposes a brief and opens the review gate. - -Driven through the real POST endpoint (auth + DB on one transaction): the row is -created, a brief is proposed inline from defaults, and the podcast lands in -``awaiting_brief`` with a complete spec and nothing generated yet. -""" - -from __future__ import annotations - -import pytest - -pytestmark = pytest.mark.integration - -BASE = "/api/v1/podcasts" - - -async def test_create_proposes_brief_and_opens_gate(client, db_search_space): - resp = await client.post( - BASE, - json={ - "title": "My Episode", - "search_space_id": db_search_space.id, - "source_content": "A long piece of source content about a topic.", - }, - ) - - assert resp.status_code == 201 - body = resp.json() - assert body["title"] == "My Episode" - assert body["status"] == "awaiting_brief" - assert body["spec_version"] == 1 - assert body["spec"] is not None - assert body["spec"]["language"] == "en" - assert len(body["spec"]["speakers"]) == 2 - assert body["transcript"] is None - assert body["has_audio"] is False - - -async def test_create_honors_requested_speaker_count(client, db_search_space): - resp = await client.post( - BASE, - json={ - "title": "Solo", - "search_space_id": db_search_space.id, - "source_content": "Content.", - "speaker_count": 3, - }, - ) - - assert resp.status_code == 201 - assert len(resp.json()["spec"]["speakers"]) == 3 diff --git a/surfsense_backend/tests/integration/podcasts/test_draft_task.py b/surfsense_backend/tests/integration/podcasts/test_draft_task.py deleted file mode 100644 index 014d98b1f..000000000 --- a/surfsense_backend/tests/integration/podcasts/test_draft_task.py +++ /dev/null @@ -1,116 +0,0 @@ -"""The transcript-drafting task against a real database. - -Drafting is the expensive LLM step, so it runs under ``billable_call``. The -behavior that protects users' money: when billing succeeds, the drafted -transcript is stored and rendering starts immediately (DRAFTING -> RENDERING, -render task enqueued — the brief gate is the only approval); when billing denies -or settlement fails, the podcast ends FAILED with no transcript left behind. The -DB, service, and transcript persistence run for real; only the true externals -are faked — billing (the metering boundary) and the generation graph (the LLM). -""" - -from __future__ import annotations - -from contextlib import asynccontextmanager -from types import SimpleNamespace -from uuid import uuid4 - -import pytest - -from app.podcasts.persistence import PodcastStatus -from app.podcasts.service import read_transcript -from app.podcasts.tasks import draft -from app.services.billable_calls import ( - BillingSettlementError, - QuotaInsufficientError, -) - -from .conftest import build_transcript - -pytestmark = pytest.mark.integration - - -def _wire_billing(monkeypatch, *, billable_call, transcript=None) -> None: - """Replace the billing + LLM externals the draft body reaches for.""" - - async def _resolver(_session, _search_space_id, *, thread_id=None): - return uuid4(), "free", "openrouter/model" - - async def _ainvoke(_state, config=None): - return {"transcript": transcript} - - monkeypatch.setattr(draft, "_resolve_agent_billing_for_search_space", _resolver) - monkeypatch.setattr(draft, "billable_call", billable_call) - monkeypatch.setattr(draft, "transcript_graph", SimpleNamespace(ainvoke=_ainvoke)) - - -async def test_successful_draft_stores_transcript_and_starts_rendering( - monkeypatch, db_search_space, make_podcast, bind_task_session, captured_tasks -): - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.DRAFTING - ) - - @asynccontextmanager - async def _ok(**_kwargs): - yield SimpleNamespace() - - _wire_billing(monkeypatch, billable_call=_ok, transcript=build_transcript()) - - result = await draft._draft_transcript(podcast.id, db_search_space.id) - - assert result["status"] == "rendering" - assert podcast.status == PodcastStatus.RENDERING - assert read_transcript(podcast) is not None - assert captured_tasks.render == [((podcast.id,), {})] - - -async def test_quota_denial_fails_the_podcast_without_a_transcript( - monkeypatch, db_search_space, make_podcast, bind_task_session -): - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.DRAFTING - ) - - @asynccontextmanager - async def _deny(**_kwargs): - raise QuotaInsufficientError( - usage_type="podcast_generation", - balance_micros=5_000_000, - remaining_micros=0, - ) - yield # pragma: no cover - unreachable, satisfies the CM protocol - - _wire_billing(monkeypatch, billable_call=_deny) - - result = await draft._draft_transcript(podcast.id, db_search_space.id) - - assert result["reason"] == "quota" - assert podcast.status == PodcastStatus.FAILED - assert read_transcript(podcast) is None - - -async def test_billing_settlement_failure_fails_the_podcast( - monkeypatch, db_search_space, make_podcast, bind_task_session -): - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.DRAFTING - ) - - @asynccontextmanager - async def _settlement_fails(**_kwargs): - yield SimpleNamespace() - raise BillingSettlementError( - usage_type="podcast_generation", - user_id=uuid4(), - cause=RuntimeError("finalize failed"), - ) - - _wire_billing( - monkeypatch, billable_call=_settlement_fails, transcript=build_transcript() - ) - - result = await draft._draft_transcript(podcast.id, db_search_space.id) - - assert result["reason"] == "billing" - assert podcast.status == PodcastStatus.FAILED diff --git a/surfsense_backend/tests/integration/podcasts/test_public_stream.py b/surfsense_backend/tests/integration/podcasts/test_public_stream.py deleted file mode 100644 index 63f634234..000000000 --- a/surfsense_backend/tests/integration/podcasts/test_public_stream.py +++ /dev/null @@ -1,80 +0,0 @@ -"""Public (unauthenticated) podcast streaming from a chat snapshot. - -A shared chat snapshot carries each podcast's stored-audio key; the public route -streams those bytes from the object store via ``share_token`` with no auth. A -podcast that isn't in the snapshot is a 404. -""" - -import pytest - -from app.db import NewChatThread, PublicChatSnapshot, User - -pytestmark = pytest.mark.integration - - -async def _snapshot(db_session, *, search_space_id, user: User, token: str, podcasts): - thread = NewChatThread( - title="Shared", search_space_id=search_space_id, created_by_id=user.id - ) - db_session.add(thread) - await db_session.flush() - snapshot = PublicChatSnapshot( - thread_id=thread.id, - share_token=token, - content_hash=f"hash-{token}", - message_ids=[], - snapshot_data={"podcasts": podcasts}, - ) - db_session.add(snapshot) - await db_session.flush() - - -async def test_public_stream_serves_audio_via_storage_key( - client, db_session, db_search_space, db_user, fake_storage -): - await _snapshot( - db_session, - search_space_id=db_search_space.id, - user=db_user, - token="tok-audio", - podcasts=[{"original_id": 555, "storage_key": "podcasts/x.mp3"}], - ) - fake_storage.objects["podcasts/x.mp3"] = b"public-audio" - - resp = await client.get("/api/v1/public/tok-audio/podcasts/555/stream") - - assert resp.status_code == 200 - assert resp.headers["content-type"] == "audio/mpeg" - assert resp.content == b"public-audio" - - -async def test_public_stream_404_when_object_missing( - client, db_session, db_search_space, db_user, fake_storage -): - await _snapshot( - db_session, - search_space_id=db_search_space.id, - user=db_user, - token="tok-gone", - podcasts=[{"original_id": 556, "storage_key": "podcasts/gone.mp3"}], - ) - - resp = await client.get("/api/v1/public/tok-gone/podcasts/556/stream") - - assert resp.status_code == 404 - - -async def test_public_stream_404_when_podcast_absent_from_snapshot( - client, db_session, db_search_space, db_user -): - await _snapshot( - db_session, - search_space_id=db_search_space.id, - user=db_user, - token="tok-empty", - podcasts=[], - ) - - resp = await client.get("/api/v1/public/tok-empty/podcasts/999/stream") - - assert resp.status_code == 404 diff --git a/surfsense_backend/tests/integration/podcasts/test_regeneration.py b/surfsense_backend/tests/integration/podcasts/test_regeneration.py deleted file mode 100644 index fd31df4ca..000000000 --- a/surfsense_backend/tests/integration/podcasts/test_regeneration.py +++ /dev/null @@ -1,202 +0,0 @@ -"""Regeneration: the listen-then-redo loop after the brief gate. - -A user who dislikes the finished audio sends the episode back to the brief -gate: the saved brief reopens for tweaks (voices, length, focus) and drafting -only restarts on a fresh approval. The whole redo can also be reverted at any -point before the new render commits, falling back to the still-stored episode. -These pin the READY -> AWAITING_BRIEF -> DRAFTING round trip, the revert -fallback, and the 409s for acting from states that have nothing to redo or -revert. -""" - -from __future__ import annotations - -import pytest - -from app.podcasts.persistence import Podcast, PodcastStatus -from app.podcasts.service import PodcastService - -from .conftest import build_transcript - -pytestmark = pytest.mark.integration - -BASE = "/api/v1/podcasts" - - -async def test_regenerate_from_ready_reopens_the_brief_gate( - client, db_search_space, make_podcast, captured_tasks -): - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.READY - ) - - resp = await client.post(f"{BASE}/{podcast.id}/transcript/regenerate") - - assert resp.status_code == 200 - body = resp.json() - assert body["status"] == "awaiting_brief" - # The prior brief is kept as the starting point for the new take. - assert body["spec"] is not None - # Nothing drafts until the user approves the reopened brief. - assert captured_tasks.draft == [] - assert captured_tasks.render == [] - - -async def test_approving_the_reopened_brief_starts_a_fresh_draft( - client, db_search_space, make_podcast, captured_tasks -): - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.READY - ) - await client.post(f"{BASE}/{podcast.id}/transcript/regenerate") - - resp = await client.post(f"{BASE}/{podcast.id}/brief/approve") - - assert resp.status_code == 200 - assert resp.json()["status"] == "drafting" - assert captured_tasks.draft == [((podcast.id, db_search_space.id), {})] - - -async def test_regenerate_from_brief_gate_is_rejected( - client, db_search_space, make_podcast, captured_tasks -): - # Nothing has been drafted yet, so there is nothing to regenerate. - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.AWAITING_BRIEF - ) - - resp = await client.post(f"{BASE}/{podcast.id}/transcript/regenerate") - - assert resp.status_code == 409 - assert captured_tasks.draft == [] - - -async def test_regenerate_from_cancelled_is_rejected( - client, db_search_space, make_podcast, captured_tasks -): - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.AWAITING_BRIEF - ) - await client.post(f"{BASE}/{podcast.id}/cancel") - - resp = await client.post(f"{BASE}/{podcast.id}/transcript/regenerate") - - assert resp.status_code == 409 - assert captured_tasks.draft == [] - - -async def test_reverting_a_regeneration_restores_the_ready_episode( - client, db_search_space, make_podcast, captured_tasks -): - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.READY - ) - await client.post(f"{BASE}/{podcast.id}/transcript/regenerate") - - resp = await client.post(f"{BASE}/{podcast.id}/regenerate/revert") - - assert resp.status_code == 200 - body = resp.json() - assert body["status"] == "ready" - # The episode the user could already play is untouched. - assert body["has_audio"] is True - assert captured_tasks.draft == [] - assert captured_tasks.render == [] - - -async def test_reverting_mid_draft_keeps_the_episode( - client, db_search_space, make_podcast -): - # Changing one's mind is allowed even after the reopened brief was - # approved: the episode survives until a new render replaces it. - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.READY - ) - await client.post(f"{BASE}/{podcast.id}/transcript/regenerate") - await client.post(f"{BASE}/{podcast.id}/brief/approve") - - resp = await client.post(f"{BASE}/{podcast.id}/regenerate/revert") - - assert resp.status_code == 200 - assert resp.json()["status"] == "ready" - - -async def test_reverting_mid_render_keeps_the_episode( - client, db_session, db_search_space, make_podcast -): - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.READY - ) - service = PodcastService(db_session) - await service.regenerate(podcast) - await service.begin_drafting(podcast) - await service.attach_transcript(podcast, build_transcript()) - - resp = await client.post(f"{BASE}/{podcast.id}/regenerate/revert") - - assert resp.status_code == 200 - assert resp.json()["status"] == "ready" - - -async def test_reverted_episode_can_be_regenerated_again( - client, db_search_space, make_podcast -): - # Reverting must not strand the episode: the user can change their mind - # again immediately. - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.READY - ) - await client.post(f"{BASE}/{podcast.id}/transcript/regenerate") - await client.post(f"{BASE}/{podcast.id}/regenerate/revert") - - resp = await client.post(f"{BASE}/{podcast.id}/transcript/regenerate") - - assert resp.status_code == 200 - assert resp.json()["status"] == "awaiting_brief" - - -async def test_revert_on_a_fresh_brief_gate_is_rejected( - client, db_search_space, make_podcast -): - # A first-time brief has no regeneration to revert. - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.AWAITING_BRIEF - ) - - resp = await client.post(f"{BASE}/{podcast.id}/regenerate/revert") - - assert resp.status_code == 409 - assert resp.json()["detail"] - - -async def test_revert_when_nothing_was_regenerated_is_rejected( - client, db_search_space, make_podcast -): - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.READY - ) - - resp = await client.post(f"{BASE}/{podcast.id}/regenerate/revert") - - assert resp.status_code == 409 - - -async def test_regenerate_without_a_brief_is_rejected( - client, db_session, db_search_space, captured_tasks -): - # Legacy episodes finished before briefs existed; reopening a gate with - # nothing to review would strand them there. - podcast = Podcast( - title="Legacy Episode", - search_space_id=db_search_space.id, - status=PodcastStatus.READY, - spec_version=1, - file_location="/var/old/podcast.mp3", - ) - db_session.add(podcast) - await db_session.flush() - - resp = await client.post(f"{BASE}/{podcast.id}/transcript/regenerate") - - assert resp.status_code == 422 - assert captured_tasks.draft == [] diff --git a/surfsense_backend/tests/integration/podcasts/test_render_task.py b/surfsense_backend/tests/integration/podcasts/test_render_task.py deleted file mode 100644 index 5a97a00c7..000000000 --- a/surfsense_backend/tests/integration/podcasts/test_render_task.py +++ /dev/null @@ -1,100 +0,0 @@ -"""The audio-rendering task against a real database. - -From RENDERING, the task synthesises and merges the approved transcript, stores -the bytes, and marks the podcast READY with the storage location recorded. The -DB, service, renderer orchestration, and storage wrapper run for real; the true -externals are faked — the TTS provider, the FFmpeg merge, and the object store. -""" - -from __future__ import annotations - -import pytest - -from app.podcasts.persistence import PodcastStatus -from app.podcasts.service import PodcastService -from app.podcasts.tasks import render - -from .conftest import build_transcript - -pytestmark = pytest.mark.integration - - -async def test_render_marks_ready_and_stores_audio( - db_search_space, make_podcast, bind_task_session, fake_tts, fake_merge, fake_storage -): - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.RENDERING - ) - - result = await render._render_audio(podcast.id) - - assert result["status"] == "ready" - assert podcast.status == PodcastStatus.READY - assert podcast.storage_backend == "memory" - assert podcast.storage_key - assert fake_storage.objects[podcast.storage_key] == b"merged-audio" - - -async def test_rerender_replaces_audio_and_purges_the_old_object( - db_session, - db_search_space, - make_podcast, - bind_task_session, - fake_tts, - fake_merge, - fake_storage, -): - # A regenerated episode keeps exactly one stored object: the new render - # must not leak the superseded audio in the object store. - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.READY - ) - old_key = podcast.storage_key - fake_storage.objects[old_key] = b"old-audio" - - service = PodcastService(db_session) - await service.regenerate(podcast) - await service.begin_drafting(podcast) - await service.attach_transcript(podcast, build_transcript()) - - result = await render._render_audio(podcast.id) - - assert result["status"] == "ready" - assert podcast.status == PodcastStatus.READY - assert podcast.storage_key != old_key - assert fake_storage.objects[podcast.storage_key] == b"merged-audio" - assert old_key in fake_storage.deleted - - -async def test_render_losing_to_a_user_revert_keeps_the_episode_and_leaks_nothing( - db_session, - db_search_space, - make_podcast, - bind_task_session, - fake_tts, - fake_merge, - fake_storage, -): - # The user reverts the regeneration while the render is in flight: the - # stale render must neither resurrect the redo nor leak the object it - # already stored. - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.READY - ) - old_key = podcast.storage_key - fake_storage.objects[old_key] = b"old-audio" - - service = PodcastService(db_session) - await service.regenerate(podcast) - await service.begin_drafting(podcast) - await service.attach_transcript(podcast, build_transcript()) - await service.revert_regeneration(podcast) - - result = await render._render_audio(podcast.id) - - assert result["status"] == "superseded" - assert podcast.status == PodcastStatus.READY - assert podcast.storage_key == old_key - assert old_key not in fake_storage.deleted - stale_keys = [key for key in fake_storage.objects if key != old_key] - assert all(key in fake_storage.deleted for key in stale_keys) diff --git a/surfsense_backend/tests/integration/podcasts/test_scoping.py b/surfsense_backend/tests/integration/podcasts/test_scoping.py deleted file mode 100644 index 304af6b6e..000000000 --- a/surfsense_backend/tests/integration/podcasts/test_scoping.py +++ /dev/null @@ -1,53 +0,0 @@ -"""Podcasts are scoped to search-space membership. - -A user can only create or read podcasts in spaces they belong to, and an -unscoped listing returns only the caller's own podcasts — never another -member's. -""" - -import pytest - -pytestmark = pytest.mark.integration - -BASE = "/api/v1/podcasts" - - -async def test_reading_a_podcast_in_a_nonmember_space_is_forbidden( - client, db_search_space, make_podcast, act_as, db_other_user -): - podcast = await make_podcast(search_space_id=db_search_space.id) - act_as(db_other_user) - - resp = await client.get(f"{BASE}/{podcast.id}") - - assert resp.status_code == 403 - - -async def test_creating_in_a_nonmember_space_is_forbidden( - client, db_search_space, act_as, db_other_user -): - act_as(db_other_user) - - resp = await client.post( - BASE, - json={ - "title": "X", - "search_space_id": db_search_space.id, - "source_content": "content", - }, - ) - - assert resp.status_code == 403 - - -async def test_listing_returns_only_the_callers_podcasts( - client, db_search_space, make_podcast, foreign_podcast -): - mine = await make_podcast(search_space_id=db_search_space.id, title="Mine") - - resp = await client.get(BASE) - - assert resp.status_code == 200 - ids = {p["id"] for p in resp.json()} - assert mine.id in ids - assert foreign_podcast.id not in ids diff --git a/surfsense_backend/tests/integration/podcasts/test_streaming.py b/surfsense_backend/tests/integration/podcasts/test_streaming.py deleted file mode 100644 index b924e2971..000000000 --- a/surfsense_backend/tests/integration/podcasts/test_streaming.py +++ /dev/null @@ -1,52 +0,0 @@ -"""Streaming a podcast's rendered audio over HTTP. - -A ready podcast streams its bytes; an in-flight one is 409, a stored-but-missing -object is 404. Storage is an in-memory backend (the object store is a boundary). -""" - -from __future__ import annotations - -import pytest - -from app.podcasts.persistence import PodcastStatus - -pytestmark = pytest.mark.integration - -BASE = "/api/v1/podcasts" - - -async def test_stream_serves_stored_audio( - client, db_search_space, make_podcast, fake_storage -): - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.READY - ) - fake_storage.objects["podcasts/audio.mp3"] = b"the-audio" - - resp = await client.get(f"{BASE}/{podcast.id}/stream") - - assert resp.status_code == 200 - assert resp.headers["content-type"] == "audio/mpeg" - assert resp.content == b"the-audio" - - -async def test_stream_409_while_in_flight(client, db_search_space, make_podcast): - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.DRAFTING - ) - - resp = await client.get(f"{BASE}/{podcast.id}/stream") - - assert resp.status_code == 409 - - -async def test_stream_404_when_object_missing( - client, db_search_space, make_podcast, fake_storage -): - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.READY - ) - - resp = await client.get(f"{BASE}/{podcast.id}/stream") - - assert resp.status_code == 404 diff --git a/surfsense_backend/tests/integration/podcasts/test_task_failure.py b/surfsense_backend/tests/integration/podcasts/test_task_failure.py deleted file mode 100644 index 43212f58f..000000000 --- a/surfsense_backend/tests/integration/podcasts/test_task_failure.py +++ /dev/null @@ -1,45 +0,0 @@ -"""The task failure safety net (``mark_failed``) against a real database. - -When a task body raises, ``mark_failed`` records the reason on the row. Its -contract has two halves worth securing: a still-running podcast moves to FAILED -with the reason, while one that already reached a terminal state is left exactly -as it was rather than forced. A missing row is a no-op, never a crash. -""" - -from __future__ import annotations - -import pytest - -from app.podcasts.persistence import PodcastStatus -from app.podcasts.tasks import runtime - -pytestmark = pytest.mark.integration - - -async def test_marking_failed_records_the_reason_on_a_running_podcast( - db_search_space, make_podcast, bind_task_session -): - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.DRAFTING - ) - - await runtime.mark_failed(podcast.id, "tts provider unavailable") - - assert podcast.status == PodcastStatus.FAILED - assert podcast.error == "tts provider unavailable" - - -async def test_marking_failed_leaves_an_already_terminal_podcast_untouched( - db_search_space, make_podcast, bind_task_session -): - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.READY - ) - - await runtime.mark_failed(podcast.id, "too late") - - assert podcast.status == PodcastStatus.READY - - -async def test_marking_a_missing_podcast_failed_is_a_no_op(bind_task_session): - await runtime.mark_failed(987654321, "gone") # must not raise diff --git a/surfsense_backend/tests/integration/podcasts/test_voice_preview.py b/surfsense_backend/tests/integration/podcasts/test_voice_preview.py deleted file mode 100644 index 113172bee..000000000 --- a/surfsense_backend/tests/integration/podcasts/test_voice_preview.py +++ /dev/null @@ -1,69 +0,0 @@ -"""Audible voice previews for the brief gate's voice picker. - -A user choosing voices should hear them, not guess from names. The endpoint -synthesises a short sample for a catalog voice and caches it on disk so each -voice is paid for at most once per process lifetime. Unknown voices and voices -of an inactive provider are 404; no configured TTS is 503. -""" - -from __future__ import annotations - -import pytest - -from app.config import config as app_config - -from .conftest import FakeTextToSpeech - -pytestmark = pytest.mark.integration - -BASE = "/api/v1/podcasts" - - -@pytest.fixture -def preview_tts(monkeypatch, tmp_path) -> FakeTextToSpeech: - """Route preview synthesis to the fake provider and an isolated cache.""" - provider = FakeTextToSpeech() - monkeypatch.setattr("app.podcasts.api.routes.get_text_to_speech", lambda: provider) - monkeypatch.setattr("app.podcasts.voices.preview.PREVIEW_CACHE_ROOT", tmp_path) - return provider - - -async def test_preview_returns_playable_audio_for_a_catalog_voice(client, preview_tts): - resp = await client.get(f"{BASE}/voices/openai:alloy/preview") - - assert resp.status_code == 200 - assert resp.headers["content-type"] == "audio/mpeg" - assert resp.content == b"segment-audio" - - -async def test_preview_is_synthesised_once_then_served_from_cache(client, preview_tts): - first = await client.get(f"{BASE}/voices/openai:alloy/preview") - second = await client.get(f"{BASE}/voices/openai:alloy/preview") - - assert first.status_code == second.status_code == 200 - assert second.content == first.content - assert len(preview_tts.requests) == 1 - - -async def test_preview_unknown_voice_is_404(client, preview_tts): - resp = await client.get(f"{BASE}/voices/openai:nope/preview") - - assert resp.status_code == 404 - assert preview_tts.requests == [] - - -async def test_preview_voice_of_inactive_provider_is_404(client, preview_tts): - # The active provider is OpenAI (pinned in conftest); a Kokoro voice exists - # in the catalog but cannot be heard through the configured provider. - resp = await client.get(f"{BASE}/voices/kokoro:af_heart/preview") - - assert resp.status_code == 404 - assert preview_tts.requests == [] - - -async def test_preview_without_tts_provider_is_503(client, preview_tts, monkeypatch): - monkeypatch.setattr(app_config, "TTS_SERVICE", None) - - resp = await client.get(f"{BASE}/voices/openai:alloy/preview") - - assert resp.status_code == 503 diff --git a/surfsense_backend/tests/integration/podcasts/test_voices.py b/surfsense_backend/tests/integration/podcasts/test_voices.py deleted file mode 100644 index fd41bfd4e..000000000 --- a/surfsense_backend/tests/integration/podcasts/test_voices.py +++ /dev/null @@ -1,51 +0,0 @@ -"""GET /podcasts/voices: the active provider's catalog, or 503 if unconfigured. - -The brief UI needs the voices the configured TTS provider offers; with no -provider configured there is nothing to choose from, which is a 503 rather than -an empty list. -""" - -import pytest - -from app.config import config as app_config - -pytestmark = pytest.mark.integration - -BASE = "/api/v1/podcasts" - - -async def test_voices_returns_the_active_providers_catalog(client): - resp = await client.get(f"{BASE}/voices") - - assert resp.status_code == 200 - voices = resp.json() - assert voices # openai/tts-1 offers voices - assert {"voice_id", "display_name", "language", "gender"} <= voices[0].keys() - - -async def test_voices_503_when_no_tts_configured(client, monkeypatch): - monkeypatch.setattr(app_config, "TTS_SERVICE", "") - - resp = await client.get(f"{BASE}/voices") - - assert resp.status_code == 503 - - -async def test_languages_returns_the_active_providers_offering(client): - """The brief form renders exactly what the backend offers — for a wildcard - provider (openai/tts-1) that is the curated list plus free entry.""" - resp = await client.get(f"{BASE}/languages") - - assert resp.status_code == 200 - offering = resp.json() - assert "en" in offering["languages"] - assert "fr" in offering["languages"] - assert offering["allows_custom"] is True - - -async def test_languages_503_when_no_tts_configured(client, monkeypatch): - monkeypatch.setattr(app_config, "TTS_SERVICE", "") - - resp = await client.get(f"{BASE}/languages") - - assert resp.status_code == 503 diff --git a/surfsense_backend/tests/integration/test_connector_index_authz.py b/surfsense_backend/tests/integration/test_connector_index_authz.py deleted file mode 100644 index cea2407cc..000000000 --- a/surfsense_backend/tests/integration/test_connector_index_authz.py +++ /dev/null @@ -1,148 +0,0 @@ -"""Cross-search-space authorization on the connector index endpoint. - -``POST /search-source-connectors/{connector_id}/index?search_space_id=<X>`` must -authorize against the **connector's own** ``search_space_id`` (matching the -read/update/delete handlers), not the caller-supplied ``search_space_id`` query -parameter, and must reject a connector that does not belong to the requested -search space. - -Without this, a user who owns search space B could index another user's -connector (which lives in space A) by passing ``search_space_id=B``: the -background indexer would run with the **victim connector's stored credentials** -and write the fetched content into the attacker's space. These tests pin that -boundary. -""" - -from __future__ import annotations - -import contextlib -import uuid -from unittest.mock import AsyncMock, patch - -import pytest -from fastapi import HTTPException -from sqlalchemy.ext.asyncio import AsyncSession - -from app.db import ( - SearchSourceConnector, - SearchSourceConnectorType, - SearchSpace, - User, -) -from app.routes.search_source_connectors_routes import index_connector_content -from app.routes.search_spaces_routes import create_default_roles_and_membership - -pytestmark = pytest.mark.integration - -# The handler imports ``check_permission`` into its own module namespace. -_CHECK_PERMISSION = "app.routes.search_source_connectors_routes.check_permission" - - -async def _make_user_with_space(session: AsyncSession) -> tuple[User, SearchSpace]: - """A user plus a search space they own, with the default roles/membership - the ``POST /searchspaces`` route would create (so ``check_permission`` would - legitimately pass for this user on this space).""" - user = User( - id=uuid.uuid4(), - email=f"authz-{uuid.uuid4()}@surfsense.test", - hashed_password="x", - is_active=True, - is_superuser=False, - is_verified=True, - ) - session.add(user) - await session.flush() - space = SearchSpace(name=f"Space {uuid.uuid4().hex[:8]}", user_id=user.id) - session.add(space) - await session.flush() - await create_default_roles_and_membership(session, space.id, user.id) - await session.flush() - return user, space - - -async def _make_connector( - session: AsyncSession, - owner: User, - space: SearchSpace, - connector_type: SearchSourceConnectorType, -) -> SearchSourceConnector: - connector = SearchSourceConnector( - name="Connector", - connector_type=connector_type, - # A stored credential the indexer would use — the thing a cross-tenant - # index must never be able to abuse. - config={ - "GITHUB_PAT": "victim-secret-pat", - "repo_full_names": ["octocat/Hello-World"], - }, - is_indexable=True, - search_space_id=space.id, - user_id=owner.id, - ) - session.add(connector) - await session.flush() - return connector - - -class TestConnectorIndexCrossSpaceAuthz: - async def test_cross_space_index_is_rejected_before_permission_check( - self, db_session: AsyncSession - ): - """Attacker (owns space B) cannot index victim's connector (in space A) - by passing ``search_space_id=B``. - - The mismatch is rejected with 404 **before** ``check_permission`` runs — - which is essential, because that permission check *would* pass: the - attacker legitimately holds ``CONNECTORS_UPDATE`` on their own space B. - """ - victim, space_a = await _make_user_with_space(db_session) - attacker, space_b = await _make_user_with_space(db_session) - connector_a = await _make_connector( - db_session, victim, space_a, SearchSourceConnectorType.GITHUB_CONNECTOR - ) - - with ( - patch(_CHECK_PERMISSION, new=AsyncMock()) as check_permission_mock, - pytest.raises(HTTPException) as exc_info, - ): - await index_connector_content( - connector_id=connector_a.id, - search_space_id=space_b.id, # the attacker's own space - session=db_session, - user=attacker, - ) - - assert exc_info.value.status_code == 404 - # Rejected at the search-space reconciliation, never reaching (or relying - # on) the permission check — which would have passed for space B. - check_permission_mock.assert_not_awaited() - - async def test_same_space_index_authorizes_against_the_connectors_own_space( - self, db_session: AsyncSession - ): - """A legitimate same-space index passes the reconciliation and authorizes - ``check_permission`` against the connector's **own** search space (not the - client-supplied query param).""" - owner, space = await _make_user_with_space(db_session) - # A "live" connector type returns early (no Celery dispatch) right after - # the permission check, so the call exercises the authz path cleanly. - connector = await _make_connector( - db_session, owner, space, SearchSourceConnectorType.CLICKUP_CONNECTOR - ) - - # Any downstream indexing behaviour is irrelevant to the authz contract - # under test; we only assert what space was authorized. - with ( - patch(_CHECK_PERMISSION, new=AsyncMock()) as check_permission_mock, - contextlib.suppress(Exception), - ): - await index_connector_content( - connector_id=connector.id, - search_space_id=space.id, # the connector's own space - session=db_session, - user=owner, - ) - - check_permission_mock.assert_awaited_once() - # The space passed to check_permission must be the connector's own space. - assert connector.search_space_id in check_permission_mock.await_args.args diff --git a/surfsense_backend/tests/unit/agents/chat/runtime/test_llm_config_sanitizer.py b/surfsense_backend/tests/unit/agents/chat/runtime/test_llm_config_sanitizer.py deleted file mode 100644 index e987f8441..000000000 --- a/surfsense_backend/tests/unit/agents/chat/runtime/test_llm_config_sanitizer.py +++ /dev/null @@ -1,39 +0,0 @@ -"""Regression tests for model-boundary message sanitization.""" - -from __future__ import annotations - -import pytest -from langchain_core.messages import AIMessage - -from app.agents.chat.runtime.llm_config import _sanitize_messages - -pytestmark = pytest.mark.unit - - -def test_sanitize_messages_strips_provider_specific_thinking_blocks() -> None: - original = AIMessage( - content=[ - {"type": "thinking", "thinking": "private reasoning"}, - {"type": "text", "text": "visible answer"}, - ] - ) - - sanitized = _sanitize_messages([original]) - - assert sanitized[0].content == "visible answer" - assert original.content == [ - {"type": "thinking", "thinking": "private reasoning"}, - {"type": "text", "text": "visible answer"}, - ] - - -def test_sanitize_messages_sets_tool_only_ai_content_to_none() -> None: - message = AIMessage( - content="", - tool_calls=[{"name": "search", "args": {"q": "x"}, "id": "call_1"}], - ) - - sanitized = _sanitize_messages([message]) - - assert sanitized[0].content is None - assert message.content == "" diff --git a/surfsense_backend/tests/unit/automations/actions/builtin/agent_task/test_dependencies.py b/surfsense_backend/tests/unit/automations/actions/builtin/agent_task/test_dependencies.py index f5709e517..79da12933 100644 --- a/surfsense_backend/tests/unit/automations/actions/builtin/agent_task/test_dependencies.py +++ b/surfsense_backend/tests/unit/automations/actions/builtin/agent_task/test_dependencies.py @@ -1,6 +1,6 @@ """Lock the runtime model-policy backstop in ``build_dependencies``. -Automations resolve their LLM from the *captured* ``chat_model_id`` snapshot (so +Automations resolve their LLM from the *captured* ``agent_llm_id`` snapshot (so runs are insulated from later chat/search-space model changes), and the model policy is re-checked at run time so a captured model that is no longer billable fails the run clearly. When no snapshot is present, resolution falls back to the @@ -45,10 +45,10 @@ def patched_side_effects(monkeypatch: pytest.MonkeyPatch): return None -async def test_build_dependencies_resolves_captured_chat_model_id( +async def test_build_dependencies_resolves_captured_agent_llm_id( monkeypatch: pytest.MonkeyPatch, patched_side_effects ) -> None: - """The bundle loads with the *captured* ``chat_model_id``, not the live search space.""" + """The bundle loads with the *captured* ``agent_llm_id``, not the live search space.""" captured: dict[str, Any] = {} async def _fake_load(_session, *, config_id, search_space_id): @@ -67,13 +67,13 @@ async def test_build_dependencies_resolves_captured_chat_model_id( lambda _ss: pytest.fail("search-space policy should not run on captured path"), ) - search_space = SimpleNamespace(chat_model_id=-99) + search_space = SimpleNamespace(agent_llm_id=-99) result = await build_dependencies( session=_FakeSession(search_space), search_space_id=42, - chat_model_id=-7, - image_gen_model_id=5, - vision_model_id=-1, + agent_llm_id=-7, + image_generation_config_id=5, + vision_llm_config_id=-1, ) assert captured == {"config_id": -7, "search_space_id": 42} @@ -98,17 +98,17 @@ async def test_build_dependencies_validates_captured_ids( monkeypatch.setattr(deps_mod, "load_llm_bundle", _fake_load) await build_dependencies( - session=_FakeSession(SimpleNamespace(chat_model_id=0)), + session=_FakeSession(SimpleNamespace(agent_llm_id=0)), search_space_id=42, - chat_model_id=-7, - image_gen_model_id=5, - vision_model_id=-1, + agent_llm_id=-7, + image_generation_config_id=5, + vision_llm_config_id=-1, ) assert seen == { - "chat_model_id": -7, - "image_gen_model_id": 5, - "vision_model_id": -1, + "agent_llm_id": -7, + "image_generation_config_id": 5, + "vision_llm_config_id": -1, } @@ -119,7 +119,7 @@ async def test_build_dependencies_raises_on_captured_policy_violation( def _raise(**_kw): raise AutomationModelPolicyError( - [{"kind": "image", "model_id": -2, "reason": "free model"}] + [{"kind": "image", "config_id": -2, "reason": "free model"}] ) monkeypatch.setattr(deps_mod, "assert_models_billable", _raise) @@ -131,11 +131,11 @@ async def test_build_dependencies_raises_on_captured_policy_violation( with pytest.raises(DependencyError): await build_dependencies( - session=_FakeSession(SimpleNamespace(chat_model_id=-7)), + session=_FakeSession(SimpleNamespace(agent_llm_id=-7)), search_space_id=42, - chat_model_id=-7, - image_gen_model_id=-2, - vision_model_id=-1, + agent_llm_id=-7, + image_generation_config_id=-2, + vision_llm_config_id=-1, ) @@ -157,7 +157,7 @@ async def test_build_dependencies_falls_back_to_search_space( lambda **_kw: pytest.fail("captured policy should not run on fallback path"), ) - search_space = SimpleNamespace(chat_model_id=-7) + search_space = SimpleNamespace(agent_llm_id=-7) result = await build_dependencies( session=_FakeSession(search_space), search_space_id=42 ) diff --git a/surfsense_backend/tests/unit/automations/runtime/test_executor_action_ctx.py b/surfsense_backend/tests/unit/automations/runtime/test_executor_action_ctx.py index c89624fbf..d7e3c4a0c 100644 --- a/surfsense_backend/tests/unit/automations/runtime/test_executor_action_ctx.py +++ b/surfsense_backend/tests/unit/automations/runtime/test_executor_action_ctx.py @@ -28,9 +28,9 @@ def _run() -> SimpleNamespace: def test_build_action_ctx_propagates_captured_models() -> None: """``definition.models`` flows onto the ActionContext model fields.""" models = AutomationModels( - chat_model_id=-1, - image_gen_model_id=5, - vision_model_id=-1, + agent_llm_id=-1, + image_generation_config_id=5, + vision_llm_config_id=-1, ) ctx = _build_action_ctx( cast(AsyncSession, None), @@ -40,9 +40,9 @@ def test_build_action_ctx_propagates_captured_models() -> None: ) assert ctx.search_space_id == 42 - assert ctx.chat_model_id == -1 - assert ctx.image_gen_model_id == 5 - assert ctx.vision_model_id == -1 + assert ctx.agent_llm_id == -1 + assert ctx.image_generation_config_id == 5 + assert ctx.vision_llm_config_id == -1 def test_build_action_ctx_none_models_leaves_fields_none() -> None: @@ -54,6 +54,6 @@ def test_build_action_ctx_none_models_leaves_fields_none() -> None: None, ) - assert ctx.chat_model_id is None - assert ctx.image_gen_model_id is None - assert ctx.vision_model_id is None + assert ctx.agent_llm_id is None + assert ctx.image_generation_config_id is None + assert ctx.vision_llm_config_id is None diff --git a/surfsense_backend/tests/unit/automations/schemas/definition/test_envelope.py b/surfsense_backend/tests/unit/automations/schemas/definition/test_envelope.py index dc7221b11..25e193ffa 100644 --- a/surfsense_backend/tests/unit/automations/schemas/definition/test_envelope.py +++ b/surfsense_backend/tests/unit/automations/schemas/definition/test_envelope.py @@ -40,24 +40,24 @@ def test_automation_definition_models_round_trip() -> None: name="Daily digest", plan=[PlanStep(step_id="s1", action="agent_task")], models=AutomationModels( - chat_model_id=-1, - image_gen_model_id=5, - vision_model_id=-1, + agent_llm_id=-1, + image_generation_config_id=5, + vision_llm_config_id=-1, ), ) dumped = definition.model_dump(mode="json", by_alias=True) assert dumped["models"] == { - "chat_model_id": -1, - "image_gen_model_id": 5, - "vision_model_id": -1, + "agent_llm_id": -1, + "image_generation_config_id": 5, + "vision_llm_config_id": -1, } restored = AutomationDefinition.model_validate(dumped) assert restored.models is not None - assert restored.models.chat_model_id == -1 - assert restored.models.image_gen_model_id == 5 - assert restored.models.vision_model_id == -1 + assert restored.models.agent_llm_id == -1 + assert restored.models.image_generation_config_id == 5 + assert restored.models.vision_llm_config_id == -1 def test_automation_definition_rejects_unknown_top_level_field() -> None: diff --git a/surfsense_backend/tests/unit/automations/services/test_automation_service_policy.py b/surfsense_backend/tests/unit/automations/services/test_automation_service_policy.py index c97dec6a2..0bbff39dc 100644 --- a/surfsense_backend/tests/unit/automations/services/test_automation_service_policy.py +++ b/surfsense_backend/tests/unit/automations/services/test_automation_service_policy.py @@ -64,12 +64,12 @@ async def test_assert_models_billable_raises_422_on_violation( def _raise(_ss): raise AutomationModelPolicyError( - [{"kind": "llm", "model_id": 0, "reason": "Auto mode"}] + [{"kind": "llm", "config_id": 0, "reason": "Auto mode"}] ) monkeypatch.setattr(automation_mod, "assert_automation_models_billable", _raise) - service = _service(SimpleNamespace(chat_model_id=0)) + service = _service(SimpleNamespace(agent_llm_id=0)) with pytest.raises(HTTPException) as exc_info: await service._assert_models_billable(1) @@ -99,7 +99,7 @@ async def test_assert_models_billable_returns_search_space_when_ok( automation_mod, "assert_automation_models_billable", lambda _ss: None ) - search_space = SimpleNamespace(chat_model_id=-1) + search_space = SimpleNamespace(agent_llm_id=-1) service = _service(search_space) assert await service._assert_models_billable(1) is search_space @@ -123,9 +123,9 @@ async def test_create_injects_captured_models_from_search_space( monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added) search_space = SimpleNamespace( - chat_model_id=-1, - image_gen_model_id=5, - vision_model_id=-1, + agent_llm_id=-1, + image_generation_config_id=5, + vision_llm_config_id=-1, ) service = _service(search_space) payload = AutomationCreate( @@ -137,9 +137,9 @@ async def test_create_injects_captured_models_from_search_space( automation = await service.create(payload) assert automation.definition["models"] == { - "chat_model_id": -1, - "image_gen_model_id": 5, - "vision_model_id": -1, + "agent_llm_id": -1, + "image_generation_config_id": 5, + "vision_llm_config_id": -1, } @@ -162,9 +162,9 @@ async def test_create_treats_unset_prefs_as_auto_zero( monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added) search_space = SimpleNamespace( - chat_model_id=None, - image_gen_model_id=None, - vision_model_id=None, + agent_llm_id=None, + image_generation_config_id=None, + vision_llm_config_id=None, ) service = _service(search_space) payload = AutomationCreate(search_space_id=1, name="A", definition=_definition()) @@ -172,9 +172,9 @@ async def test_create_treats_unset_prefs_as_auto_zero( automation = await service.create(payload) assert automation.definition["models"] == { - "chat_model_id": 0, - "image_gen_model_id": 0, - "vision_model_id": 0, + "agent_llm_id": 0, + "image_generation_config_id": 0, + "vision_llm_config_id": 0, } @@ -195,11 +195,11 @@ async def test_create_honors_selected_models_when_provided( ) validated: dict[str, Any] = {} - def _assert_ok(*, chat_model_id, image_gen_model_id, vision_model_id): + def _assert_ok(*, agent_llm_id, image_generation_config_id, vision_llm_config_id): validated["ids"] = ( - chat_model_id, - image_gen_model_id, - vision_model_id, + agent_llm_id, + image_generation_config_id, + vision_llm_config_id, ) monkeypatch.setattr(automation_mod, "assert_models_billable", _assert_ok) @@ -213,15 +213,15 @@ async def test_create_honors_selected_models_when_provided( monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize) monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added) - service = _service(SimpleNamespace(chat_model_id=-99)) + service = _service(SimpleNamespace(agent_llm_id=-99)) payload = AutomationCreate( search_space_id=1, name="A", definition=_definition( models=AutomationModels( - chat_model_id=-1, - image_gen_model_id=7, - vision_model_id=-2, + agent_llm_id=-1, + image_generation_config_id=7, + vision_llm_config_id=-2, ) ), ) @@ -230,9 +230,9 @@ async def test_create_honors_selected_models_when_provided( assert validated["ids"] == (-1, 7, -2) assert automation.definition["models"] == { - "chat_model_id": -1, - "image_gen_model_id": 7, - "vision_model_id": -2, + "agent_llm_id": -1, + "image_generation_config_id": 7, + "vision_llm_config_id": -2, } @@ -241,9 +241,9 @@ async def test_create_rejects_unbillable_selected_models( ) -> None: """A non-billable explicit selection maps the policy error to HTTP 422.""" - def _raise(*, chat_model_id, image_gen_model_id, vision_model_id): + def _raise(*, agent_llm_id, image_generation_config_id, vision_llm_config_id): raise AutomationModelPolicyError( - [{"kind": "llm", "model_id": -3, "reason": "free model"}] + [{"kind": "llm", "config_id": -3, "reason": "free model"}] ) monkeypatch.setattr(automation_mod, "assert_models_billable", _raise) @@ -253,15 +253,15 @@ async def test_create_rejects_unbillable_selected_models( monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize) - service = _service(SimpleNamespace(chat_model_id=-3)) + service = _service(SimpleNamespace(agent_llm_id=-3)) payload = AutomationCreate( search_space_id=1, name="A", definition=_definition( models=AutomationModels( - chat_model_id=-3, - image_gen_model_id=7, - vision_model_id=-2, + agent_llm_id=-3, + image_generation_config_id=7, + vision_llm_config_id=-2, ) ), ) @@ -277,9 +277,9 @@ async def test_update_preserves_captured_models( ) -> None: """A definition edit carries over the previously captured ``models``.""" captured = { - "chat_model_id": -1, - "image_gen_model_id": 5, - "vision_model_id": -1, + "agent_llm_id": -1, + "image_generation_config_id": 5, + "vision_llm_config_id": -1, } existing = SimpleNamespace( search_space_id=1, @@ -318,20 +318,20 @@ async def test_update_honors_changed_models_when_valid( "name": "A", "plan": [], "models": { - "chat_model_id": -1, - "image_gen_model_id": 5, - "vision_model_id": -1, + "agent_llm_id": -1, + "image_generation_config_id": 5, + "vision_llm_config_id": -1, }, }, version=3, ) validated: dict[str, Any] = {} - def _assert_ok(*, chat_model_id, image_gen_model_id, vision_model_id): + def _assert_ok(*, agent_llm_id, image_generation_config_id, vision_llm_config_id): validated["ids"] = ( - chat_model_id, - image_gen_model_id, - vision_model_id, + agent_llm_id, + image_generation_config_id, + vision_llm_config_id, ) monkeypatch.setattr(automation_mod, "assert_models_billable", _assert_ok) @@ -351,9 +351,9 @@ async def test_update_honors_changed_models_when_valid( patch = AutomationUpdate( definition=_definition( models=AutomationModels( - chat_model_id=-2, - image_gen_model_id=9, - vision_model_id=-2, + agent_llm_id=-2, + image_generation_config_id=9, + vision_llm_config_id=-2, ) ) ) @@ -362,9 +362,9 @@ async def test_update_honors_changed_models_when_valid( assert validated["ids"] == (-2, 9, -2) assert result.definition["models"] == { - "chat_model_id": -2, - "image_gen_model_id": 9, - "vision_model_id": -2, + "agent_llm_id": -2, + "image_generation_config_id": 9, + "vision_llm_config_id": -2, } assert result.version == 4 @@ -379,17 +379,17 @@ async def test_update_rejects_changed_unbillable_models( "name": "A", "plan": [], "models": { - "chat_model_id": -1, - "image_gen_model_id": 5, - "vision_model_id": -1, + "agent_llm_id": -1, + "image_generation_config_id": 5, + "vision_llm_config_id": -1, }, }, version=3, ) - def _raise(*, chat_model_id, image_gen_model_id, vision_model_id): + def _raise(*, agent_llm_id, image_generation_config_id, vision_llm_config_id): raise AutomationModelPolicyError( - [{"kind": "llm", "model_id": -7, "reason": "free model"}] + [{"kind": "llm", "config_id": -7, "reason": "free model"}] ) monkeypatch.setattr(automation_mod, "assert_models_billable", _raise) @@ -409,9 +409,9 @@ async def test_update_rejects_changed_unbillable_models( patch = AutomationUpdate( definition=_definition( models=AutomationModels( - chat_model_id=-7, - image_gen_model_id=5, - vision_model_id=-1, + agent_llm_id=-7, + image_generation_config_id=5, + vision_llm_config_id=-1, ) ) ) @@ -431,9 +431,9 @@ async def test_update_keeps_unchanged_models_without_revalidation( premium without an unrelated edit tripping the policy check. """ captured = { - "chat_model_id": -1, - "image_gen_model_id": 5, - "vision_model_id": -1, + "agent_llm_id": -1, + "image_generation_config_id": 5, + "vision_llm_config_id": -1, } existing = SimpleNamespace( search_space_id=1, @@ -485,7 +485,7 @@ async def test_model_eligibility_authorizes_and_returns_payload( lambda _ss: {"allowed": False, "violations": [{"kind": "image"}]}, ) - service = _service(SimpleNamespace(chat_model_id=-2)) + service = _service(SimpleNamespace(agent_llm_id=-2)) result = await service.model_eligibility(search_space_id=5) assert result == {"allowed": False, "violations": [{"kind": "image"}]} diff --git a/surfsense_backend/tests/unit/automations/services/test_model_policy.py b/surfsense_backend/tests/unit/automations/services/test_model_policy.py index 574f6d9fd..8e0806151 100644 --- a/surfsense_backend/tests/unit/automations/services/test_model_policy.py +++ b/surfsense_backend/tests/unit/automations/services/test_model_policy.py @@ -27,9 +27,9 @@ pytestmark = pytest.mark.unit def _search_space(*, llm: int | None, image: int | None, vision: int | None): """Minimal stand-in for the ``SearchSpace`` ORM row the policy reads.""" return SimpleNamespace( - chat_model_id=llm, - image_gen_model_id=image, - vision_model_id=vision, + agent_llm_id=llm, + image_generation_config_id=image, + vision_llm_config_id=vision, ) @@ -39,11 +39,29 @@ def patched_globals(monkeypatch: pytest.MonkeyPatch): Negative ids: -1 is premium, -2 is free, for each of llm/image/vision. """ + llm_configs = { + -1: {"id": -1, "billing_tier": "premium"}, + -2: {"id": -2, "billing_tier": "free"}, + } + monkeypatch.setattr( + "app.agents.chat.runtime.llm_config.load_global_llm_config_by_id", + lambda cid: llm_configs.get(cid), + ) + from app.config import config as app_config monkeypatch.setattr( app_config, - "GLOBAL_MODELS", + "GLOBAL_IMAGE_GEN_CONFIGS", + [ + {"id": -1, "billing_tier": "premium"}, + {"id": -2, "billing_tier": "free"}, + ], + raising=False, + ) + monkeypatch.setattr( + app_config, + "GLOBAL_VISION_LLM_CONFIGS", [ {"id": -1, "billing_tier": "premium"}, {"id": -2, "billing_tier": "free"}, @@ -53,7 +71,7 @@ def patched_globals(monkeypatch: pytest.MonkeyPatch): return None -@pytest.mark.parametrize("kind", ["chat", "image", "vision"]) +@pytest.mark.parametrize("kind", ["llm", "image", "vision"]) def test_byok_positive_id_is_allowed(kind: str, patched_globals) -> None: """A positive config id is a user-owned BYOK model — always billable.""" allowed, reason = model_policy._classify(kind, 7) @@ -61,7 +79,7 @@ def test_byok_positive_id_is_allowed(kind: str, patched_globals) -> None: assert reason == "" -@pytest.mark.parametrize("kind", ["chat", "image", "vision"]) +@pytest.mark.parametrize("kind", ["llm", "image", "vision"]) @pytest.mark.parametrize("config_id", [0, None]) def test_auto_mode_is_blocked(kind: str, config_id, patched_globals) -> None: """Auto mode (id 0) and an unset slot (None) are blocked.""" @@ -70,7 +88,7 @@ def test_auto_mode_is_blocked(kind: str, config_id, patched_globals) -> None: assert "Auto mode" in reason -@pytest.mark.parametrize("kind", ["chat", "image", "vision"]) +@pytest.mark.parametrize("kind", ["llm", "image", "vision"]) def test_premium_global_is_allowed(kind: str, patched_globals) -> None: """A negative (global) id with premium billing tier is allowed.""" allowed, reason = model_policy._classify(kind, -1) @@ -78,7 +96,7 @@ def test_premium_global_is_allowed(kind: str, patched_globals) -> None: assert reason == "" -@pytest.mark.parametrize("kind", ["chat", "image", "vision"]) +@pytest.mark.parametrize("kind", ["llm", "image", "vision"]) def test_free_global_is_blocked(kind: str, patched_globals) -> None: """A negative (global) id with a free billing tier is blocked.""" allowed, reason = model_policy._classify(kind, -2) @@ -86,7 +104,7 @@ def test_free_global_is_blocked(kind: str, patched_globals) -> None: assert "free model" in reason -@pytest.mark.parametrize("kind", ["chat", "image", "vision"]) +@pytest.mark.parametrize("kind", ["llm", "image", "vision"]) def test_unknown_global_id_is_blocked(kind: str, patched_globals) -> None: """A negative id that resolves to no config is treated as not premium.""" allowed, _ = model_policy._classify(kind, -999) @@ -107,10 +125,10 @@ def test_eligibility_reports_each_violation(patched_globals) -> None: assert result["allowed"] is False kinds = {v["kind"] for v in result["violations"]} - assert kinds == {"chat", "image", "vision"} - # model_id is echoed back for the UI / settings deep-link. - by_kind = {v["kind"]: v["model_id"] for v in result["violations"]} - assert by_kind == {"chat": -2, "image": 0, "vision": -2} + assert kinds == {"llm", "image", "vision"} + # config_id is echoed back for the UI / settings deep-link. + by_kind = {v["kind"]: v["config_id"] for v in result["violations"]} + assert by_kind == {"llm": -2, "image": 0, "vision": -2} def test_assert_raises_with_violations(patched_globals) -> None: @@ -120,7 +138,7 @@ def test_assert_raises_with_violations(patched_globals) -> None: assert_automation_models_billable(search_space) assert len(exc_info.value.violations) == 1 - assert exc_info.value.violations[0]["kind"] == "chat" + assert exc_info.value.violations[0]["kind"] == "llm" def test_assert_passes_when_all_billable(patched_globals) -> None: @@ -135,7 +153,7 @@ def test_assert_passes_when_all_billable(patched_globals) -> None: def test_get_model_eligibility_all_billable(patched_globals) -> None: """Premium LLM + BYOK image + premium vision (explicit ids) → allowed.""" result = get_model_eligibility( - chat_model_id=-1, image_gen_model_id=5, vision_model_id=-1 + agent_llm_id=-1, image_generation_config_id=5, vision_llm_config_id=-1 ) assert result == {"allowed": True, "violations": []} @@ -143,28 +161,28 @@ def test_get_model_eligibility_all_billable(patched_globals) -> None: def test_get_model_eligibility_reports_each_violation(patched_globals) -> None: """Free LLM, Auto image, free vision (explicit ids) each produce a violation.""" result = get_model_eligibility( - chat_model_id=-2, image_gen_model_id=0, vision_model_id=-2 + agent_llm_id=-2, image_generation_config_id=0, vision_llm_config_id=-2 ) assert result["allowed"] is False - by_kind = {v["kind"]: v["model_id"] for v in result["violations"]} - assert by_kind == {"chat": -2, "image": 0, "vision": -2} + by_kind = {v["kind"]: v["config_id"] for v in result["violations"]} + assert by_kind == {"llm": -2, "image": 0, "vision": -2} def test_assert_models_billable_raises(patched_globals) -> None: """``assert_models_billable`` raises when any explicit id is blocked.""" with pytest.raises(AutomationModelPolicyError) as exc_info: assert_models_billable( - chat_model_id=0, image_gen_model_id=5, vision_model_id=-1 + agent_llm_id=0, image_generation_config_id=5, vision_llm_config_id=-1 ) assert len(exc_info.value.violations) == 1 - assert exc_info.value.violations[0]["kind"] == "chat" + assert exc_info.value.violations[0]["kind"] == "llm" def test_assert_models_billable_passes(patched_globals) -> None: """No exception when every explicit id is premium or BYOK.""" assert ( assert_models_billable( - chat_model_id=3, image_gen_model_id=-1, vision_model_id=4 + agent_llm_id=3, image_generation_config_id=-1, vision_llm_config_id=4 ) is None ) @@ -174,5 +192,5 @@ def test_search_space_wrapper_delegates_to_core(patched_globals) -> None: """The search-space wrapper produces the same result as the ID core.""" search_space = _search_space(llm=-2, image=0, vision=-2) assert get_automation_model_eligibility(search_space) == get_model_eligibility( - chat_model_id=-2, image_gen_model_id=0, vision_model_id=-2 + agent_llm_id=-2, image_generation_config_id=0, vision_llm_config_id=-2 ) diff --git a/surfsense_backend/tests/unit/connector_indexers/test_dropbox_parallel.py b/surfsense_backend/tests/unit/connector_indexers/test_dropbox_parallel.py index a74591169..b87d1be42 100644 --- a/surfsense_backend/tests/unit/connector_indexers/test_dropbox_parallel.py +++ b/surfsense_backend/tests/unit/connector_indexers/test_dropbox_parallel.py @@ -272,26 +272,22 @@ def full_scan_mocks(mock_dropbox_client, monkeypatch): download_and_index_mock = AsyncMock(return_value=(0, 0)) monkeypatch.setattr(_mod, "_download_and_index", download_and_index_mock) - from app.services.etl_credit_service import EtlCreditService as _RealECS + from app.services.page_limit_service import PageLimitService as _RealPLS - # get_available_micros -> None means "unlimited" (billing disabled), so no - # batch is gated and charge_credits is a no-op — matching the prior - # 999_999 page-limit intent for these parallel-processing tests. - mock_credit_instance = MagicMock() - mock_credit_instance.get_available_micros = AsyncMock(return_value=None) - mock_credit_instance.charge_credits = AsyncMock(return_value=None) + mock_page_limit_instance = MagicMock() + mock_page_limit_instance.get_page_usage = AsyncMock(return_value=(0, 999_999)) + mock_page_limit_instance.update_page_usage = AsyncMock() - class _MockEtlCreditService: + class _MockPageLimitService: estimate_pages_from_metadata = staticmethod( - _RealECS.estimate_pages_from_metadata + _RealPLS.estimate_pages_from_metadata ) - pages_to_micros = staticmethod(_RealECS.pages_to_micros) def __init__(self, session): - self.get_available_micros = mock_credit_instance.get_available_micros - self.charge_credits = mock_credit_instance.charge_credits + self.get_page_usage = mock_page_limit_instance.get_page_usage + self.update_page_usage = mock_page_limit_instance.update_page_usage - monkeypatch.setattr(_mod, "EtlCreditService", _MockEtlCreditService) + monkeypatch.setattr(_mod, "PageLimitService", _MockPageLimitService) return { "dropbox_client": mock_dropbox_client, @@ -397,23 +393,22 @@ def selected_files_mocks(mock_dropbox_client, monkeypatch): download_and_index_mock = AsyncMock(return_value=(0, 0)) monkeypatch.setattr(_mod, "_download_and_index", download_and_index_mock) - from app.services.etl_credit_service import EtlCreditService as _RealECS + from app.services.page_limit_service import PageLimitService as _RealPLS - mock_credit_instance = MagicMock() - mock_credit_instance.get_available_micros = AsyncMock(return_value=None) - mock_credit_instance.charge_credits = AsyncMock(return_value=None) + mock_page_limit_instance = MagicMock() + mock_page_limit_instance.get_page_usage = AsyncMock(return_value=(0, 999_999)) + mock_page_limit_instance.update_page_usage = AsyncMock() - class _MockEtlCreditService: + class _MockPageLimitService: estimate_pages_from_metadata = staticmethod( - _RealECS.estimate_pages_from_metadata + _RealPLS.estimate_pages_from_metadata ) - pages_to_micros = staticmethod(_RealECS.pages_to_micros) def __init__(self, session): - self.get_available_micros = mock_credit_instance.get_available_micros - self.charge_credits = mock_credit_instance.charge_credits + self.get_page_usage = mock_page_limit_instance.get_page_usage + self.update_page_usage = mock_page_limit_instance.update_page_usage - monkeypatch.setattr(_mod, "EtlCreditService", _MockEtlCreditService) + monkeypatch.setattr(_mod, "PageLimitService", _MockPageLimitService) return { "dropbox_client": mock_dropbox_client, diff --git a/surfsense_backend/tests/unit/connector_indexers/test_google_drive_parallel.py b/surfsense_backend/tests/unit/connector_indexers/test_google_drive_parallel.py index 4f61976a6..9a13e4525 100644 --- a/surfsense_backend/tests/unit/connector_indexers/test_google_drive_parallel.py +++ b/surfsense_backend/tests/unit/connector_indexers/test_google_drive_parallel.py @@ -242,28 +242,20 @@ def _folder_dict(file_id: str, name: str) -> dict: } -def _make_page_limit_session(balance_micros=999_999_000, reserved_micros=0): - """Build a mock DB session that real EtlCreditService can operate against. - - ETL credit billing is disabled by default in tests, so get_available_micros - short-circuits to None ("unlimited") and these fields are unused; they're - provided for parity if a test opts into billing. - """ +def _make_page_limit_session(pages_used=0, pages_limit=999_999): + """Build a mock DB session that real PageLimitService can operate against.""" class _FakeUser: - def __init__(self, balance, reserved): - self.credit_micros_balance = balance - self.credit_micros_reserved = reserved + def __init__(self, pu, pl): + self.pages_used = pu + self.pages_limit = pl - fake_user = _FakeUser(balance_micros, reserved_micros) + fake_user = _FakeUser(pages_used, pages_limit) session = AsyncMock() def _make_result(*_a, **_kw): r = MagicMock() - r.first.return_value = ( - fake_user.credit_micros_balance, - fake_user.credit_micros_reserved, - ) + r.first.return_value = (fake_user.pages_used, fake_user.pages_limit) r.unique.return_value.scalar_one_or_none.return_value = fake_user return r diff --git a/surfsense_backend/tests/unit/connector_indexers/test_etl_credits.py b/surfsense_backend/tests/unit/connector_indexers/test_page_limits.py similarity index 73% rename from surfsense_backend/tests/unit/connector_indexers/test_etl_credits.py rename to surfsense_backend/tests/unit/connector_indexers/test_page_limits.py index aca811ee9..66722ffd7 100644 --- a/surfsense_backend/tests/unit/connector_indexers/test_etl_credits.py +++ b/surfsense_backend/tests/unit/connector_indexers/test_page_limits.py @@ -1,22 +1,17 @@ -"""Tests for ETL credit enforcement in connector indexers. +"""Tests for page limit enforcement in connector indexers. Covers: - A) EtlCreditService.estimate_pages_from_metadata — pure function (no mocks) - B) Credit-wallet gating in the connector indexers, tested through the real - EtlCreditService with a mock DB session (system boundary). ETL credit - billing is force-enabled per-test so the gating path is exercised. + A) PageLimitService.estimate_pages_from_metadata — pure function (no mocks) + B) Page-limit quota gating in _index_selected_files tested through the + real PageLimitService with a mock DB session (system boundary). Google Drive is the primary, with OneDrive/Dropbox smoke tests. - -Page estimates are converted to micro-USD at ``config.MICROS_PER_PAGE`` per -page and debited from ``user.credit_micros_balance``. """ from unittest.mock import AsyncMock, MagicMock import pytest -from app.config import config -from app.services.etl_credit_service import EtlCreditService +from app.services.page_limit_service import PageLimitService pytestmark = pytest.mark.unit @@ -25,23 +20,8 @@ _CONNECTOR_ID = 42 _SEARCH_SPACE_ID = 1 -def _micros(pages: int) -> int: - """Convert a page count to micro-USD using the configured rate.""" - return pages * config.MICROS_PER_PAGE - - -@pytest.fixture(autouse=True) -def _enable_etl_billing(monkeypatch): - """Force ETL credit billing on so the gating/charging path runs. - - It defaults to off (self-hosted/OSS), which would short-circuit - get_available_micros to None and bypass every check in this module. - """ - monkeypatch.setattr(config, "ETL_CREDIT_BILLING_ENABLED", True) - - # =================================================================== -# A) EtlCreditService.estimate_pages_from_metadata — pure function +# A) PageLimitService.estimate_pages_from_metadata — pure function # No mocks: it's a staticmethod with no I/O. # =================================================================== @@ -50,91 +30,88 @@ class TestEstimatePagesFromMetadata: """Vertical slices for the page estimation staticmethod.""" def test_pdf_100kb_returns_1(self): - assert EtlCreditService.estimate_pages_from_metadata(".pdf", 100 * 1024) == 1 + assert PageLimitService.estimate_pages_from_metadata(".pdf", 100 * 1024) == 1 def test_pdf_500kb_returns_5(self): - assert EtlCreditService.estimate_pages_from_metadata(".pdf", 500 * 1024) == 5 + assert PageLimitService.estimate_pages_from_metadata(".pdf", 500 * 1024) == 5 def test_pdf_1mb(self): - assert EtlCreditService.estimate_pages_from_metadata(".pdf", 1024 * 1024) == 10 + assert PageLimitService.estimate_pages_from_metadata(".pdf", 1024 * 1024) == 10 def test_docx_50kb_returns_1(self): - assert EtlCreditService.estimate_pages_from_metadata(".docx", 50 * 1024) == 1 + assert PageLimitService.estimate_pages_from_metadata(".docx", 50 * 1024) == 1 def test_docx_200kb(self): - assert EtlCreditService.estimate_pages_from_metadata(".docx", 200 * 1024) == 4 + assert PageLimitService.estimate_pages_from_metadata(".docx", 200 * 1024) == 4 def test_pptx_uses_200kb_per_page(self): - assert EtlCreditService.estimate_pages_from_metadata(".pptx", 600 * 1024) == 3 + assert PageLimitService.estimate_pages_from_metadata(".pptx", 600 * 1024) == 3 def test_xlsx_uses_100kb_per_page(self): - assert EtlCreditService.estimate_pages_from_metadata(".xlsx", 300 * 1024) == 3 + assert PageLimitService.estimate_pages_from_metadata(".xlsx", 300 * 1024) == 3 def test_txt_uses_3000_bytes_per_page(self): - assert EtlCreditService.estimate_pages_from_metadata(".txt", 9000) == 3 + assert PageLimitService.estimate_pages_from_metadata(".txt", 9000) == 3 def test_image_always_returns_1(self): for ext in (".jpg", ".png", ".gif", ".webp"): - assert EtlCreditService.estimate_pages_from_metadata(ext, 5_000_000) == 1 + assert PageLimitService.estimate_pages_from_metadata(ext, 5_000_000) == 1 def test_audio_uses_1mb_per_page(self): assert ( - EtlCreditService.estimate_pages_from_metadata(".mp3", 3 * 1024 * 1024) == 3 + PageLimitService.estimate_pages_from_metadata(".mp3", 3 * 1024 * 1024) == 3 ) def test_video_uses_5mb_per_page(self): assert ( - EtlCreditService.estimate_pages_from_metadata(".mp4", 15 * 1024 * 1024) == 3 + PageLimitService.estimate_pages_from_metadata(".mp4", 15 * 1024 * 1024) == 3 ) def test_unknown_ext_uses_80kb_per_page(self): - assert EtlCreditService.estimate_pages_from_metadata(".xyz", 160 * 1024) == 2 + assert PageLimitService.estimate_pages_from_metadata(".xyz", 160 * 1024) == 2 def test_zero_size_returns_1(self): - assert EtlCreditService.estimate_pages_from_metadata(".pdf", 0) == 1 + assert PageLimitService.estimate_pages_from_metadata(".pdf", 0) == 1 def test_negative_size_returns_1(self): - assert EtlCreditService.estimate_pages_from_metadata(".pdf", -500) == 1 + assert PageLimitService.estimate_pages_from_metadata(".pdf", -500) == 1 def test_minimum_is_always_1(self): - assert EtlCreditService.estimate_pages_from_metadata(".pdf", 50) == 1 + assert PageLimitService.estimate_pages_from_metadata(".pdf", 50) == 1 def test_epub_uses_50kb_per_page(self): - assert EtlCreditService.estimate_pages_from_metadata(".epub", 250 * 1024) == 5 + assert PageLimitService.estimate_pages_from_metadata(".epub", 250 * 1024) == 5 # =================================================================== -# B) Credit enforcement in connector indexers -# System boundary mocked: DB session (for EtlCreditService) +# B) Page-limit enforcement in connector indexers +# System boundary mocked: DB session (for PageLimitService) # System boundary mocked: external API clients, download/ETL -# NOT mocked: EtlCreditService itself (our own code) +# NOT mocked: PageLimitService itself (our own code) # =================================================================== class _FakeUser: """Stands in for the User ORM model at the DB boundary.""" - def __init__(self, balance_micros: int = 0, reserved_micros: int = 0): - self.credit_micros_balance = balance_micros - self.credit_micros_reserved = reserved_micros + def __init__(self, pages_used: int = 0, pages_limit: int = 100): + self.pages_used = pages_used + self.pages_limit = pages_limit -def _make_credit_session(balance_micros: int = _micros(100), reserved_micros: int = 0): - """Build a mock DB session that the real EtlCreditService can operate against. +def _make_page_limit_session(pages_used: int = 0, pages_limit: int = 100): + """Build a mock DB session that real PageLimitService can operate against. Every ``session.execute()`` returns a result compatible with both - ``get_available_micros`` (.first() → ``(balance, reserved)``) and - ``charge_credits`` (.unique().scalar_one_or_none() → User-like). + ``get_page_usage`` (.first() → tuple) and ``update_page_usage`` + (.unique().scalar_one_or_none() → User-like). """ - fake_user = _FakeUser(balance_micros, reserved_micros) + fake_user = _FakeUser(pages_used, pages_limit) session = AsyncMock() def _make_result(*_args, **_kwargs): result = MagicMock() - result.first.return_value = ( - fake_user.credit_micros_balance, - fake_user.credit_micros_reserved, - ) + result.first.return_value = (fake_user.pages_used, fake_user.pages_limit) result.unique.return_value.scalar_one_or_none.return_value = fake_user return result @@ -161,7 +138,7 @@ def gdrive_selected_mocks(monkeypatch): """Mocks for Google Drive _index_selected_files — only system boundaries.""" import app.tasks.connector_indexers.google_drive_indexer as _mod - session, fake_user = _make_credit_session(_micros(100)) + session, fake_user = _make_page_limit_session(0, 100) get_file_results: dict[str, tuple[dict | None, str | None]] = {} @@ -206,11 +183,12 @@ async def _run_gdrive_selected(mocks, file_ids): ) -async def test_gdrive_files_within_credit_are_downloaded(gdrive_selected_mocks): - """Files whose cumulative estimated cost fits within available credit +async def test_gdrive_files_within_quota_are_downloaded(gdrive_selected_mocks): + """Files whose cumulative estimated pages fit within remaining quota are sent to _download_and_index.""" m = gdrive_selected_mocks - m["fake_user"].credit_micros_balance = _micros(100) + m["fake_user"].pages_used = 0 + m["fake_user"].pages_limit = 100 for fid in ("f1", "f2", "f3"): m["get_file_results"][fid] = ( @@ -229,10 +207,11 @@ async def test_gdrive_files_within_credit_are_downloaded(gdrive_selected_mocks): assert len(call_files) == 3 -async def test_gdrive_files_exceeding_credit_rejected(gdrive_selected_mocks): - """Files whose cost would exceed available credit are rejected.""" +async def test_gdrive_files_exceeding_quota_rejected(gdrive_selected_mocks): + """Files whose pages would exceed remaining quota are rejected.""" m = gdrive_selected_mocks - m["fake_user"].credit_micros_balance = _micros(2) + m["fake_user"].pages_used = 98 + m["fake_user"].pages_limit = 100 m["get_file_results"]["big"] = ( _make_gdrive_file("big", "huge.pdf", size=500 * 1024), @@ -245,13 +224,14 @@ async def test_gdrive_files_exceeding_credit_rejected(gdrive_selected_mocks): assert indexed == 0 assert len(errors) == 1 - assert "insufficient credits" in errors[0].lower() + assert "page limit" in errors[0].lower() -async def test_gdrive_credit_mix_partial_indexing(gdrive_selected_mocks): - """3rd file pushes over available credit → only first two indexed.""" +async def test_gdrive_quota_mix_partial_indexing(gdrive_selected_mocks): + """3rd file pushes over quota → only first two indexed.""" m = gdrive_selected_mocks - m["fake_user"].credit_micros_balance = _micros(2) + m["fake_user"].pages_used = 0 + m["fake_user"].pages_limit = 2 for fid in ("f1", "f2", "f3"): m["get_file_results"][fid] = ( @@ -270,10 +250,11 @@ async def test_gdrive_credit_mix_partial_indexing(gdrive_selected_mocks): assert {f["id"] for f in call_files} == {"f1", "f2"} -async def test_gdrive_proportional_credit_deduction(gdrive_selected_mocks): - """Credit deducted is proportional to successfully indexed files.""" +async def test_gdrive_proportional_page_deduction(gdrive_selected_mocks): + """Pages deducted are proportional to successfully indexed files.""" m = gdrive_selected_mocks - m["fake_user"].credit_micros_balance = _micros(100) + m["fake_user"].pages_used = 0 + m["fake_user"].pages_limit = 100 for fid in ("f1", "f2", "f3", "f4"): m["get_file_results"][fid] = ( @@ -287,14 +268,14 @@ async def test_gdrive_proportional_credit_deduction(gdrive_selected_mocks): [("f1", "f1.xyz"), ("f2", "f2.xyz"), ("f3", "f3.xyz"), ("f4", "f4.xyz")], ) - # 4 estimated pages, 2 of 4 indexed → deduct 2 pages. - assert m["fake_user"].credit_micros_balance == _micros(100) - _micros(2) + assert m["fake_user"].pages_used == 2 async def test_gdrive_no_deduction_when_nothing_indexed(gdrive_selected_mocks): - """If batch_indexed == 0, the user's balance stays unchanged.""" + """If batch_indexed == 0, user's pages_used stays unchanged.""" m = gdrive_selected_mocks - m["fake_user"].credit_micros_balance = _micros(95) + m["fake_user"].pages_used = 5 + m["fake_user"].pages_limit = 100 m["get_file_results"]["f1"] = ( _make_gdrive_file("f1", "f1.xyz", size=80 * 1024), @@ -304,13 +285,14 @@ async def test_gdrive_no_deduction_when_nothing_indexed(gdrive_selected_mocks): await _run_gdrive_selected(m, [("f1", "f1.xyz")]) - assert m["fake_user"].credit_micros_balance == _micros(95) + assert m["fake_user"].pages_used == 5 -async def test_gdrive_zero_credit_rejects_all(gdrive_selected_mocks): - """When the balance is exhausted, every file is rejected.""" +async def test_gdrive_zero_quota_rejects_all(gdrive_selected_mocks): + """When pages_used == pages_limit, every file is rejected.""" m = gdrive_selected_mocks - m["fake_user"].credit_micros_balance = 0 + m["fake_user"].pages_used = 100 + m["fake_user"].pages_limit = 100 for fid in ("f1", "f2"): m["get_file_results"][fid] = ( @@ -335,7 +317,7 @@ async def test_gdrive_zero_credit_rejects_all(gdrive_selected_mocks): def gdrive_full_scan_mocks(monkeypatch): import app.tasks.connector_indexers.google_drive_indexer as _mod - session, fake_user = _make_credit_session(_micros(100)) + session, fake_user = _make_page_limit_session(0, 100) mock_task_logger = MagicMock() mock_task_logger.log_task_progress = AsyncMock() @@ -382,9 +364,10 @@ async def _run_gdrive_full_scan(mocks, max_files=500): ) -async def test_gdrive_full_scan_skips_over_credit(gdrive_full_scan_mocks, monkeypatch): +async def test_gdrive_full_scan_skips_over_quota(gdrive_full_scan_mocks, monkeypatch): m = gdrive_full_scan_mocks - m["fake_user"].credit_micros_balance = _micros(2) + m["fake_user"].pages_used = 0 + m["fake_user"].pages_limit = 2 page_files = [ _make_gdrive_file(f"f{i}", f"file{i}.xyz", size=80 * 1024) for i in range(5) @@ -408,7 +391,8 @@ async def test_gdrive_full_scan_deducts_after_indexing( gdrive_full_scan_mocks, monkeypatch ): m = gdrive_full_scan_mocks - m["fake_user"].credit_micros_balance = _micros(100) + m["fake_user"].pages_used = 0 + m["fake_user"].pages_limit = 100 page_files = [ _make_gdrive_file(f"f{i}", f"file{i}.xyz", size=80 * 1024) for i in range(3) @@ -424,7 +408,7 @@ async def test_gdrive_full_scan_deducts_after_indexing( await _run_gdrive_full_scan(m) - assert m["fake_user"].credit_micros_balance == _micros(100) - _micros(3) + assert m["fake_user"].pages_used == 3 # --------------------------------------------------------------------------- @@ -432,10 +416,10 @@ async def test_gdrive_full_scan_deducts_after_indexing( # --------------------------------------------------------------------------- -async def test_gdrive_delta_sync_skips_over_credit(monkeypatch): +async def test_gdrive_delta_sync_skips_over_quota(monkeypatch): import app.tasks.connector_indexers.google_drive_indexer as _mod - session, _ = _make_credit_session(_micros(2)) + session, _ = _make_page_limit_session(0, 2) changes = [ { @@ -487,7 +471,7 @@ async def test_gdrive_delta_sync_skips_over_credit(monkeypatch): # =================================================================== -# C) OneDrive smoke tests — verify credit wiring +# C) OneDrive smoke tests — verify page limit wiring # =================================================================== @@ -505,7 +489,7 @@ def _make_onedrive_file(file_id: str, name: str, size: int = 80 * 1024) -> dict: def onedrive_selected_mocks(monkeypatch): import app.tasks.connector_indexers.onedrive_indexer as _mod - session, fake_user = _make_credit_session(_micros(100)) + session, fake_user = _make_page_limit_session(0, 100) get_file_results: dict[str, tuple[dict | None, str | None]] = {} @@ -547,10 +531,11 @@ async def _run_onedrive_selected(mocks, file_ids): ) -async def test_onedrive_over_credit_rejected(onedrive_selected_mocks): - """OneDrive: files exceeding available credit produce errors, not downloads.""" +async def test_onedrive_over_quota_rejected(onedrive_selected_mocks): + """OneDrive: files exceeding quota produce errors, not downloads.""" m = onedrive_selected_mocks - m["fake_user"].credit_micros_balance = _micros(1) + m["fake_user"].pages_used = 99 + m["fake_user"].pages_limit = 100 m["get_file_results"]["big"] = ( _make_onedrive_file("big", "huge.pdf", size=500 * 1024), @@ -563,13 +548,14 @@ async def test_onedrive_over_credit_rejected(onedrive_selected_mocks): assert indexed == 0 assert len(errors) == 1 - assert "insufficient credits" in errors[0].lower() + assert "page limit" in errors[0].lower() async def test_onedrive_deducts_after_success(onedrive_selected_mocks): - """OneDrive: balance decreases after successful indexing.""" + """OneDrive: pages_used increases after successful indexing.""" m = onedrive_selected_mocks - m["fake_user"].credit_micros_balance = _micros(100) + m["fake_user"].pages_used = 0 + m["fake_user"].pages_limit = 100 for fid in ("f1", "f2"): m["get_file_results"][fid] = ( @@ -580,11 +566,11 @@ async def test_onedrive_deducts_after_success(onedrive_selected_mocks): await _run_onedrive_selected(m, [("f1", "f1.xyz"), ("f2", "f2.xyz")]) - assert m["fake_user"].credit_micros_balance == _micros(100) - _micros(2) + assert m["fake_user"].pages_used == 2 # =================================================================== -# D) Dropbox smoke tests — verify credit wiring +# D) Dropbox smoke tests — verify page limit wiring # =================================================================== @@ -604,7 +590,7 @@ def _make_dropbox_file(file_path: str, name: str, size: int = 80 * 1024) -> dict def dropbox_selected_mocks(monkeypatch): import app.tasks.connector_indexers.dropbox_indexer as _mod - session, fake_user = _make_credit_session(_micros(100)) + session, fake_user = _make_page_limit_session(0, 100) get_file_results: dict[str, tuple[dict | None, str | None]] = {} @@ -646,10 +632,11 @@ async def _run_dropbox_selected(mocks, file_paths): ) -async def test_dropbox_over_credit_rejected(dropbox_selected_mocks): - """Dropbox: files exceeding available credit produce errors, not downloads.""" +async def test_dropbox_over_quota_rejected(dropbox_selected_mocks): + """Dropbox: files exceeding quota produce errors, not downloads.""" m = dropbox_selected_mocks - m["fake_user"].credit_micros_balance = _micros(1) + m["fake_user"].pages_used = 99 + m["fake_user"].pages_limit = 100 m["get_file_results"]["/huge.pdf"] = ( _make_dropbox_file("/huge.pdf", "huge.pdf", size=500 * 1024), @@ -662,13 +649,14 @@ async def test_dropbox_over_credit_rejected(dropbox_selected_mocks): assert indexed == 0 assert len(errors) == 1 - assert "insufficient credits" in errors[0].lower() + assert "page limit" in errors[0].lower() async def test_dropbox_deducts_after_success(dropbox_selected_mocks): - """Dropbox: balance decreases after successful indexing.""" + """Dropbox: pages_used increases after successful indexing.""" m = dropbox_selected_mocks - m["fake_user"].credit_micros_balance = _micros(100) + m["fake_user"].pages_used = 0 + m["fake_user"].pages_limit = 100 for name in ("f1.xyz", "f2.xyz"): path = f"/{name}" @@ -680,4 +668,4 @@ async def test_dropbox_deducts_after_success(dropbox_selected_mocks): await _run_dropbox_selected(m, [("/f1.xyz", "f1.xyz"), ("/f2.xyz", "f2.xyz")]) - assert m["fake_user"].credit_micros_balance == _micros(100) - _micros(2) + assert m["fake_user"].pages_used == 2 diff --git a/surfsense_backend/tests/unit/etl_pipeline/cache/conftest.py b/surfsense_backend/tests/unit/etl_pipeline/cache/conftest.py deleted file mode 100644 index c6efddc09..000000000 --- a/surfsense_backend/tests/unit/etl_pipeline/cache/conftest.py +++ /dev/null @@ -1,28 +0,0 @@ -"""Stub the cache package __init__s so unit tests import only pure leaf modules. - -The real ``cache``/``storage``/``eviction``/``persistence`` __init__s eagerly -import the facade, file storage, Celery, and ``app.db`` -- none of which a pure -unit test should need. Turning those packages into bare namespace packages lets -``from app.etl_pipeline.cache.<pkg>.<leaf> import ...`` resolve the leaf module -without running the heavy __init__. ``schemas`` is left real (it is pure). -""" - -import sys -import types -from pathlib import Path - -_CACHE_DIR = Path(__file__).resolve().parents[4] / "app" / "etl_pipeline" / "cache" - - -def _stub_namespace_package(dotted: str, fs_dir: Path) -> None: - if dotted in sys.modules: - return - module = types.ModuleType(dotted) - module.__path__ = [str(fs_dir)] - module.__package__ = dotted - sys.modules[dotted] = module - - -_stub_namespace_package("app.etl_pipeline.cache", _CACHE_DIR) -_stub_namespace_package("app.etl_pipeline.cache.storage", _CACHE_DIR / "storage") -_stub_namespace_package("app.etl_pipeline.cache.eviction", _CACHE_DIR / "eviction") diff --git a/surfsense_backend/tests/unit/etl_pipeline/cache/test_eligibility.py b/surfsense_backend/tests/unit/etl_pipeline/cache/test_eligibility.py deleted file mode 100644 index 99d8e67b6..000000000 --- a/surfsense_backend/tests/unit/etl_pipeline/cache/test_eligibility.py +++ /dev/null @@ -1,88 +0,0 @@ -"""What is allowed into the cache -- the gating rules, as pure logic. - -These rules decide whether a given upload may be served from / written to the -parse cache. They live in a pure predicate so every branch (disabled, vision, -no service, file category) is covered here without touching DB, storage, or the -parser. -""" - -from __future__ import annotations - -import pytest - -from app.etl_pipeline.cache.eligibility import is_parse_cacheable - -pytestmark = pytest.mark.unit - - -def test_document_with_service_and_cache_on_is_cacheable(): - assert is_parse_cacheable( - filename="report.pdf", - etl_service="LLAMACLOUD", - cache_enabled=True, - has_vision_llm=False, - ) - - -def test_disabled_cache_is_never_cacheable(): - assert not is_parse_cacheable( - filename="report.pdf", - etl_service="LLAMACLOUD", - cache_enabled=False, - has_vision_llm=False, - ) - - -def test_vision_llm_run_is_not_cacheable(): - # Vision appends model output not captured by the key; sharing it would leak - # one run's generated text into a plain parse of the same bytes. - assert not is_parse_cacheable( - filename="report.pdf", - etl_service="LLAMACLOUD", - cache_enabled=True, - has_vision_llm=True, - ) - - -@pytest.mark.parametrize("etl_service", [None, ""]) -def test_missing_etl_service_is_not_cacheable(etl_service): - assert not is_parse_cacheable( - filename="report.pdf", - etl_service=etl_service, - cache_enabled=True, - has_vision_llm=False, - ) - - -@pytest.mark.parametrize( - "filename", - ["paper.pdf", "memo.docx", "slides.pptx", "sheet.xlsx", "book.epub"], -) -def test_document_extensions_are_cacheable(filename): - assert is_parse_cacheable( - filename=filename, - etl_service="LLAMACLOUD", - cache_enabled=True, - has_vision_llm=False, - ) - - -@pytest.mark.parametrize( - "filename", - [ - "notes.txt", # plaintext - "readme.md", # plaintext - "main.py", # plaintext - "podcast.mp3", # audio - "photo.png", # image (vision path / fallback, not a shared doc parse) - "data.csv", # direct-convert - "archive.xyz", # unsupported - ], -) -def test_non_document_categories_are_not_cacheable(filename): - assert not is_parse_cacheable( - filename=filename, - etl_service="LLAMACLOUD", - cache_enabled=True, - has_vision_llm=False, - ) diff --git a/surfsense_backend/tests/unit/etl_pipeline/cache/test_eviction_policy.py b/surfsense_backend/tests/unit/etl_pipeline/cache/test_eviction_policy.py deleted file mode 100644 index 5113d7c42..000000000 --- a/surfsense_backend/tests/unit/etl_pipeline/cache/test_eviction_policy.py +++ /dev/null @@ -1,76 +0,0 @@ -"""Size-based eviction: drop just enough of the coldest entries to fit budget. - -The caller supplies candidates already ordered coldest-first; this pure rule only -decides how far down that list to cut. It must never over-evict (stop as soon as -the footprint fits) and never promise more than the candidates can free. -""" - -from __future__ import annotations - -from datetime import UTC, datetime - -import pytest - -from app.etl_pipeline.cache.eviction.policy import select_over_budget -from app.etl_pipeline.cache.schemas import EvictionCandidate - -pytestmark = pytest.mark.unit - - -def _candidate(id_: int, size_bytes: int) -> EvictionCandidate: - return EvictionCandidate( - id=id_, - storage_key=f"etl_cache/{id_}.md", - size_bytes=size_bytes, - last_used_at=datetime(2026, 1, 1, tzinfo=UTC), - times_reused=0, - ) - - -def test_over_budget_drops_coldest_until_it_fits(): - # 300 used, budget 100 -> must free >=200. Coldest-first [120, 90, 70]; - # 120+90=210 >=200, so the third (70) is spared. - coldest_first = [_candidate(1, 120), _candidate(2, 90), _candidate(3, 70)] - - chosen = select_over_budget( - coldest_first, current_total_bytes=300, max_total_bytes=100 - ) - - assert [c.id for c in chosen] == [1, 2] - - -@pytest.mark.parametrize("current_total_bytes", [100, 80]) -def test_within_budget_evicts_nothing(current_total_bytes): - # At or under budget there is nothing to free, so no blob is touched. - coldest_first = [_candidate(1, 50), _candidate(2, 50)] - - chosen = select_over_budget( - coldest_first, - current_total_bytes=current_total_bytes, - max_total_bytes=100, - ) - - assert chosen == [] - - -def test_stops_as_soon_as_one_entry_covers_the_overage(): - # Only 10 over budget; the first (cold) entry already frees enough. - coldest_first = [_candidate(1, 40), _candidate(2, 40)] - - chosen = select_over_budget( - coldest_first, current_total_bytes=110, max_total_bytes=100 - ) - - assert [c.id for c in chosen] == [1] - - -def test_returns_all_candidates_when_they_cannot_free_enough(): - # Deficit is 500 but candidates only total 150: return everything available - # rather than looping forever or raising. - coldest_first = [_candidate(1, 100), _candidate(2, 50)] - - chosen = select_over_budget( - coldest_first, current_total_bytes=600, max_total_bytes=100 - ) - - assert [c.id for c in chosen] == [1, 2] diff --git a/surfsense_backend/tests/unit/etl_pipeline/cache/test_parse_key.py b/surfsense_backend/tests/unit/etl_pipeline/cache/test_parse_key.py deleted file mode 100644 index d69e74ee0..000000000 --- a/surfsense_backend/tests/unit/etl_pipeline/cache/test_parse_key.py +++ /dev/null @@ -1,70 +0,0 @@ -"""Content-addressing: equal (bytes + recipe) must map to one storage location. - -This is the dedup guarantee the whole cache rests on -- two users uploading the -same file under the same parser settings have to land on the same object key, and -any change to bytes or recipe has to land somewhere else. -""" - -from __future__ import annotations - -import pytest - -from app.etl_pipeline.cache.schemas import ParseKey -from app.etl_pipeline.cache.storage.object_keys import ( - CACHE_PREFIX, - build_parse_object_key, -) - -pytestmark = pytest.mark.unit - - -def _key(**overrides) -> ParseKey: - base = { - "source_sha256": "a" * 64, - "etl_service": "LLAMACLOUD", - "mode": "basic", - "version": 1, - } - base.update(overrides) - return ParseKey.for_document( - base["source_sha256"], - etl_service=base["etl_service"], - mode=base["mode"], - version=base["version"], - ) - - -def test_same_bytes_and_recipe_produce_the_same_object_key(): - assert build_parse_object_key(_key()) == build_parse_object_key(_key()) - - -def test_different_bytes_produce_different_object_keys(): - assert build_parse_object_key( - _key(source_sha256="a" * 64) - ) != build_parse_object_key(_key(source_sha256="b" * 64)) - - -@pytest.mark.parametrize( - "field, value", - [ - ("etl_service", "DOCLING"), - ("mode", "premium"), - ("version", 2), - ], -) -def test_any_recipe_change_produces_a_different_object_key(field, value): - # Same bytes but a different parser/mode/version must not collide: the recipe - # is part of the identity, so changing it has to re-parse, not reuse. - assert build_parse_object_key(_key()) != build_parse_object_key( - _key(**{field: value}) - ) - - -def test_object_key_is_prefixed_and_sharded_by_source_hash(): - # Shape matters operationally: a dedicated top-level prefix keeps cache blobs - # out of the normal store, and the sha directory groups every recipe variant - # of one file together. - key = _key() - assert build_parse_object_key(key) == ( - f"{CACHE_PREFIX}/{key.source_sha256}/LLAMACLOUD.basic.v1.md" - ) diff --git a/surfsense_backend/tests/unit/indexing_pipeline/cache/conftest.py b/surfsense_backend/tests/unit/indexing_pipeline/cache/conftest.py deleted file mode 100644 index 081dddaa7..000000000 --- a/surfsense_backend/tests/unit/indexing_pipeline/cache/conftest.py +++ /dev/null @@ -1,28 +0,0 @@ -"""Stub the cache package __init__s so unit tests import only pure leaf modules. - -The real ``cache``/``storage``/``eviction``/``persistence`` __init__s eagerly -import the facade, file storage, Celery, and ``app.db`` -- none of which a pure -unit test should need. Turning those packages into bare namespace packages lets -``from app.indexing_pipeline.cache.<leaf> import ...`` resolve the leaf module -without running the heavy __init__. ``schemas`` is left real (it is pure). -""" - -import sys -import types -from pathlib import Path - -_CACHE_DIR = Path(__file__).resolve().parents[4] / "app" / "indexing_pipeline" / "cache" - - -def _stub_namespace_package(dotted: str, fs_dir: Path) -> None: - if dotted in sys.modules: - return - module = types.ModuleType(dotted) - module.__path__ = [str(fs_dir)] - module.__package__ = dotted - sys.modules[dotted] = module - - -_stub_namespace_package("app.indexing_pipeline.cache", _CACHE_DIR) -_stub_namespace_package("app.indexing_pipeline.cache.storage", _CACHE_DIR / "storage") -_stub_namespace_package("app.indexing_pipeline.cache.eviction", _CACHE_DIR / "eviction") diff --git a/surfsense_backend/tests/unit/indexing_pipeline/cache/test_eligibility.py b/surfsense_backend/tests/unit/indexing_pipeline/cache/test_eligibility.py deleted file mode 100644 index 2e488231c..000000000 --- a/surfsense_backend/tests/unit/indexing_pipeline/cache/test_eligibility.py +++ /dev/null @@ -1,28 +0,0 @@ -from app.indexing_pipeline.cache.eligibility import is_embedding_cacheable - - -def test_disabled_cache_is_never_cacheable(): - assert not is_embedding_cacheable( - cache_enabled=False, embedding_model="m", embedding_dim=384 - ) - - -def test_missing_model_is_not_cacheable(): - assert not is_embedding_cacheable( - cache_enabled=True, embedding_model=None, embedding_dim=384 - ) - - -def test_missing_dimension_is_not_cacheable(): - assert not is_embedding_cacheable( - cache_enabled=True, embedding_model="m", embedding_dim=None - ) - assert not is_embedding_cacheable( - cache_enabled=True, embedding_model="m", embedding_dim=0 - ) - - -def test_enabled_with_model_and_dim_is_cacheable(): - assert is_embedding_cacheable( - cache_enabled=True, embedding_model="m", embedding_dim=384 - ) diff --git a/surfsense_backend/tests/unit/indexing_pipeline/cache/test_embedding_key.py b/surfsense_backend/tests/unit/indexing_pipeline/cache/test_embedding_key.py deleted file mode 100644 index ce9c8672d..000000000 --- a/surfsense_backend/tests/unit/indexing_pipeline/cache/test_embedding_key.py +++ /dev/null @@ -1,31 +0,0 @@ -from app.indexing_pipeline.cache.schemas import EmbeddingKey - - -def _key(**overrides) -> EmbeddingKey: - base = { - "markdown_sha256": "a" * 64, - "embedding_model": "openai://text-embedding-3-small", - "embedding_dim": 1536, - "chunker_kind": "hybrid", - "chunker_version": 1, - } - base.update(overrides) - return EmbeddingKey(**base) - - -def test_object_suffix_is_stable(): - assert _key().object_suffix == _key().object_suffix - - -def test_object_suffix_differs_by_model(): - assert _key().object_suffix != _key(embedding_model="local/minilm").object_suffix - - -def test_object_suffix_differs_by_chunker_kind_and_version(): - assert _key().object_suffix != _key(chunker_kind="code").object_suffix - assert _key().object_suffix != _key(chunker_version=2).object_suffix - - -def test_object_suffix_encodes_kind_and_version(): - suffix = _key(chunker_kind="code", chunker_version=3).object_suffix - assert suffix.endswith(".code.v3.emb") diff --git a/surfsense_backend/tests/unit/indexing_pipeline/cache/test_serialization.py b/surfsense_backend/tests/unit/indexing_pipeline/cache/test_serialization.py deleted file mode 100644 index f8cff6355..000000000 --- a/surfsense_backend/tests/unit/indexing_pipeline/cache/test_serialization.py +++ /dev/null @@ -1,54 +0,0 @@ -import numpy as np -import pytest - -from app.indexing_pipeline.cache.schemas import CachedChunk, EmbeddingSet -from app.indexing_pipeline.cache.serialization import deserialize, serialize - - -def _make_set(dim: int, n_chunks: int) -> EmbeddingSet: - rng = np.random.default_rng(0) - return EmbeddingSet( - summary_embedding=rng.random(dim, dtype=np.float64), - chunks=[ - CachedChunk(text=f"chunk {i}\nwith newline", embedding=rng.random(dim)) - for i in range(n_chunks) - ], - ) - - -def test_round_trip_preserves_texts_and_vectors(): - original = _make_set(dim=8, n_chunks=3) - - restored = deserialize(serialize(original)) - - assert [c.text for c in restored.chunks] == [c.text for c in original.chunks] - assert restored.chunk_count == 3 - assert np.allclose( - restored.summary_embedding, original.summary_embedding, atol=1e-6 - ) - for got, want in zip(restored.chunks, original.chunks, strict=True): - assert np.allclose(got.embedding, want.embedding, atol=1e-6) - - -def test_round_trip_with_no_chunks(): - original = _make_set(dim=4, n_chunks=0) - - restored = deserialize(serialize(original)) - - assert restored.chunk_count == 0 - assert restored.summary_embedding.shape[0] == 4 - - -def test_serialize_rejects_mismatched_dimensions(): - bad = EmbeddingSet( - summary_embedding=np.zeros(4, dtype=np.float32), - chunks=[CachedChunk(text="x", embedding=np.zeros(8, dtype=np.float32))], - ) - - with pytest.raises(ValueError): - serialize(bad) - - -def test_deserialize_rejects_foreign_blob(): - with pytest.raises(ValueError): - deserialize(b"not-a-surfsense-blob") diff --git a/surfsense_backend/tests/unit/indexing_pipeline/test_chunk_reconciler.py b/surfsense_backend/tests/unit/indexing_pipeline/test_chunk_reconciler.py deleted file mode 100644 index 7effce840..000000000 --- a/surfsense_backend/tests/unit/indexing_pipeline/test_chunk_reconciler.py +++ /dev/null @@ -1,94 +0,0 @@ -"""reconcile(): diff existing chunk rows against new chunk texts. - -The reconciler decides which rows (and embeddings) survive an edit, which texts -must be embedded, and which rows go away -- purely from content, no DB. -""" - -from __future__ import annotations - -from app.indexing_pipeline.chunk_reconciler import ExistingChunk, reconcile - - -def _existing(*contents: str) -> list[ExistingChunk]: - return [ - ExistingChunk(id=i + 1, content=text, position=i) - for i, text in enumerate(contents) - ] - - -def test_identical_content_keeps_every_row_untouched(): - plan = reconcile(_existing("alpha", "beta", "gamma"), ["alpha", "beta", "gamma"]) - - assert plan.to_embed == [] - assert plan.to_delete == [] - assert plan.reused == [] - - -def test_head_insert_embeds_only_the_new_chunk_and_shifts_the_rest(): - plan = reconcile(_existing("alpha", "beta"), ["intro", "alpha", "beta"]) - - assert plan.to_embed == [(0, "intro")] - assert plan.to_delete == [] - # alpha: position 0 -> 1, beta: 1 -> 2; embeddings untouched. - assert plan.reused == [(1, 1), (2, 2)] - - -def test_middle_edit_swaps_exactly_one_chunk(): - plan = reconcile( - _existing("alpha", "beta", "gamma"), ["alpha", "beta EDITED", "gamma"] - ) - - assert plan.to_embed == [(1, "beta EDITED")] - assert plan.to_delete == [2] - # Neighbours did not move, so no position writes at all. - assert plan.reused == [] - - -def test_removed_chunk_is_deleted_and_followers_shift_up(): - plan = reconcile(_existing("alpha", "beta", "gamma"), ["alpha", "gamma"]) - - assert plan.to_embed == [] - assert plan.to_delete == [2] - assert plan.reused == [(3, 1)] - - -def test_duplicate_texts_pair_up_one_to_one(): - # Two identical boilerplate chunks, only one survives the edit: exactly one - # row is kept and exactly one is deleted -- never both kept or both dropped. - plan = reconcile(_existing("boiler", "boiler", "body"), ["boiler", "body"]) - - assert plan.to_embed == [] - assert plan.to_delete == [2] - assert plan.reused == [(3, 1)] - - -def test_duplicate_growth_embeds_only_the_extra_copy(): - plan = reconcile(_existing("boiler", "body"), ["boiler", "boiler", "body"]) - - assert plan.to_embed == [(1, "boiler")] - assert plan.to_delete == [] - assert plan.reused == [(2, 2)] - - -def test_reorder_becomes_position_updates_with_no_embedding(): - plan = reconcile(_existing("alpha", "beta"), ["beta", "alpha"]) - - assert plan.to_embed == [] - assert plan.to_delete == [] - assert sorted(plan.reused) == [(1, 1), (2, 0)] - - -def test_full_rewrite_replaces_everything(): - plan = reconcile(_existing("alpha", "beta"), ["new one", "new two"]) - - assert plan.to_embed == [(0, "new one"), (1, "new two")] - assert sorted(plan.to_delete) == [1, 2] - assert plan.reused == [] - - -def test_no_existing_chunks_embeds_all(): - plan = reconcile([], ["alpha", "beta"]) - - assert plan.to_embed == [(0, "alpha"), (1, "beta")] - assert plan.to_delete == [] - assert plan.reused == [] diff --git a/surfsense_backend/tests/unit/indexing_pipeline/test_index_batch_parallel.py b/surfsense_backend/tests/unit/indexing_pipeline/test_index_batch_parallel.py index feb7bbc52..3a1b77d90 100644 --- a/surfsense_backend/tests/unit/indexing_pipeline/test_index_batch_parallel.py +++ b/surfsense_backend/tests/unit/indexing_pipeline/test_index_batch_parallel.py @@ -54,7 +54,7 @@ async def test_index_calls_embed_and_chunk_via_to_thread( mock_chunk_hybrid = MagicMock(return_value=["chunk1"]) mock_chunk_hybrid.__name__ = "chunk_text_hybrid" monkeypatch.setattr( - "app.indexing_pipeline.cache.cached_indexing.chunk_text_hybrid", + "app.indexing_pipeline.indexing_pipeline_service.chunk_text_hybrid", mock_chunk_hybrid, ) mock_embed = MagicMock( @@ -62,21 +62,13 @@ async def test_index_calls_embed_and_chunk_via_to_thread( ) mock_embed.__name__ = "embed_texts" monkeypatch.setattr( - "app.indexing_pipeline.cache.cached_indexing.embed_texts", + "app.indexing_pipeline.indexing_pipeline_service.embed_texts", mock_embed, ) + # Bypass set_committed_value, which requires a real ORM instance (not MagicMock). monkeypatch.setattr( - pipeline, - "_load_existing_chunks", - AsyncMock(return_value=[]), - ) - - async def _noop_persist(_session, doc, *_args, **_kwargs): - doc.status = DocumentStatus.ready() - - monkeypatch.setattr( - "app.indexing_pipeline.indexing_pipeline_service.persist_scratch_index", - _noop_persist, + "app.indexing_pipeline.indexing_pipeline_service.attach_chunks_to_document", + MagicMock(), ) connector_doc = make_connector_document( @@ -110,31 +102,22 @@ async def test_non_code_documents_use_hybrid_chunker( mock_chunk_hybrid = MagicMock(return_value=["chunk1"]) mock_chunk_hybrid.__name__ = "chunk_text_hybrid" monkeypatch.setattr( - "app.indexing_pipeline.cache.cached_indexing.chunk_text_hybrid", + "app.indexing_pipeline.indexing_pipeline_service.chunk_text_hybrid", mock_chunk_hybrid, ) mock_chunk_code = MagicMock(return_value=["chunk1"]) mock_chunk_code.__name__ = "chunk_text" monkeypatch.setattr( - "app.indexing_pipeline.cache.cached_indexing.chunk_text", + "app.indexing_pipeline.indexing_pipeline_service.chunk_text", mock_chunk_code, ) monkeypatch.setattr( - "app.indexing_pipeline.cache.cached_indexing.embed_texts", + "app.indexing_pipeline.indexing_pipeline_service.embed_texts", MagicMock(side_effect=lambda texts: [[0.1] * _EMBEDDING_DIM for _ in texts]), ) monkeypatch.setattr( - pipeline, - "_load_existing_chunks", - AsyncMock(return_value=[]), - ) - - async def _noop_persist(_session, doc, *_args, **_kwargs): - doc.status = DocumentStatus.ready() - - monkeypatch.setattr( - "app.indexing_pipeline.indexing_pipeline_service.persist_scratch_index", - _noop_persist, + "app.indexing_pipeline.indexing_pipeline_service.attach_chunks_to_document", + MagicMock(), ) connector_doc = make_connector_document( diff --git a/surfsense_backend/tests/unit/indexing_pipeline/test_persist_scratch_index.py b/surfsense_backend/tests/unit/indexing_pipeline/test_persist_scratch_index.py deleted file mode 100644 index 026c3161d..000000000 --- a/surfsense_backend/tests/unit/indexing_pipeline/test_persist_scratch_index.py +++ /dev/null @@ -1,65 +0,0 @@ -from unittest.mock import AsyncMock, MagicMock - -import pytest - -from app.db import Chunk, Document, DocumentStatus -from app.indexing_pipeline.document_persistence import persist_scratch_index - -pytestmark = pytest.mark.unit - - -def _make_document(doc_id: int = 1) -> Document: - document = MagicMock(spec=Document) - document.id = doc_id - document.content = None - document.status = DocumentStatus.processing() - return document - - -@pytest.mark.asyncio -async def test_persist_scratch_index_batches_commits(monkeypatch): - monkeypatch.setattr( - "app.indexing_pipeline.document_persistence.set_committed_value", - lambda *_args, **_kwargs: None, - ) - session = MagicMock() - session.commit = AsyncMock() - document = _make_document() - chunks = [Chunk(content=f"c{i}", embedding=[0.1], position=i) for i in range(5)] - perf = MagicMock() - - await persist_scratch_index( - session, - document, - "body", - chunks, - batch_size=2, - perf=perf, - ) - - assert session.commit.await_count == 5 - assert document.status == DocumentStatus.ready() - - -@pytest.mark.asyncio -async def test_persist_scratch_index_empty_chunks(monkeypatch): - monkeypatch.setattr( - "app.indexing_pipeline.document_persistence.set_committed_value", - lambda *_args, **_kwargs: None, - ) - session = MagicMock() - session.commit = AsyncMock() - document = _make_document() - perf = MagicMock() - - await persist_scratch_index( - session, - document, - "body", - [], - batch_size=200, - perf=perf, - ) - - assert session.commit.await_count == 2 - assert document.status == DocumentStatus.ready() diff --git a/surfsense_backend/tests/unit/notifications/service/messages/test_document_processing.py b/surfsense_backend/tests/unit/notifications/service/messages/test_document_processing.py index 9fc93d3ed..2f0a6a9d3 100644 --- a/surfsense_backend/tests/unit/notifications/service/messages/test_document_processing.py +++ b/surfsense_backend/tests/unit/notifications/service/messages/test_document_processing.py @@ -61,21 +61,3 @@ def test_completion_failure(): assert message == "Processing failed: bad" assert status == "failed" assert meta["processing_stage"] == "failed" - - -def test_started_title_truncates_long_name(): - """Very long filenames are truncated to fit the notification title column.""" - long_name = "a" * 250 - title = msg.started_title(long_name) - assert len(title) <= 200 - assert title.startswith("Processing: ") - assert title.endswith("...") - - -def test_completion_truncates_long_name(): - """Completion titles truncate long document names.""" - long_name = "b" * 250 - title, _, _, _ = msg.completion(long_name, document_id=1) - assert len(title) <= 200 - assert title.startswith("Ready: ") - assert title.endswith("...") diff --git a/surfsense_backend/tests/unit/notifications/service/messages/test_insufficient_credits.py b/surfsense_backend/tests/unit/notifications/service/messages/test_insufficient_credits.py deleted file mode 100644 index c5366cce2..000000000 --- a/surfsense_backend/tests/unit/notifications/service/messages/test_insufficient_credits.py +++ /dev/null @@ -1,38 +0,0 @@ -"""Unit tests for insufficient-credits presentation logic.""" - -from __future__ import annotations - -import pytest - -from app.notifications.service.messages import insufficient_credits as msg - -pytestmark = pytest.mark.unit - - -def test_operation_id_encodes_search_space(): - """The operation id embeds the search space id.""" - assert msg.operation_id("doc.pdf", 9).startswith("insufficient_credits_9_") - - -def test_summary_title_and_message(): - """The summary states the document and the required/available credit.""" - title, message = msg.summary( - "short.pdf", balance_micros=250_000, required_micros=1_000_000 - ) - assert title == "Insufficient credits: short.pdf" - assert message == ( - "This document costs about $1.00 to process but you have " - "$0.25 of credit left. Add more credits to continue." - ) - - -def test_summary_clamps_negative_balance_to_zero(): - """A negative balance is clamped to $0.00 in the message.""" - _, message = msg.summary("doc.pdf", balance_micros=-5_000, required_micros=500_000) - assert "$0.00 of credit left" in message - - -def test_summary_truncates_long_name(): - """A long document name is truncated in the title.""" - title, _ = msg.summary("a" * 50, balance_micros=0, required_micros=1_000) - assert title == f"Insufficient credits: {'a' * 40}..." diff --git a/surfsense_backend/tests/unit/notifications/service/messages/test_page_limit.py b/surfsense_backend/tests/unit/notifications/service/messages/test_page_limit.py new file mode 100644 index 000000000..606e985f2 --- /dev/null +++ b/surfsense_backend/tests/unit/notifications/service/messages/test_page_limit.py @@ -0,0 +1,32 @@ +"""Unit tests for page-limit presentation logic.""" + +from __future__ import annotations + +import pytest + +from app.notifications.service.messages import page_limit as msg + +pytestmark = pytest.mark.unit + + +def test_operation_id_encodes_search_space(): + """The operation id embeds the search space id.""" + assert msg.operation_id("doc.pdf", 9).startswith("page_limit_9_") + + +def test_summary_title_and_message(): + """The summary states the document and the used/limit page counts.""" + title, message = msg.summary( + "short.pdf", pages_used=95, pages_limit=100, pages_to_add=10 + ) + assert title == "Page limit exceeded: short.pdf" + assert message == ( + "This document has ~10 page(s) but you've used 95/100 pages. " + "Upgrade to process more documents." + ) + + +def test_summary_truncates_long_name(): + """A long document name is truncated in the title.""" + title, _ = msg.summary("a" * 50, pages_used=1, pages_limit=2, pages_to_add=1) + assert title == f"Page limit exceeded: {'a' * 40}..." diff --git a/surfsense_backend/tests/unit/notifications/service/messages/test_text.py b/surfsense_backend/tests/unit/notifications/service/messages/test_text.py index 183779a9c..bf3611607 100644 --- a/surfsense_backend/tests/unit/notifications/service/messages/test_text.py +++ b/surfsense_backend/tests/unit/notifications/service/messages/test_text.py @@ -4,7 +4,7 @@ from __future__ import annotations import pytest -from app.notifications.service.messages.text import format_title, truncate +from app.notifications.service.messages.text import truncate pytestmark = pytest.mark.unit @@ -22,22 +22,3 @@ def test_truncate_keeps_text_at_exact_limit(): def test_truncate_appends_ellipsis_when_over_limit(): """Text past the limit is cut to the limit and gains an ellipsis.""" assert truncate("a" * 41, 40) == "a" * 40 + "..." - - -def test_format_title_keeps_short_name(): - """Short names are joined to the prefix without truncation.""" - assert format_title("Ready: ", "report.pdf") == "Ready: report.pdf" - - -def test_format_title_truncates_long_name(): - """Long names are truncated so the full title fits the DB limit.""" - long_name = "a" * 250 - title = format_title("Processing: ", long_name) - assert len(title) == 200 - assert title.startswith("Processing: ") - assert title.endswith("...") - - -def test_format_title_respects_custom_max_length(): - """A custom max length caps the title.""" - assert len(format_title("Go: ", "hello world", max_length=10)) == 10 diff --git a/surfsense_backend/tests/unit/observability/test_helpers.py b/surfsense_backend/tests/unit/observability/test_helpers.py index eafb8b626..ae60c1939 100644 --- a/surfsense_backend/tests/unit/observability/test_helpers.py +++ b/surfsense_backend/tests/unit/observability/test_helpers.py @@ -31,10 +31,10 @@ def _disable_otel(monkeypatch: pytest.MonkeyPatch): ("process_file_upload_with_document", "process"), ("process_circleback_meeting", "process"), ("generate_video_presentation", "generate"), - ("podcast.draft_transcript", "podcast.draft"), - ("podcast.render_audio", "podcast.render"), + ("generate_content_podcast", "generate"), ("cleanup_stale_indexing_notifications", "cleanup"), - ("reconcile_pending_stripe_credit_purchases", "reconcile"), + ("reconcile_pending_stripe_page_purchases", "reconcile"), + ("reconcile_pending_stripe_token_purchases", "reconcile"), ("check_periodic_schedules", "check"), ("ai_sort_search_space", "ai"), ("index_notion_pages", "index"), diff --git a/surfsense_backend/tests/unit/podcasts/conftest.py b/surfsense_backend/tests/unit/podcasts/conftest.py deleted file mode 100644 index c77eb1cc6..000000000 --- a/surfsense_backend/tests/unit/podcasts/conftest.py +++ /dev/null @@ -1,75 +0,0 @@ -"""Shared builders for podcast unit tests. - -These tests exercise pure logic through public interfaces with no test doubles: -the brief and transcript factories build valid aggregates so each test states -only the fields it cares about. Stateful, persistence-backed paths (the lifecycle -service, the Celery task bodies) are covered by the integration suite against a -real database. -""" - -from __future__ import annotations - -import pytest - -from app.podcasts.schemas import ( - DurationTarget, - PodcastSpec, - PodcastStyle, - SpeakerRole, - SpeakerSpec, - Transcript, - TranscriptTurn, -) - - -@pytest.fixture -def make_spec(): - """Factory for a valid :class:`PodcastSpec`; override only what matters.""" - - def _make( - *, - language: str = "en", - style: PodcastStyle = PodcastStyle.CONVERSATIONAL, - speakers: list[SpeakerSpec] | None = None, - min_seconds: int = 600, - max_seconds: int = 1200, - focus: str | None = None, - ) -> PodcastSpec: - if speakers is None: - speakers = [ - SpeakerSpec( - slot=0, - name="Host", - role=SpeakerRole.HOST, - voice_id="kokoro:am_adam", - ), - SpeakerSpec( - slot=1, - name="Guest", - role=SpeakerRole.GUEST, - voice_id="kokoro:af_bella", - ), - ] - return PodcastSpec( - language=language, - style=style, - speakers=speakers, - duration=DurationTarget(min_seconds=min_seconds, max_seconds=max_seconds), - focus=focus, - ) - - return _make - - -@pytest.fixture -def make_transcript(): - """Factory for a valid :class:`Transcript`.""" - - def _make(turns: list[tuple[int, str]] | None = None) -> Transcript: - if turns is None: - turns = [(0, "Welcome to the show."), (1, "Glad to be here.")] - return Transcript( - turns=[TranscriptTurn(speaker=slot, text=text) for slot, text in turns] - ) - - return _make diff --git a/surfsense_backend/tests/unit/podcasts/test_api_schemas.py b/surfsense_backend/tests/unit/podcasts/test_api_schemas.py deleted file mode 100644 index 41664ac64..000000000 --- a/surfsense_backend/tests/unit/podcasts/test_api_schemas.py +++ /dev/null @@ -1,94 +0,0 @@ -"""The API read model the frontend renders from. - -``PodcastDetail.of`` maps a stored podcast row to the detail view and action -responses: it exposes the deserialized brief and transcript and a simple -``has_audio`` flag the client can't derive from the published Zero columns. Each -test builds a row in one lifecycle shape and asserts the mapping reflects it. -""" - -from __future__ import annotations - -from datetime import UTC, datetime - -import pytest - -from app.podcasts.api.schemas import PodcastDetail -from app.podcasts.persistence import Podcast, PodcastStatus - -pytestmark = pytest.mark.unit - - -def _podcast(*, status: PodcastStatus = PodcastStatus.PENDING, **columns) -> Podcast: - """A persisted-looking row: the id and created_at a saved podcast would carry.""" - podcast = Podcast( - title="Episode", - search_space_id=3, - status=status, - spec_version=1, - **columns, - ) - podcast.id = 1 - podcast.created_at = datetime.now(UTC) - return podcast - - -def test_a_fresh_podcast_exposes_no_brief_transcript_or_audio(): - detail = PodcastDetail.of(_podcast()) - - assert detail.status == PodcastStatus.PENDING - assert detail.spec is None - assert detail.transcript is None - assert detail.has_audio is False - - -def test_an_awaiting_brief_podcast_exposes_the_deserialized_brief(make_spec): - podcast = _podcast( - status=PodcastStatus.AWAITING_BRIEF, - spec=make_spec(language="fr").model_dump(mode="json"), - ) - - detail = PodcastDetail.of(podcast) - - assert detail.spec is not None - assert detail.spec.language == "fr" - - -def test_a_legacy_episode_still_exposes_its_transcript_and_audio(): - # Pre-rework rows stored [{speaker_id, dialog}] and a local file path; - # they must keep flowing through the new read model, not fail validation. - podcast = _podcast( - status=PodcastStatus.READY, - podcast_transcript=[ - {"speaker_id": 0, "dialog": "Welcome back."}, - {"speaker_id": 1, "dialog": "Glad to be here."}, - ], - file_location="/var/old/podcast.mp3", - ) - - detail = PodcastDetail.of(podcast) - - assert detail.has_audio is True - assert detail.transcript is not None - assert [(turn.speaker, turn.text) for turn in detail.transcript.turns] == [ - (0, "Welcome back."), - (1, "Glad to be here."), - ] - - -def test_a_ready_podcast_reports_available_audio(make_spec, make_transcript): - podcast = _podcast( - status=PodcastStatus.READY, - spec=make_spec().model_dump(mode="json"), - podcast_transcript=make_transcript().model_dump(mode="json"), - storage_backend="local", - storage_key="k", - duration_seconds=120, - ) - - detail = PodcastDetail.of(podcast) - - assert detail.status == PodcastStatus.READY - assert detail.has_audio is True - assert detail.duration_seconds == 120 - assert detail.transcript is not None - assert detail.error is None diff --git a/surfsense_backend/tests/unit/podcasts/test_renderer.py b/surfsense_backend/tests/unit/podcasts/test_renderer.py deleted file mode 100644 index bb7b8f181..000000000 --- a/surfsense_backend/tests/unit/podcasts/test_renderer.py +++ /dev/null @@ -1,94 +0,0 @@ -"""The renderer refuses an inconsistent spec/transcript before spending work. - -Full synthesis-and-merge needs FFmpeg and a real provider, so it belongs to an -integration test. What is pure and worth securing here is the renderer's -contract that it validates the transcript against the brief up front: a turn -naming an unknown speaker, or a speaker naming an unknown voice, fails loudly -rather than producing silent or wrong audio. The TTS provider is an external -port, faked here and never expected to be called on these paths. -""" - -from __future__ import annotations - -from pathlib import Path - -import pytest - -from app.podcasts.rendering import PodcastRenderer, RenderError -from app.podcasts.schemas import ( - DurationTarget, - PodcastSpec, - SpeakerRole, - SpeakerSpec, - Transcript, - TranscriptTurn, -) -from app.podcasts.tts import SynthesizedAudio -from app.podcasts.voices import CatalogVoice, TtsProvider, VoiceCatalog, VoiceGender - -pytestmark = pytest.mark.unit - - -class _UnusedTTS: - """A TTS port double that fails the test if it is ever asked to speak. - - These behaviors must short-circuit before synthesis, so any call here is a - regression. - """ - - @property - def container(self) -> str: - return "mp3" - - async def synthesize(self, _request): # pragma: no cover - must not run - raise AssertionError("synthesis should not be attempted") - return SynthesizedAudio(data=b"", container="mp3") - - -def _catalog_with(voice_id: str) -> VoiceCatalog: - return VoiceCatalog( - [ - CatalogVoice( - voice_id=voice_id, - provider=TtsProvider.KOKORO, - language="en-US", - display_name=voice_id, - gender=VoiceGender.MALE, - native_ref="am_adam", - ) - ] - ) - - -def _spec(voice_id: str) -> PodcastSpec: - return PodcastSpec( - language="en", - speakers=[ - SpeakerSpec(slot=0, name="Host", role=SpeakerRole.HOST, voice_id=voice_id) - ], - duration=DurationTarget(min_seconds=300, max_seconds=600), - ) - - -async def test_render_rejects_a_turn_for_an_unknown_speaker(tmp_path): - renderer = PodcastRenderer( - tts=_UnusedTTS(), catalog=_catalog_with("kokoro:am_adam") - ) - transcript = Transcript(turns=[TranscriptTurn(speaker=5, text="Who am I?")]) - - with pytest.raises(RenderError): - await renderer.render( - spec=_spec("kokoro:am_adam"), transcript=transcript, workdir=Path(tmp_path) - ) - - -async def test_render_rejects_a_speaker_whose_voice_is_not_in_the_catalog(tmp_path): - renderer = PodcastRenderer( - tts=_UnusedTTS(), catalog=_catalog_with("kokoro:am_adam") - ) - transcript = Transcript(turns=[TranscriptTurn(speaker=0, text="Hello.")]) - - with pytest.raises(RenderError): - await renderer.render( - spec=_spec("kokoro:ghost"), transcript=transcript, workdir=Path(tmp_path) - ) diff --git a/surfsense_backend/tests/unit/podcasts/test_resolution.py b/surfsense_backend/tests/unit/podcasts/test_resolution.py deleted file mode 100644 index aab44f8fb..000000000 --- a/surfsense_backend/tests/unit/podcasts/test_resolution.py +++ /dev/null @@ -1,98 +0,0 @@ -"""Default language and voice selection for a fresh brief. - -Resolution is what lets most briefs need no edits: it proposes a sensible -language and a distinct voice per speaker. These tests state the policy -("reuse what the user last chose, else English"; "two speakers should sound -like two people") through the public resolver functions and the real catalog. -We never guess the language from source content. -""" - -from __future__ import annotations - -import pytest - -from app.podcasts.resolution import ( - DEFAULT_LANGUAGE, - LanguageContext, - VoiceResolutionError, - resolve_language, - resolve_voices, -) -from app.podcasts.voices import TtsProvider, get_voice_catalog - -pytestmark = pytest.mark.unit - - -def test_last_used_language_is_reused(): - context = LanguageContext(last_used="fr") - assert resolve_language(context) == "fr" - - -def test_first_time_user_with_no_signal_gets_the_default(): - assert resolve_language(LanguageContext()) == DEFAULT_LANGUAGE - - -def test_two_speakers_get_distinct_voices(): - """A two-speaker episode should not voice both with the same person.""" - catalog = get_voice_catalog() - voices = resolve_voices( - catalog=catalog, provider=TtsProvider.KOKORO, language="en", speaker_count=2 - ) - assert len(voices) == 2 - assert voices[0].voice_id != voices[1].voice_id - - -def test_a_users_preferred_voice_is_reused_when_still_valid(): - catalog = get_voice_catalog() - voices = resolve_voices( - catalog=catalog, - provider=TtsProvider.KOKORO, - language="en", - speaker_count=2, - preferred=["kokoro:af_bella"], - ) - assert voices[0].voice_id == "kokoro:af_bella" - - -def test_a_preferred_voice_invalid_for_the_language_is_replaced(): - """A stale preference (wrong provider/language) is silently dropped.""" - catalog = get_voice_catalog() - voices = resolve_voices( - catalog=catalog, - provider=TtsProvider.KOKORO, - language="en", - speaker_count=1, - preferred=["kokoro:does-not-exist"], - ) - assert voices[0].voice_id in { - v.voice_id for v in catalog.for_provider(TtsProvider.KOKORO) - } - - -def test_resolution_fails_when_no_voice_speaks_the_language(): - """If a provider can't speak the language at all, that is surfaced loudly.""" - catalog = get_voice_catalog() - with pytest.raises(VoiceResolutionError): - resolve_voices( - catalog=catalog, - provider=TtsProvider.KOKORO, - language="xx", - speaker_count=1, - ) - - -def test_every_speaker_is_assigned_even_when_voices_run_out(): - """With one available voice, both speakers still get one rather than failing.""" - catalog = get_voice_catalog() - voices = resolve_voices( - catalog=catalog, provider=TtsProvider.KOKORO, language="fr", speaker_count=2 - ) - assert len(voices) == 2 - - -def test_speaker_count_must_be_positive(): - catalog = get_voice_catalog() - with pytest.raises(ValueError): - resolve_voices( - catalog=catalog, provider=TtsProvider.KOKORO, language="en", speaker_count=0 - ) diff --git a/surfsense_backend/tests/unit/podcasts/test_spec.py b/surfsense_backend/tests/unit/podcasts/test_spec.py deleted file mode 100644 index 77e720286..000000000 --- a/surfsense_backend/tests/unit/podcasts/test_spec.py +++ /dev/null @@ -1,170 +0,0 @@ -"""The brief and transcript contracts. - -A brief is what a user approves before any tokens or audio are spent, so its -validation rules are real behavior: they are the guardrails that keep a -nonsensical or ambiguous brief from ever reaching the expensive stages. These -tests pin those rules through construction of the public Pydantic models. -""" - -from __future__ import annotations - -import pytest -from pydantic import ValidationError - -from app.podcasts.schemas import ( - DurationTarget, - PodcastSpec, - PodcastStyle, - SpeakerRole, - SpeakerSpec, - Transcript, - TranscriptTurn, - normalize_language_tag, -) - -pytestmark = pytest.mark.unit - - -def _speaker(slot: int, voice_id: str = "kokoro:am_adam") -> SpeakerSpec: - return SpeakerSpec( - slot=slot, name=f"Speaker {slot}", role=SpeakerRole.HOST, voice_id=voice_id - ) - - -@pytest.mark.parametrize( - ("raw", "expected"), - [ - ("EN", "en"), - ("en-US", "en-US"), - ("PT-BR", "pt-BR"), - (" fr ", "fr"), - ], -) -def test_language_is_normalized_to_canonical_form(raw, expected): - """The primary subtag is lowercased and surrounding space trimmed.""" - assert normalize_language_tag(raw) == expected - - -@pytest.mark.parametrize("invalid", ["", "e", "english!", "123", "en_US"]) -def test_invalid_language_tags_are_rejected(invalid): - """Tags that are not BCP-47-shaped never reach a brief.""" - with pytest.raises(ValueError): - normalize_language_tag(invalid) - - -def test_spec_normalizes_its_language_on_construction(): - """A brief stores a canonical language regardless of how it was entered.""" - spec = PodcastSpec( - language="EN-us", - speakers=[_speaker(0)], - duration=DurationTarget(min_seconds=300, max_seconds=600), - ) - assert spec.language == "en-us" - - -def test_speakers_must_have_unique_slots(): - """Slots are the join key to transcript turns, so duplicates are invalid.""" - with pytest.raises(ValidationError): - PodcastSpec( - language="en", - speakers=[_speaker(0), _speaker(0, voice_id="kokoro:af_bella")], - duration=DurationTarget(min_seconds=300, max_seconds=600), - ) - - -def test_a_brief_needs_at_least_one_speaker(): - with pytest.raises(ValidationError): - PodcastSpec( - language="en", - speakers=[], - duration=DurationTarget(min_seconds=300, max_seconds=600), - ) - - -def test_a_monologue_brief_carries_exactly_one_speaker(): - spec = PodcastSpec( - language="en", - style=PodcastStyle.MONOLOGUE, - speakers=[_speaker(0)], - duration=DurationTarget(min_seconds=300, max_seconds=600), - ) - assert spec.style is PodcastStyle.MONOLOGUE - - -def test_a_monologue_brief_rejects_multiple_speakers(): - """One voice is what 'monologue' means; a second speaker is a user error.""" - with pytest.raises(ValidationError): - PodcastSpec( - language="en", - style=PodcastStyle.MONOLOGUE, - speakers=[_speaker(0), _speaker(1, voice_id="kokoro:af_bella")], - duration=DurationTarget(min_seconds=300, max_seconds=600), - ) - - -def test_duration_rejects_an_inverted_range(): - """A max below the min is a user error caught at the brief gate.""" - with pytest.raises(ValidationError): - DurationTarget(min_seconds=1200, max_seconds=600) - - -def test_duration_midpoint_is_where_drafting_aims(): - assert DurationTarget(min_seconds=600, max_seconds=1200).midpoint_seconds == 900 - assert DurationTarget(min_seconds=600, max_seconds=1200).midpoint_minutes == 15 - - -def test_duration_loads_legacy_minute_fields_from_json(): - duration = DurationTarget.model_validate({"min_minutes": 10, "max_minutes": 20}) - assert duration.min_seconds == 600 - assert duration.max_seconds == 1200 - - -def test_blank_focus_becomes_absent(): - """Whitespace-only steer is treated as no steer.""" - spec = PodcastSpec( - language="en", - speakers=[_speaker(0)], - duration=DurationTarget(min_seconds=300, max_seconds=600), - focus=" ", - ) - assert spec.focus is None - - -def test_speaker_for_returns_the_speaker_bound_to_a_slot(): - spec = PodcastSpec( - language="en", - speakers=[_speaker(0), _speaker(1, voice_id="kokoro:af_bella")], - duration=DurationTarget(min_seconds=300, max_seconds=600), - ) - assert spec.speaker_for(1).voice_id == "kokoro:af_bella" - - -def test_speaker_for_raises_when_no_speaker_matches(): - spec = PodcastSpec( - language="en", - speakers=[_speaker(0)], - duration=DurationTarget(min_seconds=300, max_seconds=600), - ) - with pytest.raises(KeyError): - spec.speaker_for(99) - - -def test_transcript_word_count_sums_spoken_words(): - """Word count is what drafting checks runtime against, so it must be exact.""" - transcript = Transcript( - turns=[ - TranscriptTurn(speaker=0, text="hello there world"), - TranscriptTurn(speaker=1, text="one two"), - ] - ) - assert transcript.word_count == 5 - - -def test_blank_transcript_turns_are_rejected(): - with pytest.raises(ValidationError): - TranscriptTurn(speaker=0, text=" ") - - -def test_a_transcript_needs_at_least_one_turn(): - with pytest.raises(ValidationError): - Transcript(turns=[]) diff --git a/surfsense_backend/tests/unit/podcasts/test_structured.py b/surfsense_backend/tests/unit/podcasts/test_structured.py deleted file mode 100644 index 8d7b2226a..000000000 --- a/surfsense_backend/tests/unit/podcasts/test_structured.py +++ /dev/null @@ -1,68 +0,0 @@ -"""Parsing a model's reply into a structured shape. - -Agent LLMs wrap JSON in prose and markdown fences. ``invoke_json`` exists so -every generation node tolerates that the same way. The LLM is an external -boundary, so it is faked with a canned reply; the behavior under test is the -parsing, not the model. -""" - -from __future__ import annotations - -import pytest -from pydantic import BaseModel - -from app.podcasts.generation.structured import StructuredOutputError, invoke_json - -pytestmark = pytest.mark.unit - - -class _Shape(BaseModel): - name: str - count: int - - -class _CannedLLM: - """A TTS-free stand-in for the chat model: replies with one fixed string.""" - - def __init__(self, reply: str) -> None: - self._reply = reply - - async def ainvoke(self, _messages): - return SimpleReply(self._reply) - - -class SimpleReply: - def __init__(self, content: str) -> None: - self.content = content - - -async def _parse(reply: str) -> _Shape: - return await invoke_json(_CannedLLM(reply), [], _Shape) - - -async def test_parses_a_clean_json_reply(): - shape = await _parse('{"name": "alpha", "count": 3}') - assert shape == _Shape(name="alpha", count=3) - - -async def test_parses_json_wrapped_in_a_markdown_fence(): - reply = '```json\n{"name": "beta", "count": 7}\n```' - shape = await _parse(reply) - assert shape == _Shape(name="beta", count=7) - - -async def test_extracts_json_embedded_in_prose(): - """Reasoning models prepend/append chatter around the object.""" - reply = 'Sure, here you go: {"name": "gamma", "count": 1} — hope that helps!' - shape = await _parse(reply) - assert shape == _Shape(name="gamma", count=1) - - -async def test_raises_when_there_is_no_json_object(): - with pytest.raises(StructuredOutputError): - await _parse("I could not produce that.") - - -async def test_raises_when_the_json_does_not_match_the_shape(): - with pytest.raises(StructuredOutputError): - await _parse('{"name": "delta"}') diff --git a/surfsense_backend/tests/unit/podcasts/test_voice_catalog.py b/surfsense_backend/tests/unit/podcasts/test_voice_catalog.py deleted file mode 100644 index d120d4bfc..000000000 --- a/surfsense_backend/tests/unit/podcasts/test_voice_catalog.py +++ /dev/null @@ -1,158 +0,0 @@ -"""The voice catalog and provider identification. - -The catalog is the single source of truth for which voices exist; resolution, -the API picker, and the renderer all depend on its lookups behaving correctly. -These tests build a small catalog of their own so they assert on the lookup -behavior, not on which specific voices ship. -""" - -from __future__ import annotations - -import pytest - -from app.podcasts.voices import ( - ANY_LANGUAGE, - CatalogVoice, - TtsProvider, - VoiceCatalog, - VoiceGender, - provider_from_service, -) - -pytestmark = pytest.mark.unit - - -def _voice( - voice_id: str, - *, - provider: TtsProvider = TtsProvider.KOKORO, - language: str = "en-US", - gender: VoiceGender = VoiceGender.MALE, -) -> CatalogVoice: - return CatalogVoice( - voice_id=voice_id, - provider=provider, - language=language, - display_name=voice_id, - gender=gender, - native_ref=voice_id, - ) - - -def test_for_provider_returns_only_that_providers_voices(): - catalog = VoiceCatalog( - [ - _voice("k1", provider=TtsProvider.KOKORO), - _voice("o1", provider=TtsProvider.OPENAI), - ] - ) - assert [v.voice_id for v in catalog.for_provider(TtsProvider.KOKORO)] == ["k1"] - - -def test_for_language_matches_on_the_primary_subtag(): - """A request for 'en' should match an 'en-US' voice (region-insensitive).""" - catalog = VoiceCatalog([_voice("k1", language="en-US")]) - assert [v.voice_id for v in catalog.for_language(TtsProvider.KOKORO, "en")] == [ - "k1" - ] - - -def test_for_language_excludes_other_languages(): - catalog = VoiceCatalog([_voice("k1", language="en-US")]) - assert catalog.for_language(TtsProvider.KOKORO, "fr") == [] - - -def test_an_any_language_voice_speaks_every_language(): - """Provider-agnostic voices (e.g. OpenAI) match whatever the text is in.""" - voice = _voice("o1", provider=TtsProvider.OPENAI, language=ANY_LANGUAGE) - assert voice.speaks("ja") - assert voice.speaks("pt-BR") - - -def test_supports_language_reports_availability(): - catalog = VoiceCatalog([_voice("k1", language="en-US")]) - assert catalog.supports_language(TtsProvider.KOKORO, "en") - assert not catalog.supports_language(TtsProvider.KOKORO, "de") - - -def test_offerable_languages_for_a_concrete_roster_are_its_tags_only(): - """A provider whose voices are language-bound offers exactly those tags.""" - catalog = VoiceCatalog( - [ - _voice("k1", language="en-US"), - _voice("k2", language="fr"), - _voice("k3", language="fr"), - ] - ) - - offering = catalog.offerable_languages(TtsProvider.KOKORO) - - assert offering.languages == ["en-US", "fr"] - assert offering.allows_custom is False - - -def test_a_wildcard_roster_offers_the_curated_languages_and_custom_entry(): - """Voices that speak anything can't enumerate languages themselves, so the - catalog offers the curated common list and invites free entry.""" - catalog = VoiceCatalog( - [_voice("o1", provider=TtsProvider.OPENAI, language=ANY_LANGUAGE)] - ) - - offering = catalog.offerable_languages(TtsProvider.OPENAI) - - assert {"en", "fr", "sw", "hi", "zh"} <= set(offering.languages) - assert offering.allows_custom is True - - -def test_a_mixed_roster_offers_the_union_of_concrete_and_curated(): - catalog = VoiceCatalog( - [ - _voice("v1", provider=TtsProvider.VERTEX_AI, language="en-GB"), - _voice("v2", provider=TtsProvider.VERTEX_AI, language=ANY_LANGUAGE), - ] - ) - - offering = catalog.offerable_languages(TtsProvider.VERTEX_AI) - - assert "en-GB" in offering.languages - assert "fr" in offering.languages - assert offering.allows_custom is True - - -def test_a_provider_with_no_voices_offers_nothing(): - catalog = VoiceCatalog([_voice("k1")]) - - offering = catalog.offerable_languages(TtsProvider.OPENAI) - - assert offering.languages == [] - assert offering.allows_custom is False - - -def test_get_raises_for_an_unknown_voice(): - catalog = VoiceCatalog([_voice("k1")]) - with pytest.raises(KeyError): - catalog.get("nope") - - -def test_a_catalog_rejects_duplicate_voice_ids(): - """Stored ids must be unique so a brief's voice_id resolves unambiguously.""" - with pytest.raises(ValueError): - VoiceCatalog([_voice("dup"), _voice("dup")]) - - -@pytest.mark.parametrize( - ("service", "expected"), - [ - ("openai/tts-1", TtsProvider.OPENAI), - ("azure/neural", TtsProvider.AZURE), - ("vertex_ai/some-model", TtsProvider.VERTEX_AI), - ("local/kokoro", TtsProvider.KOKORO), - ], -) -def test_provider_is_identified_from_the_config_string(service, expected): - assert provider_from_service(service) == expected - - -def test_unknown_provider_prefix_is_rejected(): - with pytest.raises(ValueError): - provider_from_service("madeup/model") diff --git a/surfsense_backend/tests/unit/routes/test_byok_supports_image_input.py b/surfsense_backend/tests/unit/routes/test_byok_supports_image_input.py new file mode 100644 index 000000000..c9f18d77d --- /dev/null +++ b/surfsense_backend/tests/unit/routes/test_byok_supports_image_input.py @@ -0,0 +1,110 @@ +"""Unit tests for ``supports_image_input`` derivation on BYOK chat config +endpoints (``GET /new-llm-configs`` list, ``GET /new-llm-configs/{id}``). + +There is no DB column for ``supports_image_input`` on +``NewLLMConfig`` — the value is resolved at the API boundary by +``derive_supports_image_input`` so the new-chat selector / streaming +task can read the same field shape regardless of source (BYOK vs YAML +vs OpenRouter dynamic). Default-allow on unknown so we don't lock the +user out of their own model choice. +""" + +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace +from uuid import uuid4 + +import pytest + +from app.db import LiteLLMProvider +from app.routes import new_llm_config_routes + +pytestmark = pytest.mark.unit + + +def _byok_row( + *, + id_: int, + model_name: str, + base_model: str | None = None, + provider: LiteLLMProvider = LiteLLMProvider.OPENAI, + custom_provider: str | None = None, +) -> object: + """Mimic the SQLAlchemy row's attribute surface; ``model_validate`` + walks ``from_attributes=True`` so a ``SimpleNamespace`` is enough. + + ``provider`` is a real ``LiteLLMProvider`` enum value so Pydantic's + enum validator accepts it — same as the ORM row would carry.""" + return SimpleNamespace( + id=id_, + name=f"BYOK-{id_}", + description=None, + provider=provider, + custom_provider=custom_provider, + model_name=model_name, + api_key="sk-byok", + api_base=None, + litellm_params={"base_model": base_model} if base_model else None, + system_instructions="", + use_default_system_instructions=True, + citations_enabled=True, + created_at=datetime.now(tz=UTC), + search_space_id=42, + user_id=uuid4(), + ) + + +def test_serialize_byok_known_vision_model_resolves_true(): + """The catalog resolver consults LiteLLM's map for ``gpt-4o`` -> + True. The serialized row carries that value through to the + ``NewLLMConfigRead`` schema.""" + row = _byok_row(id_=1, model_name="gpt-4o") + serialized = new_llm_config_routes._serialize_byok_config(row) + + assert serialized.supports_image_input is True + assert serialized.id == 1 + assert serialized.model_name == "gpt-4o" + + +def test_serialize_byok_unknown_model_default_allows(): + """Unknown / unmapped: default-allow. The streaming-task safety net + is the actual block, and it requires LiteLLM to *explicitly* say + text-only — so a brand new BYOK model should not be pre-judged.""" + row = _byok_row( + id_=2, + model_name="brand-new-model-x9-unmapped", + provider=LiteLLMProvider.CUSTOM, + custom_provider="brand_new_proxy", + ) + serialized = new_llm_config_routes._serialize_byok_config(row) + + assert serialized.supports_image_input is True + + +def test_serialize_byok_uses_base_model_when_present(): + """Azure-style: ``model_name`` is the deployment id, ``base_model`` + inside ``litellm_params`` is the canonical sku LiteLLM knows. The + helper must consult ``base_model`` first or unrecognised deployment + ids would shadow the real capability.""" + row = _byok_row( + id_=3, + model_name="my-azure-deployment-id-no-litellm-knows-this", + base_model="gpt-4o", + provider=LiteLLMProvider.AZURE_OPENAI, + ) + serialized = new_llm_config_routes._serialize_byok_config(row) + + assert serialized.supports_image_input is True + + +def test_serialize_byok_returns_pydantic_read_model(): + """The route now returns ``NewLLMConfigRead`` (not the raw ORM) so + the schema additions are guaranteed to be present in the API + surface. This guards against a future regression where someone + deletes the augmentation step and falls back to ORM passthrough.""" + from app.schemas import NewLLMConfigRead + + row = _byok_row(id_=4, model_name="gpt-4o") + serialized = new_llm_config_routes._serialize_byok_config(row) + assert isinstance(serialized, NewLLMConfigRead) diff --git a/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py b/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py new file mode 100644 index 000000000..2b6c76485 --- /dev/null +++ b/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py @@ -0,0 +1,184 @@ +"""Unit tests for ``is_premium`` derivation on the global image-gen and +vision-LLM list endpoints. + +Chat globals (``GET /global-llm-configs``) already emit +``is_premium = (billing_tier == "premium")``. Image and vision did not, +which made the new-chat ``model-selector`` render the Free/Premium badge +on the Chat tab but skip it on the Image and Vision tabs (the selector +keys its badge logic off ``is_premium``). These tests pin parity: + +* YAML free entry → ``is_premium=False`` +* YAML premium entry → ``is_premium=True`` +* OpenRouter dynamic premium entry → ``is_premium=True`` +* Auto stub (always emitted when at least one config is present) + → ``is_premium=False`` +""" + +from __future__ import annotations + +import pytest + +pytestmark = pytest.mark.unit + + +_IMAGE_FIXTURE: list[dict] = [ + { + "id": -1, + "name": "DALL-E 3", + "provider": "OPENAI", + "model_name": "dall-e-3", + "api_key": "sk-test", + "billing_tier": "free", + }, + { + "id": -2, + "name": "GPT-Image 1 (premium)", + "provider": "OPENAI", + "model_name": "gpt-image-1", + "api_key": "sk-test", + "billing_tier": "premium", + }, + { + "id": -20_001, + "name": "google/gemini-2.5-flash-image (OpenRouter)", + "provider": "OPENROUTER", + "model_name": "google/gemini-2.5-flash-image", + "api_key": "sk-or-test", + "api_base": "https://openrouter.ai/api/v1", + "billing_tier": "premium", + }, +] + + +_VISION_FIXTURE: list[dict] = [ + { + "id": -1, + "name": "GPT-4o Vision", + "provider": "OPENAI", + "model_name": "gpt-4o", + "api_key": "sk-test", + "billing_tier": "free", + }, + { + "id": -2, + "name": "Claude 3.5 Sonnet (premium)", + "provider": "ANTHROPIC", + "model_name": "claude-3-5-sonnet", + "api_key": "sk-ant-test", + "billing_tier": "premium", + }, + { + "id": -30_001, + "name": "openai/gpt-4o (OpenRouter)", + "provider": "OPENROUTER", + "model_name": "openai/gpt-4o", + "api_key": "sk-or-test", + "api_base": "https://openrouter.ai/api/v1", + "billing_tier": "premium", + }, +] + + +# ============================================================================= +# Image generation +# ============================================================================= + + +@pytest.mark.asyncio +async def test_global_image_gen_configs_emit_is_premium(monkeypatch): + """Each emitted config must carry ``is_premium`` derived server-side + from ``billing_tier``. The Auto stub is always free. + """ + from app.config import config + from app.routes import image_generation_routes + + monkeypatch.setattr( + config, "GLOBAL_IMAGE_GEN_CONFIGS", _IMAGE_FIXTURE, raising=False + ) + + payload = await image_generation_routes.get_global_image_gen_configs(user=None) + + by_id = {c["id"]: c for c in payload} + + # Auto stub is always emitted when at least one global config exists, + # and it must always declare itself free (Auto-mode billing-tier + # surfacing is a separate follow-up). + assert 0 in by_id, "Auto stub should be emitted when at least one config exists" + assert by_id[0]["is_premium"] is False + assert by_id[0]["billing_tier"] == "free" + + # YAML free entry — ``is_premium=False`` + assert by_id[-1]["is_premium"] is False + assert by_id[-1]["billing_tier"] == "free" + + # YAML premium entry — ``is_premium=True`` + assert by_id[-2]["is_premium"] is True + assert by_id[-2]["billing_tier"] == "premium" + + # OpenRouter dynamic premium entry — same field, same derivation + assert by_id[-20_001]["is_premium"] is True + assert by_id[-20_001]["billing_tier"] == "premium" + + # Every emitted dict (including Auto) must have the field — never missing. + for cfg in payload: + assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}" + assert isinstance(cfg["is_premium"], bool) + + +@pytest.mark.asyncio +async def test_global_image_gen_configs_no_globals_no_auto_stub(monkeypatch): + """When there are no global configs at all, the endpoint emits an + empty list (no Auto stub) — Auto mode would have nothing to route to. + """ + from app.config import config + from app.routes import image_generation_routes + + monkeypatch.setattr(config, "GLOBAL_IMAGE_GEN_CONFIGS", [], raising=False) + payload = await image_generation_routes.get_global_image_gen_configs(user=None) + assert payload == [] + + +# ============================================================================= +# Vision LLM +# ============================================================================= + + +@pytest.mark.asyncio +async def test_global_vision_llm_configs_emit_is_premium(monkeypatch): + from app.config import config + from app.routes import vision_llm_routes + + monkeypatch.setattr( + config, "GLOBAL_VISION_LLM_CONFIGS", _VISION_FIXTURE, raising=False + ) + + payload = await vision_llm_routes.get_global_vision_llm_configs(user=None) + + by_id = {c["id"]: c for c in payload} + + assert 0 in by_id, "Auto stub should be emitted when at least one config exists" + assert by_id[0]["is_premium"] is False + assert by_id[0]["billing_tier"] == "free" + + assert by_id[-1]["is_premium"] is False + assert by_id[-1]["billing_tier"] == "free" + + assert by_id[-2]["is_premium"] is True + assert by_id[-2]["billing_tier"] == "premium" + + assert by_id[-30_001]["is_premium"] is True + assert by_id[-30_001]["billing_tier"] == "premium" + + for cfg in payload: + assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}" + assert isinstance(cfg["is_premium"], bool) + + +@pytest.mark.asyncio +async def test_global_vision_llm_configs_no_globals_no_auto_stub(monkeypatch): + from app.config import config + from app.routes import vision_llm_routes + + monkeypatch.setattr(config, "GLOBAL_VISION_LLM_CONFIGS", [], raising=False) + payload = await vision_llm_routes.get_global_vision_llm_configs(user=None) + assert payload == [] diff --git a/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py b/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py new file mode 100644 index 000000000..b47d9134b --- /dev/null +++ b/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py @@ -0,0 +1,106 @@ +"""Unit tests for ``supports_image_input`` derivation on the chat global +config endpoint (``GET /global-new-llm-configs``). + +Resolution order (matches ``new_llm_config_routes.get_global_new_llm_configs``): + +1. Explicit ``supports_image_input`` on the cfg dict (set by the YAML + loader for operator overrides, or by the OpenRouter integration from + ``architecture.input_modalities``) — wins. +2. ``derive_supports_image_input`` helper — default-allow on unknown + models, only False when LiteLLM / OR modalities are definitive. + +The flag is purely informational at the API boundary. The streaming +task safety net (``is_known_text_only_chat_model``) is the actual block, +and it requires LiteLLM to *explicitly* mark the model as text-only. +""" + +from __future__ import annotations + +import pytest + +pytestmark = pytest.mark.unit + + +_FIXTURE: list[dict] = [ + { + "id": -1, + "name": "GPT-4o (explicit true)", + "description": "vision-capable, explicit YAML override", + "provider": "OPENAI", + "model_name": "gpt-4o", + "api_key": "sk-test", + "billing_tier": "free", + "supports_image_input": True, + }, + { + "id": -2, + "name": "DeepSeek V3 (explicit false)", + "description": "OpenRouter dynamic — modality-derived false", + "provider": "OPENROUTER", + "model_name": "deepseek/deepseek-v3.2-exp", + "api_key": "sk-or-test", + "api_base": "https://openrouter.ai/api/v1", + "billing_tier": "free", + "supports_image_input": False, + }, + { + "id": -10_010, + "name": "Unannotated GPT-4o", + "description": "no flag set — resolver should derive True via LiteLLM", + "provider": "OPENAI", + "model_name": "gpt-4o", + "api_key": "sk-test", + "billing_tier": "free", + # supports_image_input intentionally absent + }, + { + "id": -10_011, + "name": "Unannotated unknown model", + "description": "unmapped — default-allow True", + "provider": "CUSTOM", + "custom_provider": "brand_new_proxy", + "model_name": "brand-new-model-x9", + "api_key": "sk-test", + "billing_tier": "free", + }, +] + + +@pytest.mark.asyncio +async def test_global_new_llm_configs_emit_supports_image_input(monkeypatch): + """Each emitted chat config carries ``supports_image_input`` as a + bool. Explicit values win; unannotated entries are resolved via the + helper (default-allow True).""" + from app.config import config + from app.routes import new_llm_config_routes + + monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", _FIXTURE, raising=False) + + payload = await new_llm_config_routes.get_global_new_llm_configs(user=None) + by_id = {c["id"]: c for c in payload} + + # Auto stub: optimistic True so the user can keep Auto selected with + # vision-capable deployments somewhere in the pool. + assert 0 in by_id, "Auto stub should be emitted when configs exist" + assert by_id[0]["supports_image_input"] is True + assert by_id[0]["is_auto_mode"] is True + + # Explicit True is preserved. + assert by_id[-1]["supports_image_input"] is True + + # Explicit False is preserved (the exact failure mode the safety net + # guards against — DeepSeek V3 over OpenRouter would 404 with "No + # endpoints found that support image input"). + assert by_id[-2]["supports_image_input"] is False + + # Unannotated GPT-4o: resolver consults LiteLLM, which says vision. + assert by_id[-10_010]["supports_image_input"] is True + + # Unknown / unmapped model: default-allow rather than pre-judge. + assert by_id[-10_011]["supports_image_input"] is True + + for cfg in payload: + assert "supports_image_input" in cfg, ( + f"supports_image_input missing from {cfg.get('id')}" + ) + assert isinstance(cfg["supports_image_input"], bool) diff --git a/surfsense_backend/tests/unit/routes/test_image_gen_quota.py b/surfsense_backend/tests/unit/routes/test_image_gen_quota.py index 3d94c6c51..636b7de31 100644 --- a/surfsense_backend/tests/unit/routes/test_image_gen_quota.py +++ b/surfsense_backend/tests/unit/routes/test_image_gen_quota.py @@ -27,18 +27,9 @@ async def test_resolve_billing_for_auto_mode(monkeypatch): from app.routes import image_generation_routes from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS - async def _no_auto_candidates(*_args, **_kwargs): - return [] - - monkeypatch.setattr( - image_generation_routes, - "auto_model_candidates", - _no_auto_candidates, - ) - - search_space = SimpleNamespace(id=1, user_id=None, image_gen_model_id=None) + search_space = SimpleNamespace(image_generation_config_id=None) tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen( - session=None, + session=None, # Not consumed on this code path. config_id=0, # IMAGE_GEN_AUTO_MODE_ID search_space=search_space, ) @@ -54,48 +45,26 @@ async def test_resolve_billing_for_premium_global_config(monkeypatch): monkeypatch.setattr( config, - "GLOBAL_MODELS", + "GLOBAL_IMAGE_GEN_CONFIGS", [ { "id": -1, - "connection_id": -101, - "model_id": "gpt-image-1", + "provider": "OPENAI", + "model_name": "gpt-image-1", "billing_tier": "premium", - "catalog": {"quota_reserve_micros": 75_000}, + "quota_reserve_micros": 75_000, }, { "id": -2, - "connection_id": -102, - "model_id": "google/gemini-2.5-flash-image", + "provider": "OPENROUTER", + "model_name": "google/gemini-2.5-flash-image", "billing_tier": "free", - "catalog": {}, - }, - ], - raising=False, - ) - monkeypatch.setattr( - config, - "GLOBAL_CONNECTIONS", - [ - { - "id": -101, - "provider": "openai", - "api_key": "sk-test", - "base_url": None, - "extra": {}, - }, - { - "id": -102, - "provider": "openrouter", - "api_key": "sk-or-test", - "base_url": "https://openrouter.ai/api/v1", - "extra": {}, }, ], raising=False, ) - search_space = SimpleNamespace(id=1, user_id=None, image_gen_model_id=None) + search_space = SimpleNamespace(image_generation_config_id=None) # Premium with override. tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen( @@ -125,7 +94,7 @@ async def test_resolve_billing_for_user_owned_byok_is_free(): from app.routes import image_generation_routes from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS - search_space = SimpleNamespace(id=1, user_id=None, image_gen_model_id=None) + search_space = SimpleNamespace(image_generation_config_id=None) tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen( session=None, config_id=42, search_space=search_space ) @@ -136,7 +105,7 @@ async def test_resolve_billing_for_user_owned_byok_is_free(): @pytest.mark.asyncio async def test_resolve_billing_falls_back_to_search_space_default(monkeypatch): - """When the request omits ``image_gen_model_id``, the helper + """When the request omits ``image_generation_config_id``, the helper must consult the search space's default — so a search space pinned to a premium global config still gates new requests by quota. """ @@ -145,34 +114,19 @@ async def test_resolve_billing_falls_back_to_search_space_default(monkeypatch): monkeypatch.setattr( config, - "GLOBAL_MODELS", + "GLOBAL_IMAGE_GEN_CONFIGS", [ { "id": -7, - "connection_id": -101, - "model_id": "gpt-image-1", + "provider": "OPENAI", + "model_name": "gpt-image-1", "billing_tier": "premium", - "catalog": {}, - } - ], - raising=False, - ) - monkeypatch.setattr( - config, - "GLOBAL_CONNECTIONS", - [ - { - "id": -101, - "provider": "openai", - "api_key": "sk-test", - "base_url": None, - "extra": {}, } ], raising=False, ) - search_space = SimpleNamespace(id=1, user_id=None, image_gen_model_id=-7) + search_space = SimpleNamespace(image_generation_config_id=-7) ( tier, model, diff --git a/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py b/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py index b43540ba7..fa8819b39 100644 --- a/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py +++ b/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py @@ -1,4 +1,27 @@ -"""Unit tests for ``_resolve_agent_billing_for_search_space``.""" +"""Unit tests for ``_resolve_agent_billing_for_search_space``. + +Validates the resolver used by Celery podcast/video tasks to compute +``(owner_user_id, billing_tier, base_model)`` from a search space and its +agent LLM config. The resolver mirrors chat's billing-resolution pattern at +``stream_new_chat.py:2294-2351`` and is the single integration point that +prevents Auto-mode podcast/video from leaking premium credit. + +Coverage: + +* Auto mode + ``thread_id`` set, pin resolves to a negative-id premium + global → returns ``("premium", <base_model>)``. +* Auto mode + ``thread_id`` set, pin resolves to a negative-id free + global → returns ``("free", <base_model>)``. +* Auto mode + ``thread_id`` set, pin resolves to a positive-id BYOK config + → always ``"free"``. +* Auto mode + ``thread_id=None`` → fallback to ``("free", "auto")`` without + hitting the pin service. +* Negative id (no Auto) → uses ``get_global_llm_config``'s + ``billing_tier``. +* Positive id (user BYOK) → always ``"free"``. +* Search space not found → raises ``ValueError``. +* ``agent_llm_id`` is None → raises ``ValueError``. +""" from __future__ import annotations @@ -11,6 +34,11 @@ import pytest pytestmark = pytest.mark.unit +# --------------------------------------------------------------------------- +# Fakes +# --------------------------------------------------------------------------- + + class _FakeExecResult: def __init__(self, obj): self._obj = obj @@ -23,6 +51,14 @@ class _FakeExecResult: class _FakeSession: + """Tiny AsyncSession stub. + + ``responses`` is a list of objects to return from successive + ``execute()`` calls (in order). The resolver makes at most two + ``execute()`` calls (search-space lookup, then optionally NewLLMConfig + lookup), so two queued responses cover the matrix. + """ + def __init__(self, responses: list): self._responses = list(responses) @@ -31,6 +67,9 @@ class _FakeSession: return _FakeExecResult(None) return _FakeExecResult(self._responses.pop(0)) + async def commit(self) -> None: + pass + @dataclass class _FakePinResolution: @@ -39,33 +78,53 @@ class _FakePinResolution: from_existing_pin: bool = False -def _make_search_space(*, chat_model_id: int | None, user_id: UUID) -> SimpleNamespace: - return SimpleNamespace(id=42, chat_model_id=chat_model_id, user_id=user_id) +def _make_search_space(*, agent_llm_id: int | None, user_id: UUID) -> SimpleNamespace: + return SimpleNamespace( + id=42, + agent_llm_id=agent_llm_id, + user_id=user_id, + ) -def _make_byok_model( - *, id_: int, base_model: str | None = None, model_id: str = "gpt-byok" +def _make_byok_config( + *, id_: int, base_model: str | None = None, model_name: str = "gpt-byok" ) -> SimpleNamespace: return SimpleNamespace( id=id_, - model_id=model_id, - catalog={"base_model": base_model} if base_model else {}, - connection=SimpleNamespace(enabled=True, search_space_id=42, user_id=None), + model_name=model_name, + litellm_params={"base_model": base_model} if base_model else {}, ) +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + @pytest.mark.asyncio async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch): + """Auto + thread → pin service resolves to negative-id premium config → + resolver returns ``("premium", <base_model>)``.""" from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - session = _FakeSession([_make_search_space(chat_model_id=0, user_id=user_id)]) + session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)]) - async def _fake_resolve_pin(*_args, **kwargs): - assert kwargs["selected_llm_config_id"] == 0 - assert kwargs["thread_id"] == 99 + # Mock the pin service to return a concrete premium config id. + async def _fake_resolve_pin( + sess, + *, + thread_id, + search_space_id, + user_id, + selected_llm_config_id, + force_repin_free=False, + ): + assert selected_llm_config_id == 0 + assert thread_id == 99 return _FakePinResolution(resolved_llm_config_id=-1, resolved_tier="premium") + # Mock global config lookup to return a premium entry. def _fake_get_global(cfg_id): if cfg_id == -1: return { @@ -76,6 +135,8 @@ async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch): } return None + # Lazy imports inside the resolver — patch the *target* modules so the + # imported names resolve to our fakes. import app.services.auto_model_pin_service as pin_module import app.services.llm_service as llm_module @@ -94,17 +155,76 @@ async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch): @pytest.mark.asyncio -async def test_auto_mode_with_thread_id_resolves_to_byok_is_free(monkeypatch): +async def test_auto_mode_with_thread_id_resolves_to_free_global(monkeypatch): + """Auto + thread → pin returns negative-id free config → resolver + returns ``("free", <base_model>)``. Same path the pin service takes for + out-of-credit users (graceful degradation).""" from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - search_space = _make_search_space(chat_model_id=0, user_id=user_id) - byok_model = _make_byok_model( - id_=17, base_model="anthropic/claude-3-haiku", model_id="my-claude" - ) - session = _FakeSession([search_space, byok_model]) + session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)]) - async def _fake_resolve_pin(*_args, **_kwargs): + async def _fake_resolve_pin( + sess, + *, + thread_id, + search_space_id, + user_id, + selected_llm_config_id, + force_repin_free=False, + ): + return _FakePinResolution(resolved_llm_config_id=-3, resolved_tier="free") + + def _fake_get_global(cfg_id): + if cfg_id == -3: + return { + "id": -3, + "model_name": "openrouter/free-model", + "billing_tier": "free", + "litellm_params": {"base_model": "openrouter/free-model"}, + } + return None + + import app.services.auto_model_pin_service as pin_module + import app.services.llm_service as llm_module + + monkeypatch.setattr( + pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin + ) + monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42, thread_id=99 + ) + + assert owner == user_id + assert tier == "free" + assert base_model == "openrouter/free-model" + + +@pytest.mark.asyncio +async def test_auto_mode_with_thread_id_resolves_to_byok_is_free(monkeypatch): + """Auto + thread → pin returns positive-id BYOK config → resolver + returns ``("free", ...)`` (BYOK is always free per + ``AgentConfig.from_new_llm_config``).""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + search_space = _make_search_space(agent_llm_id=0, user_id=user_id) + byok_cfg = _make_byok_config( + id_=17, base_model="anthropic/claude-3-haiku", model_name="my-claude" + ) + session = _FakeSession([search_space, byok_cfg]) + + async def _fake_resolve_pin( + sess, + *, + thread_id, + search_space_id, + user_id, + selected_llm_config_id, + force_repin_free=False, + ): return _FakePinResolution(resolved_llm_config_id=17, resolved_tier="free") import app.services.auto_model_pin_service as pin_module @@ -124,10 +244,13 @@ async def test_auto_mode_with_thread_id_resolves_to_byok_is_free(monkeypatch): @pytest.mark.asyncio async def test_auto_mode_without_thread_id_falls_back_to_free(): + """Auto + ``thread_id=None`` → ``("free", "auto")`` without invoking + the pin service. Forward-compat fallback for any future direct-API + entrypoint that doesn't have a chat thread.""" from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - session = _FakeSession([_make_search_space(chat_model_id=0, user_id=user_id)]) + session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)]) owner, tier, base_model = await _resolve_agent_billing_for_search_space( session, search_space_id=42, thread_id=None @@ -140,10 +263,13 @@ async def test_auto_mode_without_thread_id_falls_back_to_free(): @pytest.mark.asyncio async def test_auto_mode_pin_failure_falls_back_to_free(monkeypatch): + """If the pin service raises ``ValueError`` (thread missing / + mismatched search space), the resolver should log and return free + rather than killing the whole task.""" from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - session = _FakeSession([_make_search_space(chat_model_id=0, user_id=user_id)]) + session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)]) async def _fake_resolve_pin(*args, **kwargs): raise ValueError("thread missing") @@ -165,10 +291,12 @@ async def test_auto_mode_pin_failure_falls_back_to_free(monkeypatch): @pytest.mark.asyncio async def test_negative_id_premium_global_returns_premium(monkeypatch): + """Explicit negative agent_llm_id → ``get_global_llm_config`` → + return its ``billing_tier``.""" from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - session = _FakeSession([_make_search_space(chat_model_id=-1, user_id=user_id)]) + session = _FakeSession([_make_search_space(agent_llm_id=-1, user_id=user_id)]) def _fake_get_global(cfg_id): return { @@ -192,14 +320,49 @@ async def test_negative_id_premium_global_returns_premium(monkeypatch): @pytest.mark.asyncio -async def test_negative_id_missing_base_model_falls_back_to_model_name(monkeypatch): +async def test_negative_id_free_global_returns_free(monkeypatch): from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - session = _FakeSession([_make_search_space(chat_model_id=-5, user_id=user_id)]) + session = _FakeSession([_make_search_space(agent_llm_id=-2, user_id=user_id)]) def _fake_get_global(cfg_id): - return {"id": cfg_id, "model_name": "fallback-model", "billing_tier": "premium"} + return { + "id": cfg_id, + "model_name": "openrouter/some-free", + "billing_tier": "free", + "litellm_params": {"base_model": "openrouter/some-free"}, + } + + import app.services.llm_service as llm_module + + monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42, thread_id=None + ) + + assert owner == user_id + assert tier == "free" + assert base_model == "openrouter/some-free" + + +@pytest.mark.asyncio +async def test_negative_id_missing_base_model_falls_back_to_model_name(monkeypatch): + """When the global config has no ``litellm_params.base_model``, the + resolver falls back to ``model_name`` — matching chat's behavior.""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + session = _FakeSession([_make_search_space(agent_llm_id=-5, user_id=user_id)]) + + def _fake_get_global(cfg_id): + return { + "id": cfg_id, + "model_name": "fallback-model", + "billing_tier": "premium", + # No litellm_params. + } import app.services.llm_service as llm_module @@ -215,12 +378,14 @@ async def test_negative_id_missing_base_model_falls_back_to_model_name(monkeypat @pytest.mark.asyncio async def test_positive_id_byok_is_always_free(): + """Positive agent_llm_id → user-owned BYOK NewLLMConfig → always free, + regardless of underlying provider tier.""" from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - search_space = _make_search_space(chat_model_id=23, user_id=user_id) - byok_model = _make_byok_model(id_=23, base_model="anthropic/claude-3.5-sonnet") - session = _FakeSession([search_space, byok_model]) + search_space = _make_search_space(agent_llm_id=23, user_id=user_id) + byok_cfg = _make_byok_config(id_=23, base_model="anthropic/claude-3.5-sonnet") + session = _FakeSession([search_space, byok_cfg]) owner, tier, base_model = await _resolve_agent_billing_for_search_space( session, search_space_id=42 @@ -233,10 +398,13 @@ async def test_positive_id_byok_is_always_free(): @pytest.mark.asyncio async def test_positive_id_byok_missing_returns_free_with_empty_base_model(): + """If the BYOK config row is missing/deleted but the search space still + points at it, the resolver still returns free (no debit) with an empty + base_model — billable_call's premium path is skipped, no harm done.""" from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - session = _FakeSession([_make_search_space(chat_model_id=99, user_id=user_id)]) + session = _FakeSession([_make_search_space(agent_llm_id=99, user_id=user_id)]) owner, tier, base_model = await _resolve_agent_billing_for_search_space( session, search_space_id=42 @@ -251,18 +419,18 @@ async def test_positive_id_byok_missing_returns_free_with_empty_base_model(): async def test_search_space_not_found_raises_value_error(): from app.services.billable_calls import _resolve_agent_billing_for_search_space + session = _FakeSession([None]) + with pytest.raises(ValueError, match="Search space"): - await _resolve_agent_billing_for_search_space( - _FakeSession([None]), search_space_id=999 - ) + await _resolve_agent_billing_for_search_space(session, search_space_id=999) @pytest.mark.asyncio -async def test_chat_model_id_none_raises_value_error(): +async def test_agent_llm_id_none_raises_value_error(): from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - session = _FakeSession([_make_search_space(chat_model_id=None, user_id=user_id)]) + session = _FakeSession([_make_search_space(agent_llm_id=None, user_id=user_id)]) - with pytest.raises(ValueError, match="chat_model_id"): + with pytest.raises(ValueError, match="agent_llm_id"): await _resolve_agent_billing_for_search_space(session, search_space_id=42) diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py index 598e9b1ab..d1af29aeb 100644 --- a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py +++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py @@ -17,39 +17,8 @@ from app.services.auto_model_pin_service import ( pytestmark = pytest.mark.unit -class _FakeRedis: - def __init__(self): - self.values: dict[str, str] = {} - self.ttls: dict[str, int] = {} - - def set(self, key: str, value: str, *, ex: int | None = None): - self.values[key] = value - if ex is not None: - self.ttls[key] = ex - return True - - def mget(self, keys: list[str]): - return [self.values.get(key) for key in keys] - - def delete(self, *keys: str): - removed = 0 - for key in keys: - if key in self.values: - removed += 1 - self.values.pop(key, None) - self.ttls.pop(key, None) - return removed - - def scan_iter(self, pattern: str): - prefix = pattern.removesuffix("*") - return (key for key in list(self.values) if key.startswith(prefix)) - - @pytest.fixture(autouse=True) -def _clear_runtime_cooldown_map(monkeypatch): - import app.services.auto_model_pin_service as svc - - monkeypatch.setattr(svc, "_runtime_cooldown_redis", _FakeRedis()) +def _clear_runtime_cooldown_map(): clear_runtime_cooldown() clear_healthy() yield @@ -63,9 +32,8 @@ class _FakeQuotaResult: class _FakeExecResult: - def __init__(self, *, thread=None, scalars=None): + def __init__(self, thread): self._thread = thread - self._scalars = scalars or [] def unique(self): return self @@ -73,71 +41,19 @@ class _FakeExecResult: def scalar_one_or_none(self): return self._thread - def scalars(self): - return SimpleNamespace(all=lambda: self._scalars) - class _FakeSession: - def __init__(self, thread, *, models=None): + def __init__(self, thread): self.thread = thread - self.models = models or [] self.commit_count = 0 - self.execute_count = 0 async def execute(self, _stmt): - self.execute_count += 1 - if self.execute_count == 1: - return _FakeExecResult(thread=self.thread) - return _FakeExecResult(scalars=self.models) + return _FakeExecResult(self.thread) async def commit(self): self.commit_count += 1 -def _set_global_llm_configs(monkeypatch, config, configs: list[dict]): - """Patch the new global model catalog shape from compact legacy cfg fixtures.""" - connections = [] - models = [] - for cfg in configs: - config_id = int(cfg["id"]) - connection_id = config_id - 100_000 - provider = cfg.get("provider") or cfg.get("litellm_provider") - model_name = cfg["model_name"] - connections.append( - { - "id": connection_id, - "provider": provider, - "scope": "GLOBAL", - "enabled": True, - } - ) - models.append( - { - "id": config_id, - "connection_id": connection_id, - "model_id": model_name, - "display_name": cfg.get("name") or model_name, - "supports_chat": cfg.get("supports_chat", True), - "supports_image_input": cfg.get("supports_image_input", True), - "supports_tools": cfg.get("supports_tools", True), - "supports_image_generation": cfg.get( - "supports_image_generation", False - ), - "capabilities_override": cfg.get("capabilities_override") or {}, - "billing_tier": cfg.get("billing_tier", "free"), - "catalog": { - "auto_pin_tier": cfg.get("auto_pin_tier"), - "quality_score": cfg.get("quality_score") - or cfg.get("quality_score_static"), - }, - } - ) - - monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", configs) - monkeypatch.setattr(config, "GLOBAL_CONNECTIONS", connections) - monkeypatch.setattr(config, "GLOBAL_MODELS", models) - - def _thread( *, search_space_id: int = 10, @@ -155,19 +71,14 @@ async def test_auto_first_turn_pins_one_model(monkeypatch): from app.config import config session = _FakeSession(_thread()) - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [ - { - "id": -2, - "litellm_provider": "openai", - "model_name": "gpt-free", - "api_key": "k1", - }, + {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"}, { "id": -1, - "litellm_provider": "openai", + "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium", @@ -179,7 +90,7 @@ async def test_auto_first_turn_pins_one_model(monkeypatch): return _FakeQuotaResult(allowed=True) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _allowed, ) @@ -200,13 +111,13 @@ async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch): from app.config import config session = _FakeSession(_thread()) - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [ { "id": -2, - "litellm_provider": "openai", + "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free", @@ -214,7 +125,7 @@ async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch): }, { "id": -1, - "litellm_provider": "openai", + "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium", @@ -227,7 +138,7 @@ async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch): return _FakeQuotaResult(allowed=True) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _allowed, ) @@ -243,19 +154,17 @@ async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch): @pytest.mark.asyncio -async def test_premium_eligible_auto_uses_quality_pool_not_single_preferred_model( - monkeypatch, -): +async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch): from app.config import config session = _FakeSession(_thread()) - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [ { "id": -1, - "litellm_provider": "azure", + "provider": "AZURE_OPENAI", "model_name": "gpt-5.1", "api_key": "k1", "billing_tier": "premium", @@ -264,7 +173,7 @@ async def test_premium_eligible_auto_uses_quality_pool_not_single_preferred_mode }, { "id": -2, - "litellm_provider": "azure", + "provider": "AZURE_OPENAI", "model_name": "gpt-5.4", "api_key": "k2", "billing_tier": "premium", @@ -273,39 +182,12 @@ async def test_premium_eligible_auto_uses_quality_pool_not_single_preferred_mode }, { "id": -3, - "litellm_provider": "anthropic", - "model_name": "claude-opus", + "provider": "OPENROUTER", + "model_name": "openai/gpt-5.4", "api_key": "k3", "billing_tier": "premium", - "auto_pin_tier": "A", - "quality_score": 99, - }, - { - "id": -4, - "litellm_provider": "openai", - "model_name": "gpt-5.3", - "api_key": "k4", - "billing_tier": "premium", - "auto_pin_tier": "A", - "quality_score": 98, - }, - { - "id": -5, - "litellm_provider": "gemini", - "model_name": "gemini-3-pro", - "api_key": "k5", - "billing_tier": "premium", - "auto_pin_tier": "A", - "quality_score": 97, - }, - { - "id": -6, - "litellm_provider": "xai", - "model_name": "grok-5", - "api_key": "k6", - "billing_tier": "premium", - "auto_pin_tier": "A", - "quality_score": 96, + "auto_pin_tier": "B", + "quality_score": 100, }, ], ) @@ -314,7 +196,7 @@ async def test_premium_eligible_auto_uses_quality_pool_not_single_preferred_mode return _FakeQuotaResult(allowed=True) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _allowed, ) @@ -325,7 +207,7 @@ async def test_premium_eligible_auto_uses_quality_pool_not_single_preferred_mode user_id="00000000-0000-0000-0000-000000000001", selected_llm_config_id=0, ) - assert result.resolved_llm_config_id in {-1, -3, -4, -5, -6} + assert result.resolved_llm_config_id == -2 assert result.resolved_tier == "premium" @@ -334,13 +216,13 @@ async def test_next_turn_reuses_existing_pin(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [ { "id": -1, - "litellm_provider": "openai", + "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium", @@ -350,11 +232,11 @@ async def test_next_turn_reuses_existing_pin(monkeypatch): async def _must_not_call(*_args, **_kwargs): raise AssertionError( - "credit_get_usage should not be called for valid pin reuse" + "premium_get_usage should not be called for valid pin reuse" ) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _must_not_call, ) @@ -375,13 +257,13 @@ async def test_premium_eligible_auto_can_pin_premium(monkeypatch): from app.config import config session = _FakeSession(_thread()) - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [ { "id": -1, - "litellm_provider": "openai", + "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium", @@ -393,7 +275,7 @@ async def test_premium_eligible_auto_can_pin_premium(monkeypatch): return _FakeQuotaResult(allowed=True) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _allowed, ) @@ -413,20 +295,20 @@ async def test_premium_ineligible_auto_pins_free_only(monkeypatch): from app.config import config session = _FakeSession(_thread()) - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [ { "id": -2, - "litellm_provider": "openai", + "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free", }, { "id": -1, - "litellm_provider": "openai", + "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium", @@ -438,7 +320,7 @@ async def test_premium_ineligible_auto_pins_free_only(monkeypatch): return _FakeQuotaResult(allowed=False) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _blocked, ) @@ -458,20 +340,20 @@ async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [ { "id": -2, - "litellm_provider": "openai", + "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free", }, { "id": -1, - "litellm_provider": "openai", + "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium", @@ -483,7 +365,7 @@ async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch): return _FakeQuotaResult(allowed=False) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _blocked, ) @@ -503,20 +385,20 @@ async def test_force_repin_free_switches_auto_premium_pin_to_free(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [ { "id": -2, - "litellm_provider": "openai", + "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free", }, { "id": -1, - "litellm_provider": "openai", + "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium", @@ -528,7 +410,7 @@ async def test_force_repin_free_switches_auto_premium_pin_to_free(monkeypatch): return _FakeQuotaResult(allowed=False) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _blocked, ) @@ -551,16 +433,11 @@ async def test_explicit_user_model_change_clears_pin(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-2)) - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [ - { - "id": -2, - "litellm_provider": "openai", - "model_name": "gpt-free", - "api_key": "k1", - }, + {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"}, ], ) @@ -581,16 +458,11 @@ async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-999)) - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [ - { - "id": -2, - "litellm_provider": "openai", - "model_name": "gpt-free", - "api_key": "k1", - }, + {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"}, ], ) @@ -598,7 +470,7 @@ async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch): return _FakeQuotaResult(allowed=False) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _blocked, ) @@ -615,7 +487,7 @@ async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch): # --------------------------------------------------------------------------- -# Quality-aware pin selection (Auto upgrade) +# Quality-aware pin selection (Auto Fastest upgrade) # --------------------------------------------------------------------------- @@ -626,13 +498,13 @@ async def test_health_gated_config_is_excluded_from_selection(monkeypatch): from app.config import config session = _FakeSession(_thread()) - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [ { "id": -1, - "litellm_provider": "openrouter", + "provider": "OPENROUTER", "model_name": "venice/dead-model", "api_key": "k1", "billing_tier": "free", @@ -642,7 +514,7 @@ async def test_health_gated_config_is_excluded_from_selection(monkeypatch): }, { "id": -2, - "litellm_provider": "openrouter", + "provider": "OPENROUTER", "model_name": "google/gemini-flash", "api_key": "k1", "billing_tier": "free", @@ -657,7 +529,7 @@ async def test_health_gated_config_is_excluded_from_selection(monkeypatch): return _FakeQuotaResult(allowed=False) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _blocked, ) @@ -678,13 +550,13 @@ async def test_tier_a_locks_first_premium_user_skips_or(monkeypatch): from app.config import config session = _FakeSession(_thread()) - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [ { "id": -1, - "litellm_provider": "azure", + "provider": "AZURE_OPENAI", "model_name": "gpt-5", "api_key": "k-yaml", "billing_tier": "premium", @@ -694,7 +566,7 @@ async def test_tier_a_locks_first_premium_user_skips_or(monkeypatch): }, { "id": -2, - "litellm_provider": "openrouter", + "provider": "OPENROUTER", "model_name": "openai/gpt-5", "api_key": "k-or", "billing_tier": "premium", @@ -709,7 +581,7 @@ async def test_tier_a_locks_first_premium_user_skips_or(monkeypatch): return _FakeQuotaResult(allowed=True) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _allowed, ) @@ -730,13 +602,13 @@ async def test_tier_a_falls_through_to_or_when_a_pool_empty_for_user(monkeypatch from app.config import config session = _FakeSession(_thread()) - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [ { "id": -1, - "litellm_provider": "azure", + "provider": "AZURE_OPENAI", "model_name": "gpt-5", "api_key": "k-yaml", "billing_tier": "premium", @@ -746,7 +618,7 @@ async def test_tier_a_falls_through_to_or_when_a_pool_empty_for_user(monkeypatch }, { "id": -2, - "litellm_provider": "openrouter", + "provider": "OPENROUTER", "model_name": "google/gemini-flash:free", "api_key": "k-or", "billing_tier": "free", @@ -761,7 +633,7 @@ async def test_tier_a_falls_through_to_or_when_a_pool_empty_for_user(monkeypatch return _FakeQuotaResult(allowed=False) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _blocked, ) @@ -784,7 +656,7 @@ async def test_top_k_picks_only_high_score_models(monkeypatch): high_score_cfgs = [ { "id": -i, - "litellm_provider": "azure", + "provider": "AZURE_OPENAI", "model_name": f"gpt-x-{i}", "api_key": "k", "billing_tier": "premium", @@ -796,7 +668,7 @@ async def test_top_k_picks_only_high_score_models(monkeypatch): ] low_score_trap = { "id": -99, - "litellm_provider": "azure", + "provider": "AZURE_OPENAI", "model_name": "tiny-legacy", "api_key": "k", "billing_tier": "premium", @@ -804,9 +676,9 @@ async def test_top_k_picks_only_high_score_models(monkeypatch): "quality_score": 10, "health_gated": False, } - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [*high_score_cfgs, low_score_trap], ) @@ -814,7 +686,7 @@ async def test_top_k_picks_only_high_score_models(monkeypatch): return _FakeQuotaResult(allowed=True) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _allowed, ) @@ -851,13 +723,13 @@ async def test_pin_reuse_survives_health_gating_for_existing_pin(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [ { "id": -1, - "litellm_provider": "openrouter", + "provider": "OPENROUTER", "model_name": "venice/dead-model", "api_key": "k", "billing_tier": "premium", @@ -867,7 +739,7 @@ async def test_pin_reuse_survives_health_gating_for_existing_pin(monkeypatch): }, { "id": -2, - "litellm_provider": "azure", + "provider": "AZURE_OPENAI", "model_name": "gpt-5", "api_key": "k", "billing_tier": "premium", @@ -882,7 +754,7 @@ async def test_pin_reuse_survives_health_gating_for_existing_pin(monkeypatch): return _FakeQuotaResult(allowed=True) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _allowed, ) @@ -903,13 +775,13 @@ async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [ { "id": -1, - "litellm_provider": "azure", + "provider": "AZURE_OPENAI", "model_name": "gpt-5", "api_key": "k", "billing_tier": "premium", @@ -919,7 +791,7 @@ async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch): }, { "id": -2, - "litellm_provider": "azure", + "provider": "AZURE_OPENAI", "model_name": "gpt-5-pro", "api_key": "k", "billing_tier": "premium", @@ -931,10 +803,10 @@ async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch): ) async def _must_not_call(*_args, **_kwargs): - raise AssertionError("credit_get_usage should not run on pin reuse") + raise AssertionError("premium_get_usage should not run on pin reuse") monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _must_not_call, ) @@ -961,13 +833,13 @@ async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [ { "id": -1, - "litellm_provider": "openrouter", + "provider": "OPENROUTER", "model_name": "google/gemma-4-26b-a4b-it:free", "api_key": "k", "billing_tier": "free", @@ -977,7 +849,7 @@ async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch): }, { "id": -2, - "litellm_provider": "openrouter", + "provider": "OPENROUTER", "model_name": "google/gemini-2.5-flash:free", "api_key": "k", "billing_tier": "free", @@ -992,7 +864,7 @@ async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch): return _FakeQuotaResult(allowed=False) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _blocked, ) @@ -1009,86 +881,18 @@ async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch): assert result.from_existing_pin is False -def test_mark_runtime_cooldown_writes_shared_redis(monkeypatch): - import app.services.auto_model_pin_service as svc - - mark_runtime_cooldown(-9, reason="provider_rate_limited", cooldown_seconds=123) - - redis_client = svc._runtime_cooldown_redis - assert redis_client.values["auto:cooldown:llm:-9"] == "provider_rate_limited" - assert redis_client.ttls["auto:cooldown:llm:-9"] == 123 - - -@pytest.mark.asyncio -async def test_shared_runtime_cooldown_blocks_pin_across_workers(monkeypatch): - """A Redis cooldown written by another worker should invalidate local pins.""" - import app.services.auto_model_pin_service as svc - from app.config import config - - session = _FakeSession(_thread(pinned_llm_config_id=-1)) - _set_global_llm_configs( - monkeypatch, - config, - [ - { - "id": -1, - "litellm_provider": "openrouter", - "model_name": "google/gemma-4-26b-a4b-it:free", - "api_key": "k", - "billing_tier": "free", - "auto_pin_tier": "C", - "quality_score": 90, - "health_gated": False, - }, - { - "id": -2, - "litellm_provider": "openrouter", - "model_name": "google/gemini-2.5-flash:free", - "api_key": "k", - "billing_tier": "free", - "auto_pin_tier": "C", - "quality_score": 80, - "health_gated": False, - }, - ], - ) - svc._runtime_cooldown_redis.set( - "auto:cooldown:llm:-1", - "provider_rate_limited", - ex=600, - ) - - async def _blocked(*_args, **_kwargs): - return _FakeQuotaResult(allowed=False) - - monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", - _blocked, - ) - - result = await resolve_or_get_pinned_llm_config_id( - session, - thread_id=1, - search_space_id=10, - user_id="00000000-0000-0000-0000-000000000001", - selected_llm_config_id=0, - ) - assert result.resolved_llm_config_id == -2 - assert result.from_existing_pin is False - - @pytest.mark.asyncio async def test_clearing_runtime_cooldown_restores_pin_reuse(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [ { "id": -1, - "litellm_provider": "openrouter", + "provider": "OPENROUTER", "model_name": "google/gemma-4-26b-a4b-it:free", "api_key": "k", "billing_tier": "free", @@ -1100,10 +904,10 @@ async def test_clearing_runtime_cooldown_restores_pin_reuse(monkeypatch): ) async def _must_not_call(*_args, **_kwargs): - raise AssertionError("credit_get_usage should not run on healthy pin reuse") + raise AssertionError("premium_get_usage should not run on healthy pin reuse") monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _must_not_call, ) @@ -1127,13 +931,13 @@ async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypa from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [ { "id": -1, - "litellm_provider": "openrouter", + "provider": "OPENROUTER", "model_name": "google/gemma-4-26b-a4b-it:free", "api_key": "k", "billing_tier": "free", @@ -1143,7 +947,7 @@ async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypa }, { "id": -2, - "litellm_provider": "openrouter", + "provider": "OPENROUTER", "model_name": "google/gemini-2.5-flash:free", "api_key": "k", "billing_tier": "free", @@ -1158,7 +962,7 @@ async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypa return _FakeQuotaResult(allowed=False) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _blocked, ) diff --git a/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py b/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py index c1e66feb9..0e19b80e4 100644 --- a/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py +++ b/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py @@ -45,9 +45,8 @@ class _FakeQuotaResult: class _FakeExecResult: - def __init__(self, *, thread=None, scalars=None): + def __init__(self, thread): self._thread = thread - self._scalars = scalars or [] def unique(self): return self @@ -55,21 +54,14 @@ class _FakeExecResult: def scalar_one_or_none(self): return self._thread - def scalars(self): - return SimpleNamespace(all=lambda: self._scalars) - class _FakeSession: def __init__(self, thread): self.thread = thread self.commit_count = 0 - self.execute_count = 0 async def execute(self, _stmt): - self.execute_count += 1 - if self.execute_count == 1: - return _FakeExecResult(thread=self.thread) - return _FakeExecResult(scalars=[]) + return _FakeExecResult(self.thread) async def commit(self): self.commit_count += 1 @@ -79,64 +71,10 @@ def _thread(*, pinned: int | None = None): return SimpleNamespace(id=1, search_space_id=10, pinned_llm_config_id=pinned) -def _set_global_llm_configs(monkeypatch, config, configs: list[dict]): - from app.services.provider_capabilities import derive_supports_image_input - - connections = [] - models = [] - for cfg in configs: - config_id = int(cfg["id"]) - connection_id = config_id - 100_000 - provider = cfg.get("provider") or cfg.get("litellm_provider") - model_name = cfg["model_name"] - if "supports_image_input" not in cfg: - litellm_params = cfg.get("litellm_params") or {} - base_model = ( - litellm_params.get("base_model") - if isinstance(litellm_params, dict) - else None - ) - cfg["supports_image_input"] = derive_supports_image_input( - provider=provider, - model_name=model_name, - base_model=base_model, - custom_provider=cfg.get("custom_provider"), - ) - connections.append( - { - "id": connection_id, - "provider": provider, - "scope": "GLOBAL", - "enabled": True, - } - ) - model = { - "id": config_id, - "connection_id": connection_id, - "model_id": model_name, - "display_name": cfg.get("name") or model_name, - "supports_chat": cfg.get("supports_chat", True), - "supports_tools": cfg.get("supports_tools", True), - "supports_image_generation": cfg.get("supports_image_generation", False), - "capabilities_override": cfg.get("capabilities_override") or {}, - "billing_tier": cfg.get("billing_tier", "free"), - "catalog": { - "auto_pin_tier": cfg.get("auto_pin_tier"), - "quality_score": cfg.get("quality_score"), - }, - "supports_image_input": cfg["supports_image_input"], - } - models.append(model) - - monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", configs) - monkeypatch.setattr(config, "GLOBAL_CONNECTIONS", connections) - monkeypatch.setattr(config, "GLOBAL_MODELS", models) - - def _vision_cfg(id_: int, *, tier: str = "free", quality: int = 80) -> dict: return { "id": id_, - "litellm_provider": "openai", + "provider": "OPENAI", "model_name": f"vision-{id_}", "api_key": "k", "billing_tier": tier, @@ -149,7 +87,7 @@ def _vision_cfg(id_: int, *, tier: str = "free", quality: int = 80) -> dict: def _text_only_cfg(id_: int, *, tier: str = "free", quality: int = 90) -> dict: return { "id": id_, - "litellm_provider": "openai", + "provider": "OPENAI", "model_name": f"text-{id_}", "api_key": "k", "billing_tier": tier, @@ -170,9 +108,13 @@ async def test_image_turn_filters_out_text_only_candidates(monkeypatch): from app.config import config session = _FakeSession(_thread()) - _set_global_llm_configs(monkeypatch, config, [_text_only_cfg(-1), _vision_cfg(-2)]) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + config, + "GLOBAL_LLM_CONFIGS", + [_text_only_cfg(-1), _vision_cfg(-2)], + ) + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _premium_allowed, ) @@ -198,9 +140,13 @@ async def test_image_turn_force_repins_stale_text_only_pin(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned=-1)) - _set_global_llm_configs(monkeypatch, config, [_text_only_cfg(-1), _vision_cfg(-2)]) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + config, + "GLOBAL_LLM_CONFIGS", + [_text_only_cfg(-1), _vision_cfg(-2)], + ) + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _premium_allowed, ) @@ -226,13 +172,13 @@ async def test_image_turn_reuses_existing_vision_pin(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned=-2)) - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [_text_only_cfg(-1), _vision_cfg(-2), _vision_cfg(-3, quality=70)], ) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _premium_allowed, ) @@ -257,11 +203,13 @@ async def test_image_turn_with_no_vision_candidates_raises(monkeypatch): from app.config import config session = _FakeSession(_thread()) - _set_global_llm_configs( - monkeypatch, config, [_text_only_cfg(-1), _text_only_cfg(-2)] + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [_text_only_cfg(-1), _text_only_cfg(-2)], ) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _premium_allowed, ) @@ -283,9 +231,13 @@ async def test_non_image_turn_keeps_text_only_in_pool(monkeypatch): from app.config import config session = _FakeSession(_thread()) - _set_global_llm_configs(monkeypatch, config, [_text_only_cfg(-1)]) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + config, + "GLOBAL_LLM_CONFIGS", + [_text_only_cfg(-1)], + ) + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _premium_allowed, ) @@ -309,7 +261,7 @@ async def test_image_turn_unannotated_cfg_resolves_via_helper(monkeypatch): session = _FakeSession(_thread()) cfg_unannotated_vision = { "id": -2, - "litellm_provider": "openai", + "provider": "OPENAI", "model_name": "gpt-4o", # known vision model in LiteLLM map "api_key": "k", "billing_tier": "free", @@ -317,9 +269,9 @@ async def test_image_turn_unannotated_cfg_resolves_via_helper(monkeypatch): "quality_score": 80, # NOTE: no supports_image_input key } - _set_global_llm_configs(monkeypatch, config, [cfg_unannotated_vision]) + monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [cfg_unannotated_vision]) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _premium_allowed, ) diff --git a/surfsense_backend/tests/unit/services/test_billable_call.py b/surfsense_backend/tests/unit/services/test_billable_call.py index 8e2c2f1da..c820724ed 100644 --- a/surfsense_backend/tests/unit/services/test_billable_call.py +++ b/surfsense_backend/tests/unit/services/test_billable_call.py @@ -38,13 +38,11 @@ class _FakeQuotaResult: used: int = 0, limit: int = 5_000_000, remaining: int = 5_000_000, - balance: int = 5_000_000, ) -> None: self.allowed = allowed self.used = used self.limit = limit self.remaining = remaining - self.balance = balance class _FakeSession: @@ -120,17 +118,17 @@ def _patch_isolation_layer( return object() monkeypatch.setattr( - "app.services.billable_calls.TokenQuotaService.credit_reserve", + "app.services.billable_calls.TokenQuotaService.premium_reserve", _fake_reserve, raising=False, ) monkeypatch.setattr( - "app.services.billable_calls.TokenQuotaService.credit_finalize", + "app.services.billable_calls.TokenQuotaService.premium_finalize", _fake_finalize, raising=False, ) monkeypatch.setattr( - "app.services.billable_calls.TokenQuotaService.credit_release", + "app.services.billable_calls.TokenQuotaService.premium_release", _fake_release, raising=False, ) @@ -203,7 +201,9 @@ async def test_premium_reserve_denied_raises_quota_insufficient(monkeypatch): spies = _patch_isolation_layer( monkeypatch, - reserve_result=_FakeQuotaResult(allowed=False, balance=0, remaining=0), + reserve_result=_FakeQuotaResult( + allowed=False, used=5_000_000, limit=5_000_000, remaining=0 + ), ) user_id = uuid4() @@ -220,7 +220,8 @@ async def test_premium_reserve_denied_raises_quota_insufficient(monkeypatch): err = exc_info.value assert err.usage_type == "image_generation" - assert err.balance_micros == 0 + assert err.used_micros == 5_000_000 + assert err.limit_micros == 5_000_000 assert err.remaining_micros == 0 # Reserve was attempted, but no finalize/release on a denied reserve # — the reservation never actually held credit. @@ -531,7 +532,7 @@ async def test_premium_video_denial_raises_quota_insufficient(monkeypatch): spies = _patch_isolation_layer( monkeypatch, reserve_result=_FakeQuotaResult( - allowed=False, balance=500_000, remaining=500_000 + allowed=False, used=4_500_000, limit=5_000_000, remaining=500_000 ), ) user_id = uuid4() @@ -551,7 +552,6 @@ async def test_premium_video_denial_raises_quota_insufficient(monkeypatch): err = exc_info.value assert err.usage_type == "video_presentation_generation" - assert err.balance_micros == 500_000 assert err.remaining_micros == 500_000 assert spies["reserve"][0]["reserve_micros"] == 1_000_000 assert spies["finalize"] == [] diff --git a/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py b/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py index adcfeed48..571e7d15b 100644 --- a/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py +++ b/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py @@ -1,4 +1,19 @@ -"""Image-gen call sites must pass each config's explicit ``api_base``.""" +"""Defense-in-depth: image-gen call sites must not let an empty +``api_base`` fall through to LiteLLM's module-global ``litellm.api_base``. + +The bug repro: an OpenRouter image-gen config ships +``api_base=""``. The pre-fix call site in +``image_generation_routes._execute_image_generation`` did +``if cfg.get("api_base"): kwargs["api_base"] = cfg["api_base"]`` which +silently dropped the empty string. LiteLLM then fell back to +``litellm.api_base`` (commonly inherited from ``AZURE_OPENAI_ENDPOINT``) +and OpenRouter's ``image_generation/transformation`` appended +``/chat/completions`` to it → 404 ``Resource not found``. + +This test pins the post-fix behaviour: with an empty ``api_base`` in +the config, the call site MUST set ``api_base`` to OpenRouter's public +URL instead of leaving it unset. +""" from __future__ import annotations @@ -11,23 +26,22 @@ pytestmark = pytest.mark.unit @pytest.mark.asyncio -async def test_global_openrouter_image_gen_sets_explicit_api_base(): - """The global-config branch forwards the explicit OpenRouter base.""" +async def test_global_openrouter_image_gen_sets_api_base_when_config_empty(): + """The global-config branch (``config_id < 0``) of + ``_execute_image_generation`` must apply the resolver and pin + ``api_base`` to OpenRouter when the config ships an empty string. + """ from app.routes import image_generation_routes - global_model = { + cfg = { "id": -20_001, - "connection_id": -101, - "model_id": "openai/gpt-image-1", - "supports_image_generation": True, - "capabilities_override": {}, - } - global_connection = { - "id": -101, - "provider": "openrouter", + "name": "GPT Image 1 (OpenRouter)", + "provider": "OPENROUTER", + "model_name": "openai/gpt-image-1", "api_key": "sk-or-test", - "base_url": "https://openrouter.ai/api/v1", - "extra": {}, + "api_base": "", # the original bug shape + "api_version": None, + "litellm_params": {}, } captured: dict = {} @@ -37,7 +51,7 @@ async def test_global_openrouter_image_gen_sets_explicit_api_base(): return MagicMock(model_dump=lambda: {"data": []}, _hidden_params={}) image_gen = MagicMock() - image_gen.image_gen_model_id = global_model["id"] + image_gen.image_generation_config_id = cfg["id"] image_gen.prompt = "test" image_gen.n = 1 image_gen.quality = None @@ -47,19 +61,14 @@ async def test_global_openrouter_image_gen_sets_explicit_api_base(): image_gen.model = None search_space = MagicMock() - search_space.image_gen_model_id = global_model["id"] + search_space.image_generation_config_id = cfg["id"] session = MagicMock() with ( patch.object( image_generation_routes, - "_get_global_model", - return_value=global_model, - ), - patch.object( - image_generation_routes, - "_get_global_connection", - return_value=global_connection, + "_get_global_image_gen_config", + return_value=cfg, ), patch.object( image_generation_routes, @@ -71,31 +80,30 @@ async def test_global_openrouter_image_gen_sets_explicit_api_base(): session=session, image_gen=image_gen, search_space=search_space ) + # The whole point of the fix: even with empty ``api_base`` in the + # config, we forward OpenRouter's public URL so the call doesn't + # inherit an Azure endpoint. assert captured.get("api_base") == "https://openrouter.ai/api/v1" assert captured["model"] == "openrouter/openai/gpt-image-1" @pytest.mark.asyncio -async def test_generate_image_tool_global_sets_explicit_api_base(): - """Same explicit-base behavior at the agent tool entry point — both surfaces share +async def test_generate_image_tool_global_sets_api_base_when_config_empty(): + """Same defense at the agent tool entry point — both surfaces share the same OpenRouter config payloads.""" from app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.tools import ( generate_image as gi_module, ) - global_model = { + cfg = { "id": -20_001, - "connection_id": -101, - "model_id": "openai/gpt-image-1", - "supports_image_generation": True, - "capabilities_override": {}, - } - global_connection = { - "id": -101, - "provider": "openrouter", + "name": "GPT Image 1 (OpenRouter)", + "provider": "OPENROUTER", + "model_name": "openai/gpt-image-1", "api_key": "sk-or-test", - "base_url": "https://openrouter.ai/api/v1", - "extra": {}, + "api_base": "", + "api_version": None, + "litellm_params": {}, } captured: dict = {} @@ -111,7 +119,7 @@ async def test_generate_image_tool_global_sets_explicit_api_base(): search_space = MagicMock() search_space.id = 1 - search_space.image_gen_model_id = global_model["id"] + search_space.image_generation_config_id = cfg["id"] session_cm = AsyncMock() session = AsyncMock() @@ -134,10 +142,7 @@ async def test_generate_image_tool_global_sets_explicit_api_base(): with ( patch.object(gi_module, "shielded_async_session", return_value=session_cm), - patch.object(gi_module, "_get_global_model", return_value=global_model), - patch.object( - gi_module, "_get_global_connection", return_value=global_connection - ), + patch.object(gi_module, "_get_global_image_gen_config", return_value=cfg), patch.object( gi_module, "aimage_generation", side_effect=fake_aimage_generation ), @@ -166,16 +171,20 @@ async def test_generate_image_tool_global_sets_explicit_api_base(): assert captured["model"] == "openrouter/openai/gpt-image-1" -def test_image_gen_router_deployment_sets_explicit_api_base(): - """The Auto-mode router pool carries explicit api_base into deployments.""" +def test_image_gen_router_deployment_sets_api_base_when_config_empty(): + """The Auto-mode router pool must also resolve ``api_base`` when an + OpenRouter config ships an empty string. The deployment dict is fed + straight to ``litellm.Router``, so a missing ``api_base`` would + leak the same way as the direct call sites. + """ from app.services.image_gen_router_service import ImageGenRouterService deployment = ImageGenRouterService._config_to_deployment( { "model_name": "openai/gpt-image-1", - "litellm_provider": "openrouter", + "provider": "OPENROUTER", "api_key": "sk-or-test", - "api_base": "https://openrouter.ai/api/v1", + "api_base": "", } ) assert deployment is not None diff --git a/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py b/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py index 349a7d445..c309ff881 100644 --- a/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py +++ b/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py @@ -25,10 +25,10 @@ def _fake_yaml_config( return { "id": id, "name": f"yaml-{id}", - "litellm_provider": "openai", + "provider": "OPENAI", "model_name": model_name, "api_key": "sk-test", - "api_base": "https://api.openai.com/v1", + "api_base": "", "billing_tier": billing_tier, "rpm": 100, "tpm": 100_000, @@ -54,10 +54,10 @@ def _fake_openrouter_config( return { "id": id, "name": f"or-{id}", - "litellm_provider": "openrouter", + "provider": "OPENROUTER", "model_name": model_name, "api_key": "sk-or-test", - "api_base": "https://openrouter.ai/api/v1", + "api_base": "", "billing_tier": billing_tier, "rpm": 20 if billing_tier == "free" else 200, "tpm": 100_000 if billing_tier == "free" else 1_000_000, @@ -217,64 +217,10 @@ def test_auto_model_pin_candidates_include_dynamic_openrouter(): model_name="meta-llama/llama-3.3-70b:free", billing_tier="free", ) - global_connections = [ - { - "id": -110_001, - "provider": "openrouter", - "scope": "GLOBAL", - "enabled": True, - }, - { - "id": -110_002, - "provider": "openrouter", - "scope": "GLOBAL", - "enabled": True, - }, - ] - global_models = [ - { - "id": or_premium["id"], - "connection_id": -110_001, - "model_id": or_premium["model_name"], - "display_name": or_premium["name"], - "supports_chat": True, - "supports_image_input": True, - "supports_tools": True, - "supports_image_generation": False, - "capabilities_override": {}, - "billing_tier": or_premium["billing_tier"], - "catalog": { - "auto_pin_tier": "A", - "quality_score": 50, - }, - }, - { - "id": or_free["id"], - "connection_id": -110_002, - "model_id": or_free["model_name"], - "display_name": or_free["name"], - "supports_chat": True, - "supports_image_input": True, - "supports_tools": True, - "supports_image_generation": False, - "capabilities_override": {}, - "billing_tier": or_free["billing_tier"], - "catalog": { - "auto_pin_tier": "A", - "quality_score": 50, - }, - }, - ] - original_configs = config.GLOBAL_LLM_CONFIGS - original_connections = config.GLOBAL_CONNECTIONS - original_models = config.GLOBAL_MODELS + original = config.GLOBAL_LLM_CONFIGS try: config.GLOBAL_LLM_CONFIGS = [or_premium, or_free] - config.GLOBAL_CONNECTIONS = global_connections - config.GLOBAL_MODELS = global_models candidate_ids = {c["id"] for c in _global_candidates()} assert candidate_ids == {-10_001, -10_002} finally: - config.GLOBAL_LLM_CONFIGS = original_configs - config.GLOBAL_CONNECTIONS = original_connections - config.GLOBAL_MODELS = original_models + config.GLOBAL_LLM_CONFIGS = original diff --git a/surfsense_backend/tests/unit/services/test_model_connections.py b/surfsense_backend/tests/unit/services/test_model_connections.py deleted file mode 100644 index 937eda806..000000000 --- a/surfsense_backend/tests/unit/services/test_model_connections.py +++ /dev/null @@ -1,78 +0,0 @@ -from app.services.global_model_catalog import materialize_global_model_catalog -from app.services.model_resolver import ensure_v1, to_litellm - - -def test_openai_compatible_resolver_uses_explicit_api_base() -> None: - model, kwargs = to_litellm( - { - "protocol": "OPENAI_COMPATIBLE", - "provider": "openai", - "base_url": "http://host.docker.internal:1234/v1", - "api_key": "local-key", - "extra": {}, - }, - "qwen/qwen3", - ) - - assert model == "openai/qwen/qwen3" - assert kwargs["api_base"] == "http://host.docker.internal:1234/v1" - assert kwargs["api_key"] == "local-key" - assert ensure_v1("http://example.com/v1") == "http://example.com/v1" - - -def test_ollama_resolver_uses_native_api_base() -> None: - model, kwargs = to_litellm( - { - "protocol": "OLLAMA", - "provider": "ollama_chat", - "base_url": "http://host.docker.internal:11434", - "api_key": None, - "extra": {}, - }, - "llama3.2", - ) - - assert model == "ollama_chat/llama3.2" - assert kwargs["api_base"] == "http://host.docker.internal:11434" - - -def test_global_materialization_preserves_tier_and_keeps_key_server_side() -> None: - connections, models = materialize_global_model_catalog( - chat_configs=[ - { - "id": -101, - "name": "OpenRouter Free", - "litellm_provider": "openrouter", - "model_name": "meta-llama/llama-3.1-8b-instruct:free", - "api_key": "sk-global-secret", - "api_base": "https://openrouter.ai/api/v1", - "billing_tier": "free", - "anonymous_enabled": True, - "seo_enabled": True, - "rpm": 10, - "tpm": 1000, - }, - { - "id": -102, - "name": "OpenRouter Premium", - "litellm_provider": "openrouter", - "model_name": "anthropic/claude-sonnet-4", - "api_key": "sk-global-secret", - "api_base": "https://openrouter.ai/api/v1", - "billing_tier": "premium", - }, - ], - image_configs=[], - ) - - assert len(connections) == 1 - assert connections[0]["api_key"] == "sk-global-secret" - assert {model["billing_tier"] for model in models} == {"free", "premium"} - assert models[0]["catalog"]["anonymous_enabled"] is True - assert models[0]["catalog"]["rpm"] == 10 - - public_connections = [ - {key: value for key, value in connection.items() if key != "api_key"} - for connection in connections - ] - assert "sk-" not in repr(public_connections) diff --git a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py index 147d62592..88fcf2db3 100644 --- a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py +++ b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py @@ -217,7 +217,7 @@ def test_generate_configs_drops_non_text_and_non_tool_models(): # --------------------------------------------------------------------------- -# _generate_image_gen_configs +# _generate_image_gen_configs / _generate_vision_llm_configs # --------------------------------------------------------------------------- @@ -263,15 +263,18 @@ def test_generate_image_gen_configs_filters_by_image_output(): # Each config must carry ``billing_tier`` for routing in image_generation_routes. for c in cfgs: assert c["billing_tier"] in {"free", "premium"} - assert c["provider"] == "openrouter" + assert c["provider"] == "OPENROUTER" assert c[_OPENROUTER_DYNAMIC_MARKER] is True - # Emit the OpenRouter base URL at source so every call path passes an - # explicit api_base and cannot inherit a process-global endpoint. + # Defense-in-depth: emit the OpenRouter base URL at source so a + # downstream call site that forgets ``resolve_api_base`` still + # doesn't 404 against an inherited Azure endpoint. assert c["api_base"] == "https://openrouter.ai/api/v1" def test_generate_image_gen_configs_assigns_image_id_offset(): - """Image configs use their own id_offset (-20000).""" + """Image configs use a different id_offset (-20000) so their negative + IDs don't collide with chat configs (-10000) or vision configs (-30000). + """ from app.services.openrouter_integration_service import ( _generate_image_gen_configs, ) @@ -288,3 +291,90 @@ def test_generate_image_gen_configs_assigns_image_id_offset(): cfgs = _generate_image_gen_configs(raw, dict(_SETTINGS_BASE)) assert all(c["id"] < -20_000 + 1 for c in cfgs) assert all(c["id"] > -29_000_000 for c in cfgs) + + +def test_generate_vision_llm_configs_filters_by_image_input_text_output(): + """Vision LLMs must accept image input AND emit text — pure image-gen + (no text out) and text-only (no image in) models are excluded. + """ + from app.services.openrouter_integration_service import ( + _generate_vision_llm_configs, + ) + + raw = [ + # GPT-4o: vision LLM (image in, text out) — must emit. + { + "id": "openai/gpt-4o", + "architecture": { + "input_modalities": ["text", "image"], + "output_modalities": ["text"], + }, + "context_length": 128_000, + "pricing": {"prompt": "0.000005", "completion": "0.000015"}, + }, + # Pure image generator — image *output*, no text out. Must NOT emit. + { + "id": "openai/gpt-image-1", + "architecture": { + "input_modalities": ["text"], + "output_modalities": ["image"], + }, + "context_length": 4_000, + "pricing": {"prompt": "0", "completion": "0"}, + }, + # Pure text model (no image in). Must NOT emit. + { + "id": "anthropic/claude-3-haiku", + "architecture": { + "input_modalities": ["text"], + "output_modalities": ["text"], + }, + "context_length": 200_000, + "pricing": {"prompt": "0.000001", "completion": "0.000005"}, + }, + ] + + cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE)) + names = {c["model_name"] for c in cfgs} + assert names == {"openai/gpt-4o"} + + cfg = cfgs[0] + assert cfg["billing_tier"] == "premium" + # Pricing carried inline so pricing_registration can register vision + # under ``openrouter/openai/gpt-4o`` even if the chat catalogue cache + # is cleared. + assert cfg["input_cost_per_token"] == pytest.approx(5e-6) + assert cfg["output_cost_per_token"] == pytest.approx(15e-6) + assert cfg[_OPENROUTER_DYNAMIC_MARKER] is True + # Defense-in-depth: emit the OpenRouter base URL at source so a + # downstream call site that forgets ``resolve_api_base`` still + # doesn't inherit an Azure endpoint. + assert cfg["api_base"] == "https://openrouter.ai/api/v1" + + +def test_generate_vision_llm_configs_drops_chat_only_filters(): + """A small-context vision model that doesn't advertise tool calling is + still a valid vision LLM for "describe this image" prompts. The chat + filters (``supports_tool_calling``, ``has_sufficient_context``) must + NOT be applied to vision emission. + """ + from app.services.openrouter_integration_service import ( + _generate_vision_llm_configs, + ) + + raw = [ + { + "id": "tiny/vision-mini", + "architecture": { + "input_modalities": ["text", "image"], + "output_modalities": ["text"], + }, + "supported_parameters": [], # no tools + "context_length": 4_000, # well below MIN_CONTEXT_LENGTH + "pricing": {"prompt": "0.0000001", "completion": "0.0000005"}, + } + ] + + cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE)) + assert len(cfgs) == 1 + assert cfgs[0]["model_name"] == "tiny/vision-mini" diff --git a/surfsense_backend/tests/unit/services/test_or_health_enrichment.py b/surfsense_backend/tests/unit/services/test_or_health_enrichment.py index 00706f43c..1c74aa928 100644 --- a/surfsense_backend/tests/unit/services/test_or_health_enrichment.py +++ b/surfsense_backend/tests/unit/services/test_or_health_enrichment.py @@ -25,7 +25,7 @@ def _or_cfg( ) -> dict: return { "id": cid, - "provider": "openrouter", + "provider": "OPENROUTER", "model_name": model_name, "billing_tier": tier, "auto_pin_tier": "B" if tier == "premium" else "C", @@ -144,7 +144,7 @@ async def test_enrich_health_only_touches_or_provider(monkeypatch): """YAML cfgs that aren't OPENROUTER must be skipped entirely.""" yaml_cfg = { "id": -1, - "litellm_provider": "azure", + "provider": "AZURE_OPENAI", "model_name": "gpt-5", "billing_tier": "premium", "auto_pin_tier": "A", @@ -313,7 +313,7 @@ async def test_enrich_health_no_or_cfgs_is_noop(monkeypatch): """When the catalogue has no OR cfgs at all, no HTTP calls fire.""" yaml_cfg: dict[str, Any] = { "id": -1, - "litellm_provider": "azure", + "provider": "AZURE_OPENAI", "model_name": "gpt-5", "billing_tier": "premium", } diff --git a/surfsense_backend/tests/unit/services/test_pricing_registration.py b/surfsense_backend/tests/unit/services/test_pricing_registration.py index e1437ca24..e97250ff2 100644 --- a/surfsense_backend/tests/unit/services/test_pricing_registration.py +++ b/surfsense_backend/tests/unit/services/test_pricing_registration.py @@ -186,7 +186,7 @@ def test_openrouter_models_register_under_aliases(monkeypatch): [ { "id": 1, - "litellm_provider": "openrouter", + "provider": "OPENROUTER", "model_name": "anthropic/claude-3-5-sonnet", } ], @@ -228,7 +228,7 @@ def test_yaml_override_registers_under_alias_set(monkeypatch): [ { "id": 1, - "litellm_provider": "azure", + "provider": "AZURE_OPENAI", "model_name": "gpt-5.4", "litellm_params": { "base_model": "gpt-5.4", @@ -243,6 +243,7 @@ def test_yaml_override_registers_under_alias_set(monkeypatch): keys = spy.all_keys assert "gpt-5.4" in keys + assert "azure_openai/gpt-5.4" in keys assert "azure/gpt-5.4" in keys payload = spy.calls[0] @@ -270,7 +271,7 @@ def test_no_override_means_no_registration(monkeypatch): [ { "id": 1, - "litellm_provider": "openai", + "provider": "OPENAI", "model_name": "gpt-4o", "litellm_params": {"base_model": "gpt-4o"}, } @@ -301,7 +302,7 @@ def test_openrouter_skipped_when_pricing_missing(monkeypatch): [ { "id": 1, - "litellm_provider": "openrouter", + "provider": "OPENROUTER", "model_name": "anthropic/claude-3-5-sonnet", } ], @@ -348,12 +349,12 @@ def test_register_continues_after_individual_failure(monkeypatch, caplog): [ { "id": 1, - "litellm_provider": "openrouter", + "provider": "OPENROUTER", "model_name": "anthropic/claude-3-5-sonnet", }, { "id": 2, - "litellm_provider": "openai", + "provider": "OPENAI", "model_name": "custom-deployment", "litellm_params": { "base_model": "custom-deployment", @@ -368,3 +369,79 @@ def test_register_continues_after_individual_failure(monkeypatch, caplog): # The good config still registered. assert any("custom-deployment" in payload for payload in successful_calls) + + +def test_vision_configs_registered_with_chat_shape(monkeypatch): + """``register_pricing_from_global_configs`` walks + ``GLOBAL_VISION_LLM_CONFIGS`` in addition to the chat configs so vision + calls (during indexing) bill correctly. Vision configs use the same + chat-shape token prices, but image-gen pricing is intentionally NOT + registered here (handled via ``response_cost`` in LiteLLM). + """ + from app.config import config + from app.services.pricing_registration import register_pricing_from_global_configs + + spy = _patch_register(monkeypatch) + _patch_openrouter_pricing( + monkeypatch, + {"openai/gpt-4o": {"prompt": "0.000005", "completion": "0.000015"}}, + ) + + # No chat configs — only vision. Proves the vision walk is a separate + # iteration, not piggy-backed on the chat list. + monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", []) + monkeypatch.setattr( + config, + "GLOBAL_VISION_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "OPENROUTER", + "model_name": "openai/gpt-4o", + "billing_tier": "premium", + "input_cost_per_token": 5e-6, + "output_cost_per_token": 15e-6, + } + ], + ) + + register_pricing_from_global_configs() + + assert "openrouter/openai/gpt-4o" in spy.all_keys + payload_value = spy.calls[0]["openrouter/openai/gpt-4o"] + assert payload_value["mode"] == "chat" + assert payload_value["litellm_provider"] == "openrouter" + assert payload_value["input_cost_per_token"] == pytest.approx(5e-6) + assert payload_value["output_cost_per_token"] == pytest.approx(15e-6) + + +def test_vision_with_inline_pricing_when_or_cache_missing(monkeypatch): + """If the OpenRouter pricing cache misses a vision model (different + catalogue surface), the vision walk falls back to inline + ``input_cost_per_token``/``output_cost_per_token`` on the cfg itself. + """ + from app.config import config + from app.services.pricing_registration import register_pricing_from_global_configs + + spy = _patch_register(monkeypatch) + _patch_openrouter_pricing(monkeypatch, {}) + + monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", []) + monkeypatch.setattr( + config, + "GLOBAL_VISION_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "OPENROUTER", + "model_name": "google/gemini-2.5-flash", + "billing_tier": "premium", + "input_cost_per_token": 1e-6, + "output_cost_per_token": 4e-6, + } + ], + ) + + register_pricing_from_global_configs() + + assert "openrouter/google/gemini-2.5-flash" in spy.all_keys diff --git a/surfsense_backend/tests/unit/services/test_provider_api_base.py b/surfsense_backend/tests/unit/services/test_provider_api_base.py new file mode 100644 index 000000000..12cd0a3d5 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_provider_api_base.py @@ -0,0 +1,107 @@ +"""Unit tests for the shared ``api_base`` resolver. + +The cascade exists so vision and image-gen call sites can't silently +inherit ``litellm.api_base`` (commonly set by ``AZURE_OPENAI_ENDPOINT``) +when an OpenRouter / Groq / etc. config ships an empty string. See +``provider_api_base`` module docstring for the original repro +(OpenRouter image-gen 404-ing against an Azure endpoint). +""" + +from __future__ import annotations + +import pytest + +from app.services.provider_api_base import ( + PROVIDER_DEFAULT_API_BASE, + PROVIDER_KEY_DEFAULT_API_BASE, + resolve_api_base, +) + +pytestmark = pytest.mark.unit + + +def test_config_value_wins_over_defaults(): + """A non-empty config value is always returned verbatim, even when the + provider has a default — the operator gets the last word.""" + result = resolve_api_base( + provider="OPENROUTER", + provider_prefix="openrouter", + config_api_base="https://my-openrouter-mirror.example.com/v1", + ) + assert result == "https://my-openrouter-mirror.example.com/v1" + + +def test_provider_key_default_when_config_missing(): + """``DEEPSEEK`` shares the ``openai`` LiteLLM prefix but has its own + base URL — the provider-key map must take precedence over the prefix + map so DeepSeek requests don't go to OpenAI.""" + result = resolve_api_base( + provider="DEEPSEEK", + provider_prefix="openai", + config_api_base=None, + ) + assert result == PROVIDER_KEY_DEFAULT_API_BASE["DEEPSEEK"] + + +def test_provider_prefix_default_when_no_key_default(): + result = resolve_api_base( + provider="OPENROUTER", + provider_prefix="openrouter", + config_api_base=None, + ) + assert result == PROVIDER_DEFAULT_API_BASE["openrouter"] + + +def test_unknown_provider_returns_none(): + """When neither map matches we return ``None`` so the caller can let + LiteLLM apply its own provider-integration default (Azure deployment + URL, custom-provider URL, etc.).""" + result = resolve_api_base( + provider="SOMETHING_NEW", + provider_prefix="something_new", + config_api_base=None, + ) + assert result is None + + +def test_empty_string_config_treated_as_missing(): + """The original bug: OpenRouter dynamic configs ship ``api_base=""`` + and downstream call sites use ``if cfg.get("api_base"):`` — empty + strings are falsy in Python but the cascade has to step in anyway.""" + result = resolve_api_base( + provider="OPENROUTER", + provider_prefix="openrouter", + config_api_base="", + ) + assert result == PROVIDER_DEFAULT_API_BASE["openrouter"] + + +def test_whitespace_only_config_treated_as_missing(): + """A config value of ``" "`` is a configuration mistake — treat it + as missing instead of forwarding whitespace to LiteLLM (which would + almost certainly 404).""" + result = resolve_api_base( + provider="OPENROUTER", + provider_prefix="openrouter", + config_api_base=" ", + ) + assert result == PROVIDER_DEFAULT_API_BASE["openrouter"] + + +def test_provider_case_insensitive(): + """Some call sites pass the provider lowercase (DB enum value), others + uppercase (YAML key). Both must resolve.""" + upper = resolve_api_base( + provider="DEEPSEEK", provider_prefix="openai", config_api_base=None + ) + lower = resolve_api_base( + provider="deepseek", provider_prefix="openai", config_api_base=None + ) + assert upper == lower == PROVIDER_KEY_DEFAULT_API_BASE["DEEPSEEK"] + + +def test_all_inputs_none_returns_none(): + assert ( + resolve_api_base(provider=None, provider_prefix=None, config_api_base=None) + is None + ) diff --git a/surfsense_backend/tests/unit/services/test_provider_capabilities.py b/surfsense_backend/tests/unit/services/test_provider_capabilities.py index d20af14ae..aac88977f 100644 --- a/surfsense_backend/tests/unit/services/test_provider_capabilities.py +++ b/surfsense_backend/tests/unit/services/test_provider_capabilities.py @@ -32,7 +32,7 @@ pytestmark = pytest.mark.unit def test_or_modalities_with_image_returns_true(): assert ( derive_supports_image_input( - provider="openrouter", + provider="OPENROUTER", model_name="openai/gpt-4o", openrouter_input_modalities=["text", "image"], ) @@ -43,7 +43,7 @@ def test_or_modalities_with_image_returns_true(): def test_or_modalities_text_only_returns_false(): assert ( derive_supports_image_input( - provider="openrouter", + provider="OPENROUTER", model_name="deepseek/deepseek-v3.2-exp", openrouter_input_modalities=["text"], ) @@ -57,7 +57,7 @@ def test_or_modalities_empty_list_returns_false(): to LiteLLM.""" assert ( derive_supports_image_input( - provider="openrouter", + provider="OPENROUTER", model_name="weird/empty-modalities", openrouter_input_modalities=[], ) @@ -70,7 +70,7 @@ def test_or_modalities_none_falls_through_to_litellm(): to LiteLLM. Using ``openai/gpt-4o`` which is in LiteLLM's map.""" assert ( derive_supports_image_input( - provider="openai", + provider="OPENAI", model_name="gpt-4o", openrouter_input_modalities=None, ) @@ -86,7 +86,7 @@ def test_or_modalities_none_falls_through_to_litellm(): def test_litellm_known_vision_model_returns_true(): assert ( derive_supports_image_input( - provider="openai", + provider="OPENAI", model_name="gpt-4o", ) is True @@ -100,7 +100,7 @@ def test_litellm_base_model_wins_over_model_name(): doesn't know) would shadow the real capability.""" assert ( derive_supports_image_input( - provider="azure", + provider="AZURE_OPENAI", model_name="my-azure-deployment-id", base_model="gpt-4o", ) @@ -112,7 +112,7 @@ def test_litellm_unknown_model_default_allows(): """Default-allow on unknown — the safety net is the actual block.""" assert ( derive_supports_image_input( - provider="custom", + provider="CUSTOM", model_name="brand-new-model-x9-unmapped", custom_provider="brand_new_proxy", ) @@ -128,7 +128,7 @@ def test_litellm_known_text_only_returns_false(): # Sanity: confirm the helper's negative path. We use a small model # known not to support vision per the map. result = derive_supports_image_input( - provider="openai", + provider="DEEPSEEK", model_name="deepseek-chat", ) # We accept either False (LiteLLM said explicit no) or True @@ -147,7 +147,7 @@ def test_litellm_known_text_only_returns_false(): def test_is_known_text_only_returns_false_for_vision_model(): assert ( is_known_text_only_chat_model( - provider="openai", + provider="OPENAI", model_name="gpt-4o", ) is False @@ -160,7 +160,7 @@ def test_is_known_text_only_returns_false_for_unknown_model(): fixing.""" assert ( is_known_text_only_chat_model( - provider="custom", + provider="CUSTOM", model_name="brand-new-model-x9-unmapped", custom_provider="brand_new_proxy", ) @@ -181,7 +181,7 @@ def test_is_known_text_only_returns_false_when_lookup_raises(monkeypatch): assert ( is_known_text_only_chat_model( - provider="openai", + provider="OPENAI", model_name="gpt-4o", ) is False @@ -201,7 +201,7 @@ def test_is_known_text_only_returns_true_on_explicit_false(monkeypatch): assert ( is_known_text_only_chat_model( - provider="openai", + provider="OPENAI", model_name="any-model", ) is True @@ -218,7 +218,7 @@ def test_is_known_text_only_returns_false_on_supports_vision_true(monkeypatch): assert ( is_known_text_only_chat_model( - provider="openai", + provider="OPENAI", model_name="any-model", ) is False @@ -237,7 +237,7 @@ def test_is_known_text_only_returns_false_on_missing_key(monkeypatch): assert ( is_known_text_only_chat_model( - provider="openai", + provider="OPENAI", model_name="any-model", ) is False diff --git a/surfsense_backend/tests/unit/services/test_quality_score.py b/surfsense_backend/tests/unit/services/test_quality_score.py index cb3f7523a..6fbc8fd62 100644 --- a/surfsense_backend/tests/unit/services/test_quality_score.py +++ b/surfsense_backend/tests/unit/services/test_quality_score.py @@ -1,4 +1,4 @@ -"""Unit tests for the Auto quality scoring module.""" +"""Unit tests for the Auto (Fastest) quality scoring module.""" from __future__ import annotations @@ -228,7 +228,7 @@ def test_static_score_or_recent_release_beats_year_old_same_provider(): def test_static_score_yaml_includes_operator_bonus(): cfg = { - "litellm_provider": "azure", + "provider": "AZURE_OPENAI", "model_name": "gpt-5", "litellm_params": {"base_model": "azure/gpt-5"}, } @@ -238,7 +238,7 @@ def test_static_score_yaml_includes_operator_bonus(): def test_static_score_yaml_unknown_provider_still_carries_bonus(): cfg = { - "litellm_provider": "some_new_provider", + "provider": "SOME_NEW_PROVIDER", "model_name": "weird-model", } score = static_score_yaml(cfg) @@ -247,7 +247,7 @@ def test_static_score_yaml_unknown_provider_still_carries_bonus(): def test_static_score_yaml_clamped_0_to_100(): cfg = { - "litellm_provider": "azure", + "provider": "AZURE_OPENAI", "model_name": "gpt-5", "litellm_params": {"base_model": "azure/gpt-5"}, } diff --git a/surfsense_backend/tests/unit/services/test_quota_checked_vision_llm.py b/surfsense_backend/tests/unit/services/test_quota_checked_vision_llm.py index 0f5dd531f..9e35b6f9c 100644 --- a/surfsense_backend/tests/unit/services/test_quota_checked_vision_llm.py +++ b/surfsense_backend/tests/unit/services/test_quota_checked_vision_llm.py @@ -105,7 +105,8 @@ async def test_ainvoke_propagates_quota_insufficient_error(monkeypatch): async def _denying_billable_call(**_kwargs): raise QuotaInsufficientError( usage_type="vision_extraction", - balance_micros=5_000_000, + used_micros=5_000_000, + limit_micros=5_000_000, remaining_micros=0, ) yield # unreachable but required for asynccontextmanager type diff --git a/surfsense_backend/tests/unit/services/test_token_quota_service_cost.py b/surfsense_backend/tests/unit/services/test_token_quota_service_cost.py index 9eeb55a4d..63681828d 100644 --- a/surfsense_backend/tests/unit/services/test_token_quota_service_cost.py +++ b/surfsense_backend/tests/unit/services/test_token_quota_service_cost.py @@ -112,77 +112,6 @@ def test_per_message_summary_groups_cost_by_model(): assert summary["gpt-4o-mini"]["cost_micros"] == 200 -def test_add_reconciles_metadata_when_litellm_strips_provider_prefix(): - """Regression: LiteLLM's ``get_llm_provider`` strips the provider prefix we - add in ``to_litellm`` (``azure/gpt-5.2-chat`` → ``gpt-5.2-chat`` because - ``azure`` is in ``litellm.provider_list``), so the success callback reports - the bare model. Metadata registered under the *prefixed* string must still - attach to the call so the per-model breakdown carries provider/display_name - — otherwise the UI falls back to a bare-name collision and mis-attributes an - Azure turn to an OpenRouter model (e.g. shows "OpenAI: GPT-5.2 Chat"). - """ - from app.services.token_tracking_service import TurnTokenAccumulator - - acc = TurnTokenAccumulator() - acc.register_model_metadata( - model="azure/gpt-5.2-chat", - model_ref="global:-1", - model_id="gpt-5.2-chat", - display_name="Azure GPT 5.2", - provider="azure", - ) - # LiteLLM callback fires with the prefix-stripped model name. - acc.add( - model="gpt-5.2-chat", - prompt_tokens=100, - completion_tokens=50, - total_tokens=150, - cost_micros=4_000, - ) - - summary = acc.per_message_summary() - entry = summary["gpt-5.2-chat"] - assert entry["provider"] == "azure" - assert entry["display_name"] == "Azure GPT 5.2" - assert entry["model_id"] == "gpt-5.2-chat" - assert entry["model_ref"] == "global:-1" - - -def test_add_prefers_exact_metadata_over_bare_alias(): - """When the callback model matches a registered key exactly, the exact - metadata wins even if another model shares the same bare name — so a turn - that legitimately used two same-named deployments stays correctly - attributed.""" - from app.services.token_tracking_service import TurnTokenAccumulator - - acc = TurnTokenAccumulator() - acc.register_model_metadata( - model="azure/gpt-5.2-chat", - model_ref="global:-1", - model_id="gpt-5.2-chat", - display_name="Azure GPT 5.2", - provider="azure", - ) - acc.register_model_metadata( - model="openai/gpt-5.2-chat", - model_ref="db:7", - model_id="gpt-5.2-chat", - display_name="OpenAI GPT 5.2", - provider="openai", - ) - acc.add( - model="openai/gpt-5.2-chat", - prompt_tokens=10, - completion_tokens=5, - total_tokens=15, - cost_micros=100, - ) - - entry = acc.per_message_summary()["openai/gpt-5.2-chat"] - assert entry["provider"] == "openai" - assert entry["display_name"] == "OpenAI GPT 5.2" - - def test_serialized_calls_includes_cost_micros(): """``serialized_calls`` is what flows into the SSE ``call_details`` payload; cost_micros must be present on each entry so the FE message-info @@ -202,10 +131,6 @@ def test_serialized_calls_includes_cost_micros(): assert serialized == [ { "model": "m", - "model_ref": None, - "model_id": None, - "display_name": None, - "provider": None, "prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2, diff --git a/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py b/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py new file mode 100644 index 000000000..5e3aa6eda --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py @@ -0,0 +1,89 @@ +"""Defense-in-depth: vision-LLM resolution must not leak ``api_base`` +defaults from ``litellm.api_base`` either. + +Vision shares the same shape as image-gen — global YAML / OpenRouter +dynamic configs ship ``api_base=""`` and the pre-fix ``get_vision_llm`` +call sites would silently drop the empty string and inherit +``AZURE_OPENAI_ENDPOINT``. ``ChatLiteLLM(...)`` doesn't 404 on +construction so we test the kwargs we hand to it instead. +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +pytestmark = pytest.mark.unit + + +@pytest.mark.asyncio +async def test_get_vision_llm_global_openrouter_sets_api_base(): + """Global negative-ID branch: an OpenRouter vision config with + ``api_base=""`` must end up calling ``SanitizedChatLiteLLM`` with + ``api_base="https://openrouter.ai/api/v1"`` — never an empty string, + never silently absent.""" + from app.services import llm_service + + cfg = { + "id": -30_001, + "name": "GPT-4o Vision (OpenRouter)", + "provider": "OPENROUTER", + "model_name": "openai/gpt-4o", + "api_key": "sk-or-test", + "api_base": "", + "api_version": None, + "litellm_params": {}, + "billing_tier": "free", + } + + search_space = MagicMock() + search_space.id = 1 + search_space.user_id = "user-x" + search_space.vision_llm_config_id = cfg["id"] + + session = AsyncMock() + scalars = MagicMock() + scalars.first.return_value = search_space + result = MagicMock() + result.scalars.return_value = scalars + session.execute.return_value = result + + captured: dict = {} + + class FakeSanitized: + def __init__(self, **kwargs): + captured.update(kwargs) + + with ( + patch( + "app.services.vision_llm_router_service.get_global_vision_llm_config", + return_value=cfg, + ), + patch( + "app.agents.chat.runtime.llm_config.SanitizedChatLiteLLM", + new=FakeSanitized, + ), + ): + await llm_service.get_vision_llm(session=session, search_space_id=1) + + assert captured.get("api_base") == "https://openrouter.ai/api/v1" + assert captured["model"] == "openrouter/openai/gpt-4o" + + +def test_vision_router_deployment_sets_api_base_when_config_empty(): + """Auto-mode vision router: deployments are fed to ``litellm.Router``, + so the resolver has to apply at deployment construction time too.""" + from app.services.vision_llm_router_service import VisionLLMRouterService + + deployment = VisionLLMRouterService._config_to_deployment( + { + "model_name": "openai/gpt-4o", + "provider": "OPENROUTER", + "api_key": "sk-or-test", + "api_base": "", + } + ) + assert deployment is not None + assert deployment["litellm_params"]["api_base"] == "https://openrouter.ai/api/v1" + assert deployment["litellm_params"]["model"] == "openrouter/openai/gpt-4o" diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_error_classifier.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_error_classifier.py deleted file mode 100644 index 98ffaa282..000000000 --- a/surfsense_backend/tests/unit/tasks/chat/streaming/test_error_classifier.py +++ /dev/null @@ -1,79 +0,0 @@ -from __future__ import annotations - -import pytest - -from app.services.llm_error_adapter import LLMErrorCategory, adapt_llm_exception -from app.tasks.chat.streaming.errors.classifier import classify_stream_exception - -pytestmark = pytest.mark.unit - - -def _exception_named(name: str, message: str) -> Exception: - return type(name, (Exception,), {})(message) - - -def test_adapter_classifies_authentication_error_by_class_name() -> None: - exc = _exception_named("AuthenticationError", "provider rejected credentials") - - adapted = adapt_llm_exception(exc) - - assert adapted.category is LLMErrorCategory.AUTH_FAILED - assert adapted.retryable is False - assert adapted.user_message == "LLM authentication failed. Check your API key." - - -def test_adapter_classifies_embedded_provider_401_payload() -> None: - exc = RuntimeError( - 'litellm.AuthenticationError: OpenrouterException - {"error":{"message":"User not found.","code":401}}' - ) - - adapted = adapt_llm_exception(exc) - - assert adapted.category is LLMErrorCategory.AUTH_FAILED - assert adapted.provider_status_code == 401 - - -def test_adapter_preserves_rate_limit_classification() -> None: - exc = RuntimeError('{"error":{"message":"Slow down","code":429}}') - - adapted = adapt_llm_exception(exc) - - assert adapted.category is LLMErrorCategory.RATE_LIMITED - assert adapted.retryable is True - - -def test_stream_classifier_maps_model_auth_to_stable_code() -> None: - exc = RuntimeError( - 'litellm.AuthenticationError: OpenrouterException - {"error":{"message":"User not found.","code":401}}' - ) - - kind, code, severity, expected, message, extra = classify_stream_exception( - exc, - flow_label="chat", - ) - - assert kind == "model_auth_failed" - assert code == "MODEL_AUTH_FAILED" - assert severity == "warn" - assert expected is True - assert "API key" in message - assert extra == { - "provider_error_category": "auth_failed", - "provider_status_code": 401, - } - - -def test_stream_classifier_keeps_unknown_errors_generic() -> None: - exc = RuntimeError("database exploded") - - kind, code, severity, expected, message, extra = classify_stream_exception( - exc, - flow_label="chat", - ) - - assert kind == "server_error" - assert code == "SERVER_ERROR" - assert severity == "error" - assert expected is False - assert message == "Error during chat: database exploded" - assert extra is None diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_llm_bundle.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_llm_bundle.py deleted file mode 100644 index cecf8be5d..000000000 --- a/surfsense_backend/tests/unit/tasks/chat/streaming/test_llm_bundle.py +++ /dev/null @@ -1,154 +0,0 @@ -"""Contracts for chat LLM construction in streaming flows. - -``stream_new_chat`` / ``stream_resume_chat`` depend on LangChain receiving -token chunks from ``ChatLiteLLM``. ``langchain-litellm`` defaults -``streaming`` to ``False``, so the shared bundle loader must opt in -explicitly for both DB-backed and global model paths. -""" - -from __future__ import annotations - -from types import SimpleNamespace -from typing import Any - -import pytest - -import app.tasks.chat.streaming.flows.shared.llm_bundle as llm_bundle - -pytestmark = pytest.mark.unit - - -class _CapturedChatLiteLLM: - calls: list[dict[str, Any]] = [] - - def __init__(self, **kwargs: Any) -> None: - self.kwargs = kwargs - self.__class__.calls.append(kwargs) - - -@pytest.fixture(autouse=True) -def _patch_common_bundle_dependencies(monkeypatch: pytest.MonkeyPatch): - """Keep these tests focused on the LLM constructor contract.""" - - _CapturedChatLiteLLM.calls = [] - - async def _fake_search_space(_session: Any, _search_space_id: int) -> SimpleNamespace: - return SimpleNamespace(id=42, user_id="user-1") - - monkeypatch.setattr(llm_bundle, "_load_search_space", _fake_search_space) - monkeypatch.setattr(llm_bundle, "SanitizedChatLiteLLM", _CapturedChatLiteLLM) - monkeypatch.setattr(llm_bundle, "register_model_usage_metadata", lambda **_kw: None) - monkeypatch.setattr( - llm_bundle, - "has_capability", - lambda _model, capability: capability in {"chat", "vision"}, - ) - - return None - - -async def test_load_llm_bundle_enables_streaming_for_db_models( - monkeypatch: pytest.MonkeyPatch, -) -> None: - connection = SimpleNamespace( - provider="openai", - api_key="sk-test", - base_url=None, - extra={"litellm_params": {"temperature": 0.1}}, - ) - model = SimpleNamespace( - id=7, - model_id="gpt-4o-mini", - display_name="GPT 4o Mini", - connection=connection, - ) - - async def _fake_db_model(_session: Any, *, model_id: int, search_space: Any) -> Any: - assert model_id == 7 - assert search_space.id == 42 - return model - - monkeypatch.setattr(llm_bundle, "_load_db_model", _fake_db_model) - monkeypatch.setattr( - llm_bundle, - "to_litellm", - lambda _conn, _model_id: ( - "openai/gpt-4o-mini", - {"api_key": "sk-test", "temperature": 0.1}, - ), - ) - - llm, agent_config, error = await llm_bundle.load_llm_bundle( - object(), - config_id=7, - search_space_id=42, - ) - - assert error is None - assert llm is not None - assert agent_config is not None - assert _CapturedChatLiteLLM.calls == [ - { - "model": "openai/gpt-4o-mini", - "api_key": "sk-test", - "temperature": 0.1, - "streaming": True, - } - ] - - -async def test_load_llm_bundle_enables_streaming_for_global_models( - monkeypatch: pytest.MonkeyPatch, -) -> None: - global_model = { - "id": -11, - "connection_id": -101, - "model_id": "claude-sonnet-4-5", - "display_name": "Claude Sonnet", - "billing_tier": "premium", - } - global_connection = { - "id": -101, - "provider": "anthropic", - "api_key": "sk-ant-test", - "base_url": None, - "extra": {"litellm_params": {"temperature": 0.2}}, - } - monkeypatch.setattr( - llm_bundle.config, - "GLOBAL_MODELS", - [global_model], - raising=False, - ) - monkeypatch.setattr( - llm_bundle.config, - "GLOBAL_CONNECTIONS", - [global_connection], - raising=False, - ) - monkeypatch.setattr( - llm_bundle, - "to_litellm", - lambda _conn, _model_id: ( - "anthropic/claude-sonnet-4-5", - {"api_key": "sk-ant-test", "temperature": 0.2}, - ), - ) - - llm, agent_config, error = await llm_bundle.load_llm_bundle( - object(), - config_id=-11, - search_space_id=42, - ) - - assert error is None - assert llm is not None - assert agent_config is not None - assert _CapturedChatLiteLLM.calls == [ - { - "model": "anthropic/claude-sonnet-4-5", - "api_key": "sk-ant-test", - "temperature": 0.2, - "streaming": True, - } - ] diff --git a/surfsense_backend/tests/unit/tasks/chat/test_llm_history_normalizer.py b/surfsense_backend/tests/unit/tasks/chat/test_llm_history_normalizer.py deleted file mode 100644 index 5b2e4fdca..000000000 --- a/surfsense_backend/tests/unit/tasks/chat/test_llm_history_normalizer.py +++ /dev/null @@ -1,61 +0,0 @@ -"""Unit tests for provider-safe LLM history normalization.""" - -from __future__ import annotations - -import pytest - -from app.tasks.chat.llm_history_normalizer import ( - assistant_content_to_llm_text, - user_content_to_llm_content, -) - -pytestmark = pytest.mark.unit - - -def test_assistant_ui_parts_drop_thinking_steps_for_llm_history() -> None: - content = [ - {"type": "data-thinking-steps", "data": [{"id": "thinking-1"}]}, - {"type": "text", "text": "visible answer"}, - ] - - assert assistant_content_to_llm_text(content) == "visible answer" - - -def test_provider_thinking_blocks_are_not_replayed_to_llm() -> None: - content = [ - {"type": "thinking", "thinking": "private reasoning"}, - {"type": "text", "text": "final answer"}, - ] - - assert assistant_content_to_llm_text(content) == "final answer" - - -def test_unknown_assistant_blocks_are_dropped() -> None: - content = [ - {"type": "redacted_thinking", "data": "hidden"}, - {"type": "tool_use", "name": "search"}, - {"type": "text", "text": "kept"}, - ] - - assert assistant_content_to_llm_text(content) == "kept" - - -def test_user_images_convert_to_openai_compatible_image_url_blocks() -> None: - content = [ - {"type": "text", "text": "look"}, - {"type": "image", "image": "data:image/png;base64,abc"}, - ] - - assert user_content_to_llm_content(content, allow_images=True) == [ - {"type": "text", "text": "look"}, - {"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}}, - ] - - -def test_user_images_can_be_dropped_for_text_only_history() -> None: - content = [ - {"type": "text", "text": "look"}, - {"type": "image", "image": "data:image/png;base64,abc"}, - ] - - assert user_content_to_llm_content(content, allow_images=False) == "look" diff --git a/surfsense_backend/tests/unit/tasks/chat/test_message_parts_normalizer.py b/surfsense_backend/tests/unit/tasks/chat/test_message_parts_normalizer.py deleted file mode 100644 index 91ca01d95..000000000 --- a/surfsense_backend/tests/unit/tasks/chat/test_message_parts_normalizer.py +++ /dev/null @@ -1,67 +0,0 @@ -"""Unit tests for final assistant message part normalization.""" - -from __future__ import annotations - -import pytest -from langchain_core.messages import AIMessage, HumanMessage, ToolMessage - -from app.tasks.chat.message_parts_normalizer import ( - final_assistant_parts_from_messages, - merge_streamed_and_final_parts, - normalize_ai_message_to_parts, -) - -pytestmark = pytest.mark.unit - - -def test_string_ai_message_content_becomes_text_part() -> None: - assert normalize_ai_message_to_parts(AIMessage(content="hello")) == [ - {"type": "text", "text": "hello"} - ] - - -def test_deepseek_thinking_plus_text_blocks_backfill_only_text() -> None: - message = AIMessage( - content=[ - {"type": "thinking", "thinking": "hidden reasoning"}, - {"type": "text", "text": "Yo bro! What's up?"}, - ], - additional_kwargs={"reasoning_content": "hidden reasoning"}, - ) - - assert normalize_ai_message_to_parts(message) == [ - {"type": "text", "text": "Yo bro! What's up?"} - ] - - -def test_final_parts_use_last_ai_message_and_skip_trailing_tool_messages() -> None: - messages = [ - HumanMessage(content="ask"), - AIMessage(content="draft"), - ToolMessage(content="tool output", tool_call_id="tc-1"), - AIMessage(content=[{"type": "text", "text": "final answer"}]), - ToolMessage(content="trailing tool noise", tool_call_id="tc-2"), - ] - - assert final_assistant_parts_from_messages(messages) == [ - {"type": "text", "text": "final answer"} - ] - - -def test_merge_adds_final_text_when_stream_only_has_thinking_steps() -> None: - streamed = [ - { - "type": "data-thinking-steps", - "data": [{"id": "thinking-1", "status": "completed"}], - } - ] - final = [{"type": "text", "text": "visible answer"}] - - assert merge_streamed_and_final_parts(streamed, final) == [*streamed, *final] - - -def test_merge_does_not_duplicate_when_stream_already_has_text() -> None: - streamed = [{"type": "text", "text": "streamed answer"}] - final = [{"type": "text", "text": "final answer"}] - - assert merge_streamed_and_final_parts(streamed, final) == streamed diff --git a/surfsense_backend/tests/unit/tasks/test_celery_async_runner.py b/surfsense_backend/tests/unit/tasks/test_celery_async_runner.py index 2342dd8da..a5bb3f58a 100644 --- a/surfsense_backend/tests/unit/tasks/test_celery_async_runner.py +++ b/surfsense_backend/tests/unit/tasks/test_celery_async_runner.py @@ -239,18 +239,17 @@ def test_video_presentation_task_uses_runner_helper() -> None: ) -def test_podcast_tasks_use_runner_helper() -> None: - """Symmetric assertion for the podcast tasks — same root cause, same +def test_podcast_task_uses_runner_helper() -> None: + """Symmetric assertion for the podcast task — same root cause, same fix, same regression risk. """ import inspect - from app.podcasts.tasks import draft, render + from app.tasks.celery_tasks import podcast_tasks - for module in (draft, render): - src = inspect.getsource(module) - assert "run_async_celery_task" in src - assert "asyncio.new_event_loop" not in src + src = inspect.getsource(podcast_tasks) + assert "run_async_celery_task" in src + assert "asyncio.new_event_loop" not in src def test_runner_runs_shutdown_asyncgens_before_close() -> None: diff --git a/surfsense_backend/tests/unit/tasks/test_podcast_billing.py b/surfsense_backend/tests/unit/tasks/test_podcast_billing.py new file mode 100644 index 000000000..699297df1 --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/test_podcast_billing.py @@ -0,0 +1,388 @@ +"""Unit tests for podcast Celery task billing integration. + +Validates ``_generate_content_podcast`` correctly wraps +``podcaster_graph.ainvoke`` in a ``billable_call`` envelope, propagates the +search-space owner's billing decision, and degrades cleanly when the +resolver fails or premium credit is exhausted. + +Coverage: + +* Happy-path free config: resolver → ``billable_call`` enters with + ``usage_type='podcast_generation'`` and the configured reserve override, + graph runs, podcast row flips to ``READY``. +* Happy-path premium config: same wiring with ``billing_tier='premium'``. +* Quota denial: ``billable_call`` raises ``QuotaInsufficientError`` → + graph is *not* invoked, podcast row flips to ``FAILED``, return dict + carries ``reason='premium_quota_exhausted'``. +* Resolver failure: ``ValueError`` from the resolver → podcast row flips + to ``FAILED``, return dict carries ``reason='billing_resolution_failed'``. +""" + +from __future__ import annotations + +import contextlib +from types import SimpleNamespace +from typing import Any +from uuid import uuid4 + +import pytest + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# Fakes +# --------------------------------------------------------------------------- + + +class _FakeExecResult: + def __init__(self, obj): + self._obj = obj + + def scalars(self): + return self + + def first(self): + return self._obj + + def filter(self, *_args, **_kwargs): + return self + + +class _FakeSession: + def __init__(self, podcast): + self._podcast = podcast + self.commit_count = 0 + + async def execute(self, _stmt): + return _FakeExecResult(self._podcast) + + async def commit(self): + self.commit_count += 1 + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + return None + + +class _FakeSessionMaker: + def __init__(self, session: _FakeSession): + self._session = session + + def __call__(self): + return self._session + + +def _make_podcast(podcast_id: int = 7, thread_id: int = 99) -> SimpleNamespace: + """Stand-in for a ``Podcast`` row. Importing ``PodcastStatus`` lazily + inside helpers keeps this fixture cheap.""" + return SimpleNamespace( + id=podcast_id, + title="Test Podcast", + thread_id=thread_id, + status=None, + podcast_transcript=None, + file_location=None, + ) + + +@contextlib.asynccontextmanager +async def _ok_billable_call(**kwargs): + """Stand-in for ``billable_call`` that records its kwargs and yields a + no-op accumulator-shaped object.""" + _CALL_LOG.append(kwargs) + yield SimpleNamespace() + + +_CALL_LOG: list[dict[str, Any]] = [] + + +@contextlib.asynccontextmanager +async def _denying_billable_call(**kwargs): + from app.services.billable_calls import QuotaInsufficientError + + _CALL_LOG.append(kwargs) + raise QuotaInsufficientError( + usage_type=kwargs.get("usage_type", "?"), + used_micros=5_000_000, + limit_micros=5_000_000, + remaining_micros=0, + ) + yield SimpleNamespace() # pragma: no cover — for grammar only + + +@contextlib.asynccontextmanager +async def _settlement_failing_billable_call(**kwargs): + from app.services.billable_calls import BillingSettlementError + + _CALL_LOG.append(kwargs) + yield SimpleNamespace() + raise BillingSettlementError( + usage_type=kwargs.get("usage_type", "?"), + user_id=kwargs["user_id"], + cause=RuntimeError("finalize failed"), + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _reset_call_log(): + _CALL_LOG.clear() + yield + _CALL_LOG.clear() + + +@pytest.mark.asyncio +async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeypatch): + """Happy path: free billing tier still wraps the graph call so the + audit row is recorded. Verifies kwargs threading.""" + from app.config import config as app_config + from app.db import PodcastStatus + from app.tasks.celery_tasks import podcast_tasks + + podcast = _make_podcast(podcast_id=7, thread_id=99) + session = _FakeSession(podcast) + monkeypatch.setattr( + podcast_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + user_id = uuid4() + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + assert search_space_id == 555 + assert thread_id == 99 + return user_id, "free", "openrouter/some-free-model" + + monkeypatch.setattr( + podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver + ) + monkeypatch.setattr(podcast_tasks, "billable_call", _ok_billable_call) + + async def _fake_graph_invoke(state, config): + return { + "podcast_transcript": [ + SimpleNamespace(speaker_id=0, dialog="Hi"), + SimpleNamespace(speaker_id=1, dialog="Hello"), + ], + "final_podcast_file_path": "/tmp/podcast.wav", + } + + monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke) + + result = await podcast_tasks._generate_content_podcast( + podcast_id=7, + source_content="hello world", + search_space_id=555, + user_prompt="make it short", + ) + + assert result["status"] == "ready" + assert result["podcast_id"] == 7 + assert podcast.status == PodcastStatus.READY + assert podcast.file_location == "/tmp/podcast.wav" + + assert len(_CALL_LOG) == 1 + call = _CALL_LOG[0] + assert call["user_id"] == user_id + assert call["search_space_id"] == 555 + assert call["billing_tier"] == "free" + assert call["base_model"] == "openrouter/some-free-model" + assert call["usage_type"] == "podcast_generation" + assert ( + call["quota_reserve_micros_override"] + == app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS + ) + # Background artifact audit rows intentionally omit the TokenUsage.thread_id + # FK to avoid coupling Celery audit commits to an active chat transaction. + assert "thread_id" not in call + assert call["call_details"] == { + "podcast_id": 7, + "title": "Test Podcast", + "thread_id": 99, + } + assert callable(call["billable_session_factory"]) + + +@pytest.mark.asyncio +async def test_billable_call_invoked_with_premium_tier(monkeypatch): + """Premium resolution flows through to ``billable_call`` so the + reserve/finalize path triggers.""" + from app.tasks.celery_tasks import podcast_tasks + + podcast = _make_podcast() + session = _FakeSession(podcast) + monkeypatch.setattr( + podcast_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + user_id = uuid4() + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + return user_id, "premium", "gpt-5.4" + + monkeypatch.setattr( + podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver + ) + monkeypatch.setattr(podcast_tasks, "billable_call", _ok_billable_call) + + async def _fake_graph_invoke(state, config): + return {"podcast_transcript": [], "final_podcast_file_path": "x.wav"} + + monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke) + + await podcast_tasks._generate_content_podcast( + podcast_id=7, + source_content="hi", + search_space_id=555, + user_prompt=None, + ) + + assert _CALL_LOG[0]["billing_tier"] == "premium" + assert _CALL_LOG[0]["base_model"] == "gpt-5.4" + + +@pytest.mark.asyncio +async def test_quota_insufficient_marks_podcast_failed_and_skips_graph(monkeypatch): + """When ``billable_call`` denies the reservation, the graph never + runs and the podcast row flips to FAILED with the documented reason + code.""" + from app.db import PodcastStatus + from app.tasks.celery_tasks import podcast_tasks + + podcast = _make_podcast(podcast_id=8) + session = _FakeSession(podcast) + monkeypatch.setattr( + podcast_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + return uuid4(), "premium", "gpt-5.4" + + monkeypatch.setattr( + podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver + ) + monkeypatch.setattr(podcast_tasks, "billable_call", _denying_billable_call) + + graph_invoked = [] + + async def _fake_graph_invoke(state, config): + graph_invoked.append(True) + return {} + + monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke) + + result = await podcast_tasks._generate_content_podcast( + podcast_id=8, + source_content="hi", + search_space_id=555, + user_prompt=None, + ) + + assert result == { + "status": "failed", + "podcast_id": 8, + "reason": "premium_quota_exhausted", + } + assert podcast.status == PodcastStatus.FAILED + assert graph_invoked == [] # Graph never ran on denied reservation. + + +@pytest.mark.asyncio +async def test_billing_settlement_failure_marks_podcast_failed(monkeypatch): + from app.db import PodcastStatus + from app.tasks.celery_tasks import podcast_tasks + + podcast = _make_podcast(podcast_id=10) + session = _FakeSession(podcast) + monkeypatch.setattr( + podcast_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + return uuid4(), "premium", "gpt-5.4" + + monkeypatch.setattr( + podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver + ) + monkeypatch.setattr( + podcast_tasks, "billable_call", _settlement_failing_billable_call + ) + + async def _fake_graph_invoke(state, config): + return {"podcast_transcript": [], "final_podcast_file_path": "x.wav"} + + monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke) + + result = await podcast_tasks._generate_content_podcast( + podcast_id=10, + source_content="hi", + search_space_id=555, + user_prompt=None, + ) + + assert result == { + "status": "failed", + "podcast_id": 10, + "reason": "billing_settlement_failed", + } + assert podcast.status == PodcastStatus.FAILED + + +@pytest.mark.asyncio +async def test_resolver_failure_marks_podcast_failed(monkeypatch): + """If the resolver raises (e.g. search-space deleted), the task fails + cleanly without invoking the graph.""" + from app.db import PodcastStatus + from app.tasks.celery_tasks import podcast_tasks + + podcast = _make_podcast(podcast_id=9) + session = _FakeSession(podcast) + monkeypatch.setattr( + podcast_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + async def _failing_resolver(sess, search_space_id, *, thread_id=None): + raise ValueError("Search space 555 not found") + + monkeypatch.setattr( + podcast_tasks, "_resolve_agent_billing_for_search_space", _failing_resolver + ) + + graph_invoked = [] + + async def _fake_graph_invoke(state, config): + graph_invoked.append(True) + return {} + + monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke) + + result = await podcast_tasks._generate_content_podcast( + podcast_id=9, + source_content="hi", + search_space_id=555, + user_prompt=None, + ) + + assert result == { + "status": "failed", + "podcast_id": 9, + "reason": "billing_resolution_failed", + } + assert podcast.status == PodcastStatus.FAILED + assert graph_invoked == [] diff --git a/surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py b/surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py index e82957acd..792d059b0 100644 --- a/surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py +++ b/surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py @@ -35,7 +35,7 @@ def test_safety_net_does_not_fire_for_azure_gpt_4o(): it text-only.""" assert ( is_known_text_only_chat_model( - provider="azure", + provider="AZURE_OPENAI", model_name="my-azure-deployment", base_model="gpt-4o", ) @@ -49,7 +49,7 @@ def test_safety_net_does_not_fire_for_unknown_model(): LiteLLM doesn't know about must flow through to the provider.""" assert ( is_known_text_only_chat_model( - provider="custom", + provider="CUSTOM", custom_provider="brand_new_proxy", model_name="brand-new-model-x9", ) @@ -69,7 +69,7 @@ def test_safety_net_does_not_fire_when_lookup_raises(monkeypatch): assert ( is_known_text_only_chat_model( - provider="openai", + provider="OPENAI", model_name="gpt-4o", ) is False @@ -88,7 +88,7 @@ def test_safety_net_fires_only_on_explicit_false(monkeypatch): monkeypatch.setattr(pc.litellm, "get_model_info", _info_explicit_false) assert ( is_known_text_only_chat_model( - provider="openai", + provider="OPENAI", model_name="text-only-stub", ) is True @@ -100,7 +100,7 @@ def test_safety_net_fires_only_on_explicit_false(monkeypatch): monkeypatch.setattr(pc.litellm, "get_model_info", _info_true) assert ( is_known_text_only_chat_model( - provider="openai", + provider="OPENAI", model_name="vision-stub", ) is False @@ -112,7 +112,7 @@ def test_safety_net_fires_only_on_explicit_false(monkeypatch): monkeypatch.setattr(pc.litellm, "get_model_info", _info_missing) assert ( is_known_text_only_chat_model( - provider="openai", + provider="OPENAI", model_name="missing-key-stub", ) is False diff --git a/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py b/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py index 7183024ed..423b64ddb 100644 --- a/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py +++ b/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py @@ -98,7 +98,8 @@ async def _denying_billable_call(**kwargs): _CALL_LOG.append(kwargs) raise QuotaInsufficientError( usage_type=kwargs.get("usage_type", "?"), - balance_micros=5_000_000, + used_micros=5_000_000, + limit_micros=5_000_000, remaining_micros=0, ) yield SimpleNamespace() # pragma: no cover diff --git a/surfsense_backend/uv.lock b/surfsense_backend/uv.lock index 8c540b41c..a927a928d 100644 --- a/surfsense_backend/uv.lock +++ b/surfsense_backend/uv.lock @@ -15,12 +15,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform == 'linux' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra == 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -34,12 +28,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version < '3.13' and sys_platform == 'linux' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra == 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -53,12 +41,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform == 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra == 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -72,12 +54,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform == 'emscripten' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra == 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -91,12 +67,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra == 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -110,12 +80,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform == 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra == 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -129,12 +93,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform == 'emscripten' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra == 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -148,12 +106,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra == 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -167,12 +119,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version < '3.13' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra == 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -186,12 +132,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version < '3.13' and sys_platform == 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra == 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -209,14 +149,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform == 'linux' and extra != 'extra-16-surf-new-backend-cpu' and extra == 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -230,12 +162,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform == 'linux' and extra != 'extra-16-surf-new-backend-cpu' and extra == 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -249,12 +175,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version < '3.13' and sys_platform == 'linux' and extra != 'extra-16-surf-new-backend-cpu' and extra == 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -268,12 +188,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform == 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra == 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -287,12 +201,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform == 'emscripten' and extra != 'extra-16-surf-new-backend-cpu' and extra == 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -306,12 +214,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra == 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -325,12 +227,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform == 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra == 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -344,12 +240,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform == 'emscripten' and extra != 'extra-16-surf-new-backend-cpu' and extra == 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -363,12 +253,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra == 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -382,12 +266,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version < '3.13' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra == 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -401,12 +279,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version < '3.13' and sys_platform == 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra == 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -424,14 +296,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform == 'linux' and extra == 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -445,12 +309,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform == 'linux' and extra == 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -464,12 +322,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version < '3.13' and sys_platform == 'linux' and extra == 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -483,12 +335,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform == 'win32' and extra == 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -502,12 +348,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform == 'emscripten' and extra == 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -521,12 +361,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -540,12 +374,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform == 'win32' and extra == 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -559,12 +387,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform == 'emscripten' and extra == 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -578,12 +400,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -597,12 +413,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version < '3.13' and sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -616,12 +426,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version < '3.13' and sys_platform == 'win32' and extra == 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -639,14 +443,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform == 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -660,12 +456,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform == 'emscripten' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -691,18 +481,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -716,12 +494,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform == 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -735,12 +507,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform == 'emscripten' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -766,18 +532,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform != 'emscripten' and sys_platform != 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -803,18 +557,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version < '3.13' and sys_platform != 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -828,12 +570,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version < '3.13' and sys_platform == 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", ] conflicts = [[ @@ -3331,6 +3067,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e8/2d/d2a548598be01649e2d46231d151a6c56d10b964d94043a335ae56ea2d92/flatbuffers-25.12.19-py2.py3-none-any.whl", hash = "sha256:7634f50c427838bb021c2d66a3d1168e9d199b0607e6329399f04846d42e20b4", size = 26661 }, ] +[[package]] +name = "flower" +version = "2.0.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "celery" }, + { name = "humanize" }, + { name = "prometheus-client" }, + { name = "pytz" }, + { name = "tornado" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/09/a1/357f1b5d8946deafdcfdd604f51baae9de10aafa2908d0b7322597155f92/flower-2.0.1.tar.gz", hash = "sha256:5ab717b979530770c16afb48b50d2a98d23c3e9fe39851dcf6bc4d01845a02a0", size = 3220408 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a6/ff/ee2f67c0ff146ec98b5df1df637b2bc2d17beeb05df9f427a67bd7a7d79c/flower-2.0.1-py2.py3-none-any.whl", hash = "sha256:9db2c621eeefbc844c8dd88be64aef61e84e2deb29b271e02ab2b5b9f01068e2", size = 383553 }, +] + [[package]] name = "flupy" version = "1.2.3" @@ -4029,6 +3781,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a8/af/48ac8483240de756d2438c380746e7130d1c6f75802ef22f3c6d49982787/huggingface_hub-0.36.2-py3-none-any.whl", hash = "sha256:48f0c8eac16145dfce371e9d2d7772854a4f591bcb56c9cf548accf531d54270", size = 566395 }, ] +[[package]] +name = "humanize" +version = "4.15.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ba/66/a3921783d54be8a6870ac4ccffcd15c4dc0dd7fcce51c6d63b8c63935276/humanize-4.15.0.tar.gz", hash = "sha256:1dd098483eb1c7ee8e32eb2e99ad1910baefa4b75c3aff3a82f4d78688993b10", size = 83599 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c5/7b/bca5613a0c3b542420cf92bd5e5fb8ebd5435ce1011a091f66bb7693285e/humanize-4.15.0-py3-none-any.whl", hash = "sha256:b1186eb9f5a9749cd9cb8565aee77919dd7c8d076161cf44d70e59e3301e1769", size = 132203 }, +] + [[package]] name = "hyperframe" version = "6.1.0" @@ -7546,6 +7307,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/42/59/123aee44a039b212cfb8d90be1adf06496a99b313ee1683aadf90b3d9799/progress-1.6.1-py3-none-any.whl", hash = "sha256:5239f22f305c12fdc8ce6e0e47f70f21622a935e16eafc4535617112e7c7ea0b", size = 9761 }, ] +[[package]] +name = "prometheus-client" +version = "0.24.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f0/58/a794d23feb6b00fc0c72787d7e87d872a6730dd9ed7c7b3e954637d8f280/prometheus_client-0.24.1.tar.gz", hash = "sha256:7e0ced7fbbd40f7b84962d5d2ab6f17ef88a72504dcf7c0b40737b43b2a461f9", size = 85616 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/74/c3/24a2f845e3917201628ecaba4f18bab4d18a337834c1df2a159ee9d22a42/prometheus_client-0.24.1-py3-none-any.whl", hash = "sha256:150db128af71a5c2482b36e588fc8a6b95e498750da4b17065947c16070f4055", size = 64057 }, +] + [[package]] name = "prompt-toolkit" version = "3.0.52" @@ -9712,7 +9482,7 @@ wheels = [ [[package]] name = "surf-new-backend" -version = "0.0.29" +version = "0.0.27" source = { editable = "." } dependencies = [ { name = "alembic" }, @@ -9737,6 +9507,7 @@ dependencies = [ { name = "fastapi-users", extra = ["oauth", "sqlalchemy"] }, { name = "faster-whisper" }, { name = "firecrawl-py" }, + { name = "flower" }, { name = "fractional-indexing" }, { name = "github3-py" }, { name = "gitingest" }, @@ -9855,6 +9626,7 @@ requires-dist = [ { name = "fastapi-users", extras = ["oauth", "sqlalchemy"], specifier = ">=15.0.3" }, { name = "faster-whisper", specifier = ">=1.1.0" }, { name = "firecrawl-py", specifier = ">=4.9.0" }, + { name = "flower", specifier = ">=2.0.1" }, { name = "fractional-indexing", specifier = ">=0.1.3" }, { name = "github3-py", specifier = "==4.0.1" }, { name = "gitingest", specifier = ">=0.3.1" }, diff --git a/surfsense_browser_extension/package.json b/surfsense_browser_extension/package.json index 4d888acdb..959e0b395 100644 --- a/surfsense_browser_extension/package.json +++ b/surfsense_browser_extension/package.json @@ -1,7 +1,7 @@ { "name": "surfsense_browser_extension", "displayName": "Surfsense Browser Extension", - "version": "0.0.29", + "version": "0.0.27", "description": "Extension to collect Browsing History for SurfSense.", "author": "https://github.com/MODSetter", "engines": { diff --git a/surfsense_desktop/package.json b/surfsense_desktop/package.json index 5b663de00..433e33315 100644 --- a/surfsense_desktop/package.json +++ b/surfsense_desktop/package.json @@ -1,7 +1,7 @@ { "name": "surfsense-desktop", "productName": "SurfSense", - "version": "0.0.29", + "version": "0.0.27", "description": "SurfSense Desktop App", "main": "dist/main.js", "scripts": { diff --git a/surfsense_desktop/scripts/build-electron.mjs b/surfsense_desktop/scripts/build-electron.mjs index cc2083fe4..75a3cdf61 100644 --- a/surfsense_desktop/scripts/build-electron.mjs +++ b/surfsense_desktop/scripts/build-electron.mjs @@ -108,11 +108,8 @@ async function buildElectron() { sourcemap: true, minify: false, define: { - 'process.env.HOSTED_BACKEND_URL': JSON.stringify( - process.env.HOSTED_BACKEND_URL || desktopEnv.HOSTED_BACKEND_URL || '' - ), 'process.env.HOSTED_FRONTEND_URL': JSON.stringify( - process.env.HOSTED_FRONTEND_URL || desktopEnv.HOSTED_FRONTEND_URL || 'https://surfsense.com' + process.env.HOSTED_FRONTEND_URL || desktopEnv.HOSTED_FRONTEND_URL || 'https://surfsense.net' ), 'process.env.POSTHOG_KEY': JSON.stringify( process.env.POSTHOG_KEY || desktopEnv.POSTHOG_KEY || '' diff --git a/surfsense_desktop/src/modules/server.ts b/surfsense_desktop/src/modules/server.ts index d7274ad9c..fc2fa05c3 100644 --- a/surfsense_desktop/src/modules/server.ts +++ b/surfsense_desktop/src/modules/server.ts @@ -43,13 +43,11 @@ export async function startNextServer(): Promise<void> { const standalonePath = getStandalonePath(); const serverScript = path.join(standalonePath, 'server.js'); - const backendInternalUrl = process.env.SURFSENSE_BACKEND_INTERNAL_URL || process.env.HOSTED_BACKEND_URL; const child = utilityProcess.fork(serverScript, [], { cwd: standalonePath, env: { ...process.env, - ...(backendInternalUrl ? { SURFSENSE_BACKEND_INTERNAL_URL: backendInternalUrl } : {}), PORT: String(serverPort), // Loopback bind: avoids 0.0.0.0 leaking into request.url and redirect origins. HOSTNAME: SERVER_HOST, diff --git a/surfsense_evals/README.md b/surfsense_evals/README.md index e6fc52ca1..c755c4de6 100644 --- a/surfsense_evals/README.md +++ b/surfsense_evals/README.md @@ -77,7 +77,7 @@ The walkthrough above is `--scenario head-to-head` (default): both arms answer w | `symmetric-cheap` | `--provider-model` (cheap, text-only) | `--provider-model` (same) | Does pre-extracted image context let a non-vision LLM reason over image-heavy docs? | | `cost-arbitrage` | `--native-arm-model` (vision) | `--provider-model` (cheap) | How close does SurfSense get to a vision-native baseline at a fraction of per-query cost?| -In all three modes the **ingest-time** vision LLM is set on the SearchSpace's `vision_model_id` (auto-picked from the strongest registered global OpenRouter vision-capable model — `claude-sonnet-4.5` > `claude-opus-4.7` > `gpt-5` > `gemini-2.5-pro`, override with `--vision-llm <slug>`). What changes is which slug the *answering* models hit per arm. +In all three modes the **ingest-time** vision LLM is set on the SearchSpace's `vision_llm_config_id` (auto-picked from the strongest registered global OpenRouter vision config — `claude-sonnet-4.5` > `claude-opus-4.7` > `gpt-5` > `gemini-2.5-pro`, override with `--vision-llm <slug>`). What changes is which slug the *answering* models hit per arm. ### Ingest with vision, evaluate with a non-vision LLM (`symmetric-cheap`) @@ -118,7 +118,7 @@ python -m surfsense_evals report --suite medical Notes: - `cost-arbitrage` requires both `--provider-model` (the cheap SurfSense slug) AND `--native-arm-model <vision slug>`. -- `--vision-llm <slug>` is optional; if omitted the harness queries `GET /api/v1/model-connections/global` and auto-picks the strongest registered vision-capable model. Pass `--no-vision-llm-setup` if you want to keep whatever vision model is already attached to the SearchSpace. +- `--vision-llm <slug>` is optional; if omitted the harness queries `GET /api/v1/global-vision-llm-configs` and auto-picks the strongest registered one. Pass `--no-vision-llm-setup` if you want to keep whatever vision config is already attached to the SearchSpace. - The runner's "looks text-only" warning is suppressed (or relabelled as informational) for `symmetric-cheap` so intentional asymmetry doesn't read as a misconfiguration. - All three scenario fields (`scenario`, `provider_model`, `native_arm_model`, `vision_provider_model`) are persisted to `state.json` and recorded in `run_artifact.extra` + the report header — no need to retrace what was set. diff --git a/surfsense_evals/data/multimodal_doc/runs/2026-05-14T00-53-19Z/parser_compare/run_artifact.json b/surfsense_evals/data/multimodal_doc/runs/2026-05-14T00-53-19Z/parser_compare/run_artifact.json index b6c59e2bc..a4687f64a 100644 --- a/surfsense_evals/data/multimodal_doc/runs/2026-05-14T00-53-19Z/parser_compare/run_artifact.json +++ b/surfsense_evals/data/multimodal_doc/runs/2026-05-14T00-53-19Z/parser_compare/run_artifact.json @@ -9,7 +9,7 @@ "llamacloud_premium_lc", "surfsense_agentic" ], - "chat_model_id": -5138454, + "agent_llm_id": -5138454, "concurrency": 2, "llm_model": "anthropic/claude-sonnet-4.5", "n_pdfs": 30, diff --git a/surfsense_evals/src/surfsense_evals/core/cli.py b/surfsense_evals/src/surfsense_evals/core/cli.py index 17979fba0..3d4d0fd24 100644 --- a/surfsense_evals/src/surfsense_evals/core/cli.py +++ b/surfsense_evals/src/surfsense_evals/core/cli.py @@ -2,7 +2,7 @@ Subcommands: -* ``setup --suite <name> --provider-model <slug> [--chat-model-id <int>]`` +* ``setup --suite <name> --provider-model <slug> [--agent-llm-id <int>]`` * ``teardown --suite <name>`` * ``models list [--provider openrouter] [--grep <s>]`` * ``suites list`` @@ -18,7 +18,7 @@ publish its own flags. Design choices worth flagging: -* ``setup`` rejects ``chat_model_id == 0`` (Auto / LiteLLM router) so +* ``setup`` rejects ``agent_llm_id == 0`` (Auto / LiteLLM router) so per-question accuracy is reproducible. * ``setup`` validates that the picked LLM config has ``provider == "OPENROUTER"`` and ``model_name == --provider-model`` @@ -59,6 +59,7 @@ if sys.platform == "win32": from . import registry from .auth import CredentialError, acquire_token, client_with_auth from .clients import SearchSpaceClient +from .clients.search_space import LlmPreferences from .config import ( DEFAULT_SCENARIO, SCENARIOS, @@ -110,30 +111,23 @@ class LlmConfigEntry: def from_payload(cls, payload: dict[str, Any]) -> LlmConfigEntry: return cls( id=int(payload["id"]), - name=str(payload.get("display_name") or payload.get("name") or ""), + name=str(payload.get("name", "")), provider=str(payload.get("provider", "")).upper(), - model_name=str(payload.get("model_id") or payload.get("model_name") or ""), + model_name=str(payload.get("model_name", "")), raw=payload, ) async def _list_global_llm_configs(http: httpx.AsyncClient, base: str) -> list[LlmConfigEntry]: response = await http.get( - f"{base}/api/v1/model-connections/global", + f"{base}/api/v1/global-new-llm-configs", headers={"Accept": "application/json"}, ) response.raise_for_status() payload = response.json() if not isinstance(payload, list): - raise RuntimeError(f"Unexpected /model-connections/global payload: {payload!r}") - entries: list[LlmConfigEntry] = [] - for connection in payload: - provider = connection.get("provider", "") - for model in connection.get("models") or []: - if not model.get("enabled", True) or not model.get("supports_chat"): - continue - entries.append(LlmConfigEntry.from_payload({**model, "provider": provider})) - return entries + raise RuntimeError(f"Unexpected /global-new-llm-configs payload: {payload!r}") + return [LlmConfigEntry.from_payload(item) for item in payload] def _resolve_openrouter_id( @@ -149,8 +143,8 @@ def _resolve_openrouter_id( * If ``explicit_id`` is given: return it directly. The caller is then expected to GET-validate that the row's ``provider == "OPENROUTER"`` and ``model_name`` matches the slug. - That branch supports positive BYOK model rows whose slugs may overlap - with global OpenRouter virtuals. + That branch supports positive BYOK ``NewLLMConfig`` rows whose + slugs may overlap with global OpenRouter virtuals. * Otherwise: filter to ``provider == "OPENROUTER"`` and ``model_name == provider_model``. Expect exactly one match — raise with a friendly message otherwise. @@ -179,7 +173,7 @@ def _resolve_openrouter_id( listing = "\n".join(f" id={c.id} name={c.name!r}" for c in matches) raise RuntimeError( f"Multiple OpenRouter configs for slug '{provider_model}':\n{listing}\n" - "Pass --chat-model-id <id> to disambiguate." + "Pass --agent-llm-id <id> to disambiguate." ) return matches[0].id @@ -192,7 +186,7 @@ def _resolve_openrouter_id( async def _cmd_setup(args: argparse.Namespace) -> int: suite = args.suite provider_model: str = args.provider_model - explicit_id: int | None = args.chat_model_id + explicit_id: int | None = args.agent_llm_id scenario: str = args.scenario vision_llm_slug: str | None = args.vision_llm native_arm_model: str | None = args.native_arm_model @@ -200,7 +194,7 @@ async def _cmd_setup(args: argparse.Namespace) -> int: if explicit_id == 0: console.print( - "[red]chat_model_id == 0 (Auto / LiteLLM router) is not allowed — " + "[red]agent_llm_id == 0 (Auto / LiteLLM router) is not allowed — " "results would not be reproducible.[/red]" ) return 2 @@ -248,7 +242,7 @@ async def _cmd_setup(args: argparse.Namespace) -> int: candidates = await _list_global_llm_configs(http, config.surfsense_api_base) try: - chat_model_id = _resolve_openrouter_id( + agent_llm_id = _resolve_openrouter_id( candidates, provider_model, explicit_id=explicit_id ) except RuntimeError as exc: @@ -294,7 +288,7 @@ async def _cmd_setup(args: argparse.Namespace) -> int: vision_provider_model: str | None = None if not skip_vision_setup and (vision_required or vision_llm_slug is not None): try: - vision_candidates = await ss_client.list_global_vision_models() + vision_candidates = await ss_client.list_global_vision_llm_configs() resolved = resolve_vision_llm( vision_candidates, explicit_slug=vision_llm_slug ) @@ -308,34 +302,37 @@ async def _cmd_setup(args: argparse.Namespace) -> int: f"(id={vision_config_id}, selected_via={resolved.selected_via})." ) - role_kwargs: dict[str, Any] = {"chat_model_id": chat_model_id} + pref_kwargs: dict[str, Any] = {"agent_llm_id": agent_llm_id} if vision_config_id is not None: - role_kwargs["vision_model_id"] = vision_config_id + pref_kwargs["vision_llm_config_id"] = vision_config_id - await ss_client.set_model_roles(search_space_id, **role_kwargs) - roles = await ss_client.get_model_roles(search_space_id) - if roles.chat_model_id != chat_model_id: + await ss_client.set_llm_preferences(search_space_id, **pref_kwargs) + prefs = await ss_client.get_llm_preferences(search_space_id) + if not _validate_pin(prefs, provider_model): + agent = prefs.agent_llm or {} console.print( f"[red]LLM pin validation FAILED.[/red] After PUT, " - f"chat_model_id={roles.chat_model_id!r}; expected {chat_model_id!r}." + f"agent_llm.provider={agent.get('provider')!r}, " + f"model_name={agent.get('model_name')!r}; expected " + f"provider=OPENROUTER, model_name={provider_model!r}." ) return 2 - if vision_config_id is not None and roles.vision_model_id != vision_config_id: + if vision_config_id is not None and prefs.vision_llm_config_id != vision_config_id: console.print( f"[red]Vision LLM pin validation FAILED.[/red] After PUT, " - f"vision_model_id={roles.vision_model_id!r}; " + f"vision_llm_config_id={prefs.vision_llm_config_id!r}; " f"expected {vision_config_id!r}." ) return 2 suite_state = SuiteState( search_space_id=search_space_id, - chat_model_id=chat_model_id, + agent_llm_id=agent_llm_id, provider_model=provider_model, created_at=utc_iso_timestamp(), ingestion_maps=existing.ingestion_maps if existing else {}, scenario=scenario, - vision_model_id=vision_config_id, + vision_llm_config_id=vision_config_id, vision_provider_model=vision_provider_model, native_arm_model=native_arm_model, ) @@ -345,7 +342,7 @@ async def _cmd_setup(args: argparse.Namespace) -> int: f"suite={suite!r}", f"scenario={scenario!r}", f"search_space_id={suite_state.search_space_id}", - f"chat_model_id={suite_state.chat_model_id}", + f"agent_llm_id={suite_state.agent_llm_id}", f"provider_model={suite_state.provider_model!r}", ] if suite_state.vision_provider_model: @@ -356,6 +353,14 @@ async def _cmd_setup(args: argparse.Namespace) -> int: return 0 +def _validate_pin(prefs: LlmPreferences, provider_model: str) -> bool: + agent = prefs.agent_llm or {} + return ( + str(agent.get("provider", "")).upper() == "OPENROUTER" + and str(agent.get("model_name", "")) == provider_model + ) + + async def _cmd_teardown(args: argparse.Namespace) -> int: suite = args.suite config = load_config() @@ -649,10 +654,10 @@ def _build_parser() -> argparse.ArgumentParser: ), ) p_setup.add_argument( - "--chat-model-id", + "--agent-llm-id", type=int, default=None, - help="Optional explicit model id override.", + help="Optional override for BYOK NewLLMConfig rows.", ) p_setup.add_argument( "--scenario", diff --git a/surfsense_evals/src/surfsense_evals/core/clients/search_space.py b/surfsense_evals/src/surfsense_evals/core/clients/search_space.py index efd4a571d..e2d37694d 100644 --- a/surfsense_evals/src/surfsense_evals/core/clients/search_space.py +++ b/surfsense_evals/src/surfsense_evals/core/clients/search_space.py @@ -1,16 +1,17 @@ -"""Client for ``/api/v1/searchspaces`` and model-role endpoints. +"""Client for ``/api/v1/searchspaces`` and ``/api/v1/search-spaces/{id}/llm-preferences``. Verified against: * ``surfsense_backend/app/routes/search_spaces_routes.py:116`` (POST create) * ``surfsense_backend/app/routes/search_spaces_routes.py:234`` (GET by id) * ``surfsense_backend/app/routes/search_spaces_routes.py:422`` (DELETE soft-delete) -* ``surfsense_backend/app/routes/model_connections_routes.py`` (GET/PUT model roles) +* ``surfsense_backend/app/routes/search_spaces_routes.py:698-849`` (GET/PUT llm-preferences) * ``surfsense_backend/app/schemas/search_space.py:14`` (SearchSpaceCreate body) +* ``surfsense_backend/app/routes/vision_llm_routes.py:60`` (GET global vision configs) Note the inconsistent pluralisation in the backend: ``/searchspaces`` -(no hyphen) for CRUD, but ``/search-spaces`` (hyphenated) for model-role -sub-resources. Both are mirrored verbatim here. +(no hyphen) for CRUD, but ``/search-spaces`` (hyphenated) for the +``llm-preferences`` sub-resource. Both are mirrored verbatim here. """ from __future__ import annotations @@ -45,8 +46,13 @@ class SearchSpaceRow: @dataclass -class VisionModelEntry: - """Subset of one GLOBAL model-connection model with image input support.""" +class VisionLlmConfigEntry: + """Subset of one ``GET /global-vision-llm-configs`` row. + + The backend returns negative ids for global / OpenRouter-derived + vision configs and positive ids for per-user BYOK rows. Either is + accepted by ``set_llm_preferences(vision_llm_config_id=...)``. + """ id: int name: str @@ -56,38 +62,45 @@ class VisionModelEntry: raw: dict[str, Any] @classmethod - def from_payload(cls, payload: dict[str, Any]) -> VisionModelEntry: + def from_payload(cls, payload: dict[str, Any]) -> VisionLlmConfigEntry: return cls( id=int(payload.get("id", 0)), - name=str(payload.get("display_name") or payload.get("model_id") or ""), + name=str(payload.get("name", "")), provider=str(payload.get("provider", "")).upper(), - model_name=str(payload.get("model_id", "")), - is_auto_mode=False, + model_name=str(payload.get("model_name", "")), + is_auto_mode=bool(payload.get("is_auto_mode", False)), raw=payload, ) @dataclass -class ModelRoles: - """Model role ids for a search space.""" +class LlmPreferences: + """Resolved LLM preferences with the embedded full config row. - chat_model_id: int | None - image_gen_model_id: int | None - vision_model_id: int | None + Mirrors ``LLMPreferencesRead`` from the backend so the lifecycle + command can introspect ``provider`` / ``model_name`` to validate the + OpenRouter pin. + """ + + agent_llm_id: int | None + image_generation_config_id: int | None + vision_llm_config_id: int | None + agent_llm: dict[str, Any] | None raw: dict[str, Any] @classmethod - def from_payload(cls, payload: dict[str, Any]) -> ModelRoles: + def from_payload(cls, payload: dict[str, Any]) -> LlmPreferences: return cls( - chat_model_id=payload.get("chat_model_id"), - image_gen_model_id=payload.get("image_gen_model_id"), - vision_model_id=payload.get("vision_model_id"), + agent_llm_id=payload.get("agent_llm_id"), + image_generation_config_id=payload.get("image_generation_config_id"), + vision_llm_config_id=payload.get("vision_llm_config_id"), + agent_llm=payload.get("agent_llm"), raw=payload, ) class SearchSpaceClient: - """Thin wrapper around the SearchSpace + model role endpoints.""" + """Thin wrapper around the SearchSpace + LLM preferences endpoints.""" def __init__(self, http: httpx.AsyncClient, base_url: str) -> None: self._http = http @@ -126,67 +139,64 @@ class SearchSpaceClient: return response.raise_for_status() - async def get_model_roles(self, search_space_id: int) -> ModelRoles: + async def get_llm_preferences(self, search_space_id: int) -> LlmPreferences: response = await self._http.get( - f"{self._base}/api/v1/search-spaces/{search_space_id}/model-roles", + f"{self._base}/api/v1/search-spaces/{search_space_id}/llm-preferences", headers={"Accept": "application/json"}, ) response.raise_for_status() - return ModelRoles.from_payload(response.json()) + return LlmPreferences.from_payload(response.json()) - async def set_model_roles( + async def set_llm_preferences( self, search_space_id: int, *, - chat_model_id: int | None = None, - image_gen_model_id: int | None = None, - vision_model_id: int | None = None, - ) -> ModelRoles: - """PUT a partial update to ``/search-spaces/{id}/model-roles``. + agent_llm_id: int | None = None, + image_generation_config_id: int | None = None, + vision_llm_config_id: int | None = None, + ) -> LlmPreferences: + """PUT a partial update to ``/search-spaces/{id}/llm-preferences``. Backend uses ``model_dump(exclude_unset=True)`` so omitted fields are left unchanged. """ body: dict[str, Any] = {} - if chat_model_id is not None: - body["chat_model_id"] = chat_model_id - if image_gen_model_id is not None: - body["image_gen_model_id"] = image_gen_model_id - if vision_model_id is not None: - body["vision_model_id"] = vision_model_id + if agent_llm_id is not None: + body["agent_llm_id"] = agent_llm_id + if image_generation_config_id is not None: + body["image_generation_config_id"] = image_generation_config_id + if vision_llm_config_id is not None: + body["vision_llm_config_id"] = vision_llm_config_id response = await self._http.put( - f"{self._base}/api/v1/search-spaces/{search_space_id}/model-roles", + f"{self._base}/api/v1/search-spaces/{search_space_id}/llm-preferences", json=body, headers={"Accept": "application/json"}, ) response.raise_for_status() - return ModelRoles.from_payload(response.json()) + return LlmPreferences.from_payload(response.json()) - async def list_global_vision_models(self) -> list[VisionModelEntry]: - """List registered GLOBAL models that can accept image input. + async def list_global_vision_llm_configs(self) -> list[VisionLlmConfigEntry]: + """List the registered global vision LLM configs. - Used by ``setup`` to resolve ``--vision-llm <slug>`` or auto-pick a - reproducible ingest-time vision model. + Used by ``setup`` to (a) resolve an explicit ``--vision-llm <slug>`` + to a config id and (b) auto-pick the strongest registered vision + config when the operator doesn't pass one. The ``Auto (Fastest)`` + entry (``id=0``) is filtered out — accuracy must be reproducible. """ response = await self._http.get( - f"{self._base}/api/v1/model-connections/global", + f"{self._base}/api/v1/global-vision-llm-configs", headers={"Accept": "application/json"}, ) response.raise_for_status() payload = response.json() if not isinstance(payload, list): raise RuntimeError( - f"Unexpected /model-connections/global payload: {payload!r}" + f"Unexpected /global-vision-llm-configs payload: {payload!r}" ) - entries: list[VisionModelEntry] = [] - for connection in payload: - provider = str(connection.get("provider", "")) - for model in connection.get("models") or []: - if not model.get("enabled", True) or not model.get("supports_image_input"): - continue - entries.append( - VisionModelEntry.from_payload({**model, "provider": provider}) - ) - return entries + return [ + VisionLlmConfigEntry.from_payload(item) + for item in payload + if not bool(item.get("is_auto_mode", False)) + ] diff --git a/surfsense_evals/src/surfsense_evals/core/config.py b/surfsense_evals/src/surfsense_evals/core/config.py index 9a5a71e89..164955914 100644 --- a/surfsense_evals/src/surfsense_evals/core/config.py +++ b/surfsense_evals/src/surfsense_evals/core/config.py @@ -147,35 +147,35 @@ class SuiteState: """Per-suite persisted state. ``provider_model`` is the slug pinned to the SearchSpace's - ``chat_model_id`` — what answers SurfSense queries (and what the native + ``agent_llm`` — what answers SurfSense queries (and what the native arm uses too, unless ``native_arm_model`` is set for cost-arbitrage). - ``vision_provider_model`` is the slug of the OpenRouter vision model - attached to the SearchSpace's ``vision_model_id`` — what + ``vision_provider_model`` is the slug of the OpenRouter vision LLM + config attached to the SearchSpace's ``vision_llm_config_id`` — what SurfSense uses to extract image content at ingest time when ``use_vision_llm=True``. ``None`` means no vision config was attached at setup (legacy or text-only suite). """ search_space_id: int - chat_model_id: int + agent_llm_id: int provider_model: str created_at: str ingestion_maps: dict[str, str] = field(default_factory=dict) scenario: str = DEFAULT_SCENARIO - vision_model_id: int | None = None + vision_llm_config_id: int | None = None vision_provider_model: str | None = None native_arm_model: str | None = None def to_dict(self) -> dict[str, Any]: return { "search_space_id": self.search_space_id, - "chat_model_id": self.chat_model_id, + "agent_llm_id": self.agent_llm_id, "provider_model": self.provider_model, "created_at": self.created_at, "ingestion_maps": dict(self.ingestion_maps), "scenario": self.scenario, - "vision_model_id": self.vision_model_id, + "vision_llm_config_id": self.vision_llm_config_id, "vision_provider_model": self.vision_provider_model, "native_arm_model": self.native_arm_model, } @@ -187,16 +187,15 @@ class SuiteState: scenario = str(payload.get("scenario") or DEFAULT_SCENARIO) if scenario not in SCENARIOS: scenario = DEFAULT_SCENARIO - raw_chat_id = payload.get("chat_model_id") - raw_vision_id = payload.get("vision_model_id") + raw_vision_id = payload.get("vision_llm_config_id") return cls( search_space_id=int(payload["search_space_id"]), - chat_model_id=int(raw_chat_id), + agent_llm_id=int(payload["agent_llm_id"]), provider_model=str(payload["provider_model"]), created_at=str(payload.get("created_at") or ""), ingestion_maps=dict(payload.get("ingestion_maps") or {}), scenario=scenario, - vision_model_id=int(raw_vision_id) if raw_vision_id is not None else None, + vision_llm_config_id=int(raw_vision_id) if raw_vision_id is not None else None, vision_provider_model=( str(payload["vision_provider_model"]) if payload.get("vision_provider_model") diff --git a/surfsense_evals/src/surfsense_evals/core/registry.py b/surfsense_evals/src/surfsense_evals/core/registry.py index 65f64c39a..cc8b725e0 100644 --- a/surfsense_evals/src/surfsense_evals/core/registry.py +++ b/surfsense_evals/src/surfsense_evals/core/registry.py @@ -53,8 +53,8 @@ class RunContext: return self.suite_state.search_space_id @property - def chat_model_id(self) -> int: - return self.suite_state.chat_model_id + def agent_llm_id(self) -> int: + return self.suite_state.agent_llm_id @property def provider_model(self) -> str: diff --git a/surfsense_evals/src/surfsense_evals/core/vision_llm.py b/surfsense_evals/src/surfsense_evals/core/vision_llm.py index 5d5e2c6d1..ae96f1285 100644 --- a/surfsense_evals/src/surfsense_evals/core/vision_llm.py +++ b/surfsense_evals/src/surfsense_evals/core/vision_llm.py @@ -3,8 +3,8 @@ Two responsibilities: 1. Resolve an explicit ``--vision-llm <slug>`` to a global OpenRouter - vision-capable model id that ``set_model_roles(vision_model_id=...)`` can - accept. + vision LLM config id that ``set_llm_preferences(vision_llm_config_id=...)`` + can accept. 2. Auto-pick the strongest registered vision config when the operator doesn't pass ``--vision-llm`` but the scenario / benchmark needs one. diff --git a/surfsense_evals/src/surfsense_evals/suites/medical/medxpertqa/runner.py b/surfsense_evals/src/surfsense_evals/suites/medical/medxpertqa/runner.py index ac0651996..e1a830138 100644 --- a/surfsense_evals/src/surfsense_evals/suites/medical/medxpertqa/runner.py +++ b/surfsense_evals/src/surfsense_evals/suites/medical/medxpertqa/runner.py @@ -371,7 +371,7 @@ class MedXpertQAMMBenchmark: "provider_model": ctx.provider_model, "native_arm_model": native_arm_model, "vision_provider_model": ctx.vision_provider_model, - "chat_model_id": ctx.chat_model_id, + "agent_llm_id": ctx.agent_llm_id, "ingest_settings": ingest_settings, }, ) diff --git a/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/mmlongbench/runner.py b/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/mmlongbench/runner.py index b7685766e..95a1e15eb 100644 --- a/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/mmlongbench/runner.py +++ b/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/mmlongbench/runner.py @@ -391,7 +391,7 @@ class MMLongBenchDocBenchmark: "provider_model": ctx.provider_model, "native_arm_model": native_arm_model, "vision_provider_model": ctx.vision_provider_model, - "chat_model_id": ctx.chat_model_id, + "agent_llm_id": ctx.agent_llm_id, "ingest_settings": ingest_settings, }, ) diff --git a/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/parser_compare/runner.py b/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/parser_compare/runner.py index 2c4a0ffe4..e71dffa65 100644 --- a/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/parser_compare/runner.py +++ b/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/parser_compare/runner.py @@ -554,7 +554,7 @@ class ParserCompareBenchmark: "scenario": ctx.scenario, "provider_model": ctx.provider_model, "vision_provider_model": ctx.vision_provider_model, - "chat_model_id": ctx.chat_model_id, + "agent_llm_id": ctx.agent_llm_id, "preprocess_tariff": { "basic_per_1k_pages": 1.0, "premium_per_1k_pages": 10.0, diff --git a/surfsense_evals/src/surfsense_evals/suites/research/crag/runner.py b/surfsense_evals/src/surfsense_evals/suites/research/crag/runner.py index 654c261a2..8b759e0d8 100644 --- a/surfsense_evals/src/surfsense_evals/suites/research/crag/runner.py +++ b/surfsense_evals/src/surfsense_evals/suites/research/crag/runner.py @@ -467,7 +467,7 @@ class CragBenchmark: "provider_model": ctx.provider_model, "native_arm_model": ctx.native_arm_model, "vision_provider_model": ctx.vision_provider_model, - "chat_model_id": ctx.chat_model_id, + "agent_llm_id": ctx.agent_llm_id, "ingest_settings": ingest_settings, "per_page_char_cap": per_page_char_cap, "max_output_tokens": max_output_tokens, diff --git a/surfsense_evals/src/surfsense_evals/suites/research/frames/runner.py b/surfsense_evals/src/surfsense_evals/suites/research/frames/runner.py index 450c7ddd6..9c0e16b00 100644 --- a/surfsense_evals/src/surfsense_evals/suites/research/frames/runner.py +++ b/surfsense_evals/src/surfsense_evals/suites/research/frames/runner.py @@ -372,7 +372,7 @@ class FramesBenchmark: "provider_model": ctx.provider_model, "native_arm_model": ctx.native_arm_model, "vision_provider_model": ctx.vision_provider_model, - "chat_model_id": ctx.chat_model_id, + "agent_llm_id": ctx.agent_llm_id, "ingest_settings": ingest_settings, "bare_arm_label": "bare_llm", }, diff --git a/surfsense_evals/tests/core/test_clients.py b/surfsense_evals/tests/core/test_clients.py index aa98f0ad4..611408703 100644 --- a/surfsense_evals/tests/core/test_clients.py +++ b/surfsense_evals/tests/core/test_clients.py @@ -63,22 +63,29 @@ async def test_delete_search_space_idempotent_on_404(respx_mock, http): @pytest.mark.asyncio @respx.mock(base_url=_BASE) -async def test_set_model_roles_partial_update(respx_mock, http): - route = respx_mock.put("/api/v1/search-spaces/42/model-roles").mock( +async def test_set_llm_preferences_partial_update(respx_mock, http): + route = respx_mock.put("/api/v1/search-spaces/42/llm-preferences").mock( return_value=httpx.Response( 200, json={ - "chat_model_id": -10042, - "image_gen_model_id": None, - "vision_model_id": None, + "agent_llm_id": -10042, + "agent_llm_id": None, + "image_generation_config_id": None, + "vision_llm_config_id": None, + "agent_llm": { + "id": -10042, + "provider": "OPENROUTER", + "model_name": "anthropic/claude-sonnet-4.5", + }, }, ) ) client = SearchSpaceClient(http, _BASE) - roles = await client.set_model_roles(42, chat_model_id=-10042) - assert roles.chat_model_id == -10042 + prefs = await client.set_llm_preferences(42, agent_llm_id=-10042) + assert prefs.agent_llm_id == -10042 + assert prefs.agent_llm["provider"] == "OPENROUTER" sent_body = json.loads(route.calls[-1].request.content) - assert sent_body == {"chat_model_id": -10042} + assert sent_body == {"agent_llm_id": -10042} # --------------------------------------------------------------------------- diff --git a/surfsense_evals/tests/core/test_config.py b/surfsense_evals/tests/core/test_config.py index 6f9671c86..f7b8f7249 100644 --- a/surfsense_evals/tests/core/test_config.py +++ b/surfsense_evals/tests/core/test_config.py @@ -41,14 +41,14 @@ def test_state_roundtrip_per_suite(tmp_env): # noqa: ARG001 assert get_suite_state(config, "medical") is None state = SuiteState( search_space_id=1, - chat_model_id=-10042, + agent_llm_id=-10042, provider_model="anthropic/claude-sonnet-4.5", created_at="2026-05-11T20-30-00Z", ) set_suite_state(config, "medical", state) legal = SuiteState( search_space_id=2, - chat_model_id=-1, + agent_llm_id=-1, provider_model="openai/gpt-5", created_at="2026-05-11T21-00-00Z", ) @@ -84,19 +84,25 @@ def test_paths_are_per_suite(tmp_env): # noqa: ARG001 # --------------------------------------------------------------------------- -def test_minimal_state_defaults_to_head_to_head(): - """Missing scenario / vision / native fields default safely.""" +def test_legacy_state_back_compat_defaults_to_head_to_head(): + """state.json files written before scenarios shipped must still load. - payload = { + Missing ``scenario`` / ``vision_*`` / ``native_arm_model`` keys all + default to ``head-to-head`` / ``None`` so old setups keep working + after upgrade — the runner's behaviour exactly mirrors the legacy + one (both arms answer with ``provider_model``). + """ + + legacy = { "search_space_id": 7, - "chat_model_id": -123, + "agent_llm_id": -123, "provider_model": "anthropic/claude-sonnet-4.5", "created_at": "2026-05-11T20-30-00Z", "ingestion_maps": {}, } - state = SuiteState.from_dict(payload) + state = SuiteState.from_dict(legacy) assert state.scenario == DEFAULT_SCENARIO == "head-to-head" - assert state.vision_model_id is None + assert state.vision_llm_config_id is None assert state.vision_provider_model is None assert state.native_arm_model is None # The native arm should still answer with the same slug as SurfSense. @@ -112,7 +118,7 @@ def test_unknown_scenario_falls_back_to_default(): payload = { "search_space_id": 1, - "chat_model_id": -1, + "agent_llm_id": -1, "provider_model": "openai/gpt-5", "scenario": "unknown-scenario-name", } @@ -124,11 +130,11 @@ def test_cost_arbitrage_state_persists_native_arm_model(tmp_env): # noqa: ARG00 config = load_config() state = SuiteState( search_space_id=42, - chat_model_id=-1, + agent_llm_id=-1, provider_model="openai/gpt-5.4-mini", created_at="2026-05-11T20-30-00Z", scenario="cost-arbitrage", - vision_model_id=-101, + vision_llm_config_id=-101, vision_provider_model="anthropic/claude-sonnet-4.5", native_arm_model="anthropic/claude-sonnet-4.5", ) @@ -136,7 +142,7 @@ def test_cost_arbitrage_state_persists_native_arm_model(tmp_env): # noqa: ARG00 fetched = get_suite_state(config, "medical") assert fetched.scenario == "cost-arbitrage" - assert fetched.vision_model_id == -101 + assert fetched.vision_llm_config_id == -101 assert fetched.vision_provider_model == "anthropic/claude-sonnet-4.5" assert fetched.native_arm_model == "anthropic/claude-sonnet-4.5" # Cost arbitrage's whole point: native arm slug != surfsense slug. diff --git a/surfsense_evals/tests/test_integration_smoke.py b/surfsense_evals/tests/test_integration_smoke.py index 1c89ae5ab..493c04c25 100644 --- a/surfsense_evals/tests/test_integration_smoke.py +++ b/surfsense_evals/tests/test_integration_smoke.py @@ -27,7 +27,7 @@ async def test_smoke_against_localhost(): pytest.skip("No credentials in environment; skipping integration smoke") bundle = await acquire_token(config) async with client_with_auth(config, bundle) as client: - response = await client.get(f"{config.surfsense_api_base}/api/v1/model-connections/global") + response = await client.get(f"{config.surfsense_api_base}/api/v1/global-new-llm-configs") try: response.raise_for_status() except httpx.HTTPStatusError as exc: diff --git a/surfsense_web/.env.example b/surfsense_web/.env.example index 7d03cf498..5fb9d07d1 100644 --- a/surfsense_web/.env.example +++ b/surfsense_web/.env.example @@ -1,84 +1,30 @@ -# ───────────────────────────────────────────────────────────────────────────── -# Backend connectivity -# ───────────────────────────────────────────────────────────────────────────── +NEXT_PUBLIC_FASTAPI_BACKEND_URL=http://localhost:8000 -# Optional packaged-client override. Leave unset in Docker so browser requests -# use same-origin relative URLs behind Caddy. Set it for packaged clients -# (e.g. Electron) or local dev that talks to a separate backend origin. -# NEXT_PUBLIC_FASTAPI_BACKEND_URL=http://localhost:8000 +# Server-only. Internal backend URL used by Next.js server code. +FASTAPI_BACKEND_INTERNAL_URL=https://your-internal-backend.example.com -# Server-only. Internal backend URL used by Next.js server code (RSC / route -# handlers). Cannot be a relative URL. -SURFSENSE_BACKEND_INTERNAL_URL=http://backend:8000 +NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=LOCAL or GOOGLE +NEXT_PUBLIC_ETL_SERVICE=UNSTRUCTURED or LLAMACLOUD or DOCLING +NEXT_PUBLIC_ZERO_CACHE_URL=http://localhost:4848 -# ───────────────────────────────────────────────────────────────────────────── -# Runtime configuration (read at runtime by the server, no rebuild needed) -# ───────────────────────────────────────────────────────────────────────────── -# Configure these plain variables for runtime behavior. They are read by server -# code when the app starts/serves requests, so changing them requires restarting -# the web process but not rebuilding the frontend bundle. -# -# Authentication method: LOCAL (email/password) or GOOGLE (OAuth). -AUTH_TYPE=LOCAL -# Document parsing backend: DOCLING, LLAMACLOUD, etc. -ETL_SERVICE=DOCLING -# Deployment mode: self-hosted or cloud. -DEPLOYMENT_MODE=self-hosted - -# ───────────────────────────────────────────────────────────────────────────── -# Database (Contact Form, optional) -# ───────────────────────────────────────────────────────────────────────────── +# Contact Form Vars (optional) DATABASE_URL=postgresql://postgres:[YOUR-PASSWORD]@db.sdsf.supabase.co:5432/postgres -# ───────────────────────────────────────────────────────────────────────────── -# PostHog analytics (optional, leave key empty to disable) -# ───────────────────────────────────────────────────────────────────────────── +# Deployment mode (optional) +NEXT_PUBLIC_DEPLOYMENT_MODE="self-hosted" or "cloud" + +# PostHog analytics (optional, leave empty to disable) NEXT_PUBLIC_POSTHOG_KEY= -NEXT_PUBLIC_POSTHOG_HOST=https://us.i.posthog.com -# ───────────────────────────────────────────────────────────────────────────── -# Zero cache (real-time sync). Leave unset in Docker to use the same-origin -# "/zero" endpoint behind Caddy. Set it for local dev or packaged clients. -# ───────────────────────────────────────────────────────────────────────────── -# NEXT_PUBLIC_ZERO_CACHE_URL=http://localhost:4848 - -# ───────────────────────────────────────────────────────────────────────────── # Cloudflare Turnstile CAPTCHA for anonymous chat abuse prevention # Get your site key from https://dash.cloudflare.com/ -> Turnstile -# ───────────────────────────────────────────────────────────────────────────── NEXT_PUBLIC_TURNSTILE_SITE_KEY= -# ───────────────────────────────────────────────────────────────────────────── # Google AdSense (optional, only enables ads on the /free hub page). # Publisher ID from your AdSense dashboard, e.g. ca-pub-XXXXXXXXXXXXXXXX. # Leave empty to disable ad rendering entirely. -# ───────────────────────────────────────────────────────────────────────────── NEXT_PUBLIC_GOOGLE_ADSENSE_CLIENT_ID= # Ad unit slot IDs from AdSense dashboard -> Ads -> By ad unit. # Leave empty to hide individual slots while keeping the script loaded. NEXT_PUBLIC_GOOGLE_ADSENSE_SLOT_FREE_HUB_IN_CONTENT= -NEXT_PUBLIC_GOOGLE_ADSENSE_SLOT_FREE_HUB_BEFORE_FAQ= - -# ───────────────────────────────────────────────────────────────────────────── -# Global announcement banner (e.g. planned downtime / maintenance notice). -# Set ENABLED to "true" to show the banner, and put the notice text in MESSAGE. -# ───────────────────────────────────────────────────────────────────────────── -NEXT_PUBLIC_GLOBAL_ANNOUNCEMENT_ENABLED=false -NEXT_PUBLIC_GLOBAL_ANNOUNCEMENT_MESSAGE= - -# ───────────────────────────────────────────────────────────────────────────── -# Internal build-time fallbacks -# ───────────────────────────────────────────────────────────────────────────── -# -# Most deployments should leave these unset. -# -# These are only for SurfSense-managed production/cloud builds or packaged -# clients that do not have the normal server runtime config available. -# -# NEXT_PUBLIC_* values are embedded into the browser bundle during `next build`. -# Changing them after the bundle is built has no effect. - -# NEXT_PUBLIC_AUTH_TYPE=GOOGLE -# NEXT_PUBLIC_ETL_SERVICE=DOCLING -# NEXT_PUBLIC_DEPLOYMENT_MODE=self-hosted -# NEXT_PUBLIC_APP_VERSION= \ No newline at end of file +NEXT_PUBLIC_GOOGLE_ADSENSE_SLOT_FREE_HUB_BEFORE_FAQ= \ No newline at end of file diff --git a/surfsense_web/Dockerfile b/surfsense_web/Dockerfile index 48cc28594..d851cf045 100644 --- a/surfsense_web/Dockerfile +++ b/surfsense_web/Dockerfile @@ -35,6 +35,21 @@ RUN apk add --no-cache git # Enable pnpm RUN corepack enable pnpm +# Build with placeholder values for NEXT_PUBLIC_* variables. +# These are replaced at container startup by docker-entrypoint.js +# with real values from the container's environment variables. +ARG NEXT_PUBLIC_FASTAPI_BACKEND_URL=__NEXT_PUBLIC_FASTAPI_BACKEND_URL__ +ARG NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=__NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE__ +ARG NEXT_PUBLIC_ETL_SERVICE=__NEXT_PUBLIC_ETL_SERVICE__ +ARG NEXT_PUBLIC_ZERO_CACHE_URL=__NEXT_PUBLIC_ZERO_CACHE_URL__ +ARG NEXT_PUBLIC_DEPLOYMENT_MODE=__NEXT_PUBLIC_DEPLOYMENT_MODE__ + +ENV NEXT_PUBLIC_FASTAPI_BACKEND_URL=$NEXT_PUBLIC_FASTAPI_BACKEND_URL +ENV NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=$NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE +ENV NEXT_PUBLIC_ETL_SERVICE=$NEXT_PUBLIC_ETL_SERVICE +ENV NEXT_PUBLIC_ZERO_CACHE_URL=$NEXT_PUBLIC_ZERO_CACHE_URL +ENV NEXT_PUBLIC_DEPLOYMENT_MODE=$NEXT_PUBLIC_DEPLOYMENT_MODE + COPY --from=deps /app/node_modules ./node_modules COPY . . @@ -63,6 +78,10 @@ COPY --from=builder /app/public ./public COPY --from=builder --chown=nextjs:nodejs /app/.next/standalone/app/ ./ COPY --from=builder --chown=nextjs:nodejs /app/.next/static ./.next/static +# Entrypoint scripts for runtime env var substitution +COPY --chown=nextjs:nodejs docker-entrypoint.js ./docker-entrypoint.js +COPY --chown=nextjs:nodejs --chmod=755 docker-entrypoint.sh ./docker-entrypoint.sh + USER nextjs EXPOSE 3000 @@ -72,4 +91,4 @@ ENV PORT=3000 # server.js is created by next build from the standalone output # https://nextjs.org/docs/pages/api-reference/config/next-config-js/output ENV HOSTNAME="0.0.0.0" -CMD ["node", "server.js"] \ No newline at end of file +ENTRYPOINT ["/bin/sh", "./docker-entrypoint.sh"] \ No newline at end of file diff --git a/surfsense_web/app/(home)/changelog/page.tsx b/surfsense_web/app/(home)/changelog/page.tsx index b7aa14d20..42bac512a 100644 --- a/surfsense_web/app/(home)/changelog/page.tsx +++ b/surfsense_web/app/(home)/changelog/page.tsx @@ -3,7 +3,10 @@ import type { MDXComponents } from "mdx/types"; import type { Metadata } from "next"; import type { ComponentType } from "react"; import { changelog } from "@/.source/server"; -import { ChangelogTimeline, type ChangelogTimelineEntry } from "@/components/ui/changelog-timeline"; +import { + ChangelogTimeline, + type ChangelogTimelineEntry, +} from "@/components/ui/changelog-timeline"; import { formatDate } from "@/lib/utils"; import { getMDXComponents } from "@/mdx-components"; diff --git a/surfsense_web/app/(home)/free/[model_slug]/page.tsx b/surfsense_web/app/(home)/free/[model_slug]/page.tsx index 71fc925e4..e72c3d6e3 100644 --- a/surfsense_web/app/(home)/free/[model_slug]/page.tsx +++ b/surfsense_web/app/(home)/free/[model_slug]/page.tsx @@ -7,7 +7,7 @@ import { FAQJsonLd, JsonLd } from "@/components/seo/json-ld"; import { Button } from "@/components/ui/button"; import { Separator } from "@/components/ui/separator"; import type { AnonModel } from "@/contracts/types/anonymous-chat.types"; -import { SERVER_BACKEND_URL } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; interface PageProps { params: Promise<{ model_slug: string }>; @@ -16,7 +16,7 @@ interface PageProps { async function getModel(slug: string): Promise<AnonModel | null> { try { const res = await fetch( - `${SERVER_BACKEND_URL}/api/v1/public/anon-chat/models/${encodeURIComponent(slug)}`, + `${BACKEND_URL}/api/v1/public/anon-chat/models/${encodeURIComponent(slug)}`, { next: { revalidate: 300 } } ); if (!res.ok) return null; @@ -28,7 +28,7 @@ async function getModel(slug: string): Promise<AnonModel | null> { async function getAllModels(): Promise<AnonModel[]> { try { - const res = await fetch(`${SERVER_BACKEND_URL}/api/v1/public/anon-chat/models`, { + const res = await fetch(`${BACKEND_URL}/api/v1/public/anon-chat/models`, { next: { revalidate: 300 }, }); if (!res.ok) return []; @@ -136,7 +136,7 @@ export async function generateMetadata({ params }: PageProps): Promise<Metadata> export async function generateStaticParams() { const models = await getAllModels(); - return models.flatMap((m) => (m.seo_slug ? [{ model_slug: m.seo_slug }] : [])); + return models.filter((m) => m.seo_slug).map((m) => ({ model_slug: m.seo_slug! })); } export default async function FreeModelPage({ params }: PageProps) { diff --git a/surfsense_web/app/(home)/free/page.tsx b/surfsense_web/app/(home)/free/page.tsx index b754502f6..0092ca2d5 100644 --- a/surfsense_web/app/(home)/free/page.tsx +++ b/surfsense_web/app/(home)/free/page.tsx @@ -16,7 +16,7 @@ import { TableRow, } from "@/components/ui/table"; import type { AnonModel } from "@/contracts/types/anonymous-chat.types"; -import { SERVER_BACKEND_URL } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; export const metadata: Metadata = { title: "Free AI Chat, No Login Required | SurfSense", @@ -94,7 +94,7 @@ export const metadata: Metadata = { async function getModels(): Promise<AnonModel[]> { try { - const res = await fetch(`${SERVER_BACKEND_URL}/api/v1/public/anon-chat/models`, { + const res = await fetch(`${BACKEND_URL}/api/v1/public/anon-chat/models`, { next: { revalidate: 300 }, }); if (!res.ok) return []; @@ -246,6 +246,11 @@ export default async function FreeHubPage() { className="group flex flex-col gap-0.5" > <span className="font-medium group-hover:underline">{model.name}</span> + {model.description && ( + <span className="text-xs text-muted-foreground line-clamp-1"> + {model.description} + </span> + )} </Link> </TableCell> <TableCell> diff --git a/surfsense_web/app/(home)/layout.tsx b/surfsense_web/app/(home)/layout.tsx index c749b10f3..57dd9919e 100644 --- a/surfsense_web/app/(home)/layout.tsx +++ b/surfsense_web/app/(home)/layout.tsx @@ -2,7 +2,6 @@ import { usePathname } from "next/navigation"; import { FooterNew } from "@/components/homepage/footer-new"; -import { GlobalAnnouncement } from "@/components/homepage/global-announcement"; import { Navbar } from "@/components/homepage/navbar"; export default function HomePageLayout({ children }: { children: React.ReactNode }) { @@ -16,7 +15,6 @@ export default function HomePageLayout({ children }: { children: React.ReactNode return ( <main className="min-h-screen bg-linear-to-b from-gray-50 to-gray-100 text-gray-900 dark:from-black dark:to-gray-900 dark:text-white overflow-x-hidden"> - <GlobalAnnouncement /> <Navbar /> {children} {!isAuthPage && <FooterNew />} diff --git a/surfsense_web/app/(home)/login/GoogleLoginButton.tsx b/surfsense_web/app/(home)/login/GoogleLoginButton.tsx index a9e1b553e..1c91f8115 100644 --- a/surfsense_web/app/(home)/login/GoogleLoginButton.tsx +++ b/surfsense_web/app/(home)/login/GoogleLoginButton.tsx @@ -3,7 +3,7 @@ import { useTranslations } from "next-intl"; import { useState } from "react"; import { Logo } from "@/components/Logo"; import { Button } from "@/components/ui/button"; -import { buildBackendUrl } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; import { trackLoginAttempt } from "@/lib/posthog/events"; import { AmbientBackground } from "./AmbientBackground"; @@ -51,7 +51,7 @@ export function GoogleLoginButton() { // cross-origin fetch requests may not be sent on subsequent redirects. // The authorize-redirect endpoint does a server-side redirect to Google // and sets the CSRF cookie properly for same-site context. - window.location.href = buildBackendUrl("/auth/google/authorize-redirect"); + window.location.href = `${BACKEND_URL}/auth/google/authorize-redirect`; }; return ( <div className="relative w-full overflow-hidden"> diff --git a/surfsense_web/app/(home)/login/LocalLoginForm.tsx b/surfsense_web/app/(home)/login/LocalLoginForm.tsx index 108151512..9692d35e1 100644 --- a/surfsense_web/app/(home)/login/LocalLoginForm.tsx +++ b/surfsense_web/app/(home)/login/LocalLoginForm.tsx @@ -7,10 +7,10 @@ import { useRouter } from "next/navigation"; import { useTranslations } from "next-intl"; import { useState } from "react"; import { loginMutationAtom } from "@/atoms/auth/auth-mutation.atoms"; -import { useRuntimeConfig } from "@/components/providers/runtime-config"; import { Button } from "@/components/ui/button"; import { Spinner } from "@/components/ui/spinner"; import { getAuthErrorDetails, isNetworkError } from "@/lib/auth-errors"; +import { AUTH_TYPE } from "@/lib/env-config"; import { ValidationError } from "@/lib/error"; import { trackLoginAttempt, trackLoginFailure, trackLoginSuccess } from "@/lib/posthog/events"; @@ -26,7 +26,7 @@ export function LocalLoginForm() { title: null, message: null, }); - const { authType } = useRuntimeConfig(); + const authType = AUTH_TYPE; const router = useRouter(); const [{ mutateAsync: login, isPending: isLoggingIn }] = useAtom(loginMutationAtom); diff --git a/surfsense_web/app/(home)/login/layout.tsx b/surfsense_web/app/(home)/login/layout.tsx deleted file mode 100644 index e14aec239..000000000 --- a/surfsense_web/app/(home)/login/layout.tsx +++ /dev/null @@ -1,5 +0,0 @@ -import { RuntimeConfig } from "@/components/providers/runtime-config.server"; - -export default function LoginLayout({ children }: { children: React.ReactNode }) { - return <RuntimeConfig>{children}</RuntimeConfig>; -} diff --git a/surfsense_web/app/(home)/login/page.tsx b/surfsense_web/app/(home)/login/page.tsx index 8f146f815..42a9182e9 100644 --- a/surfsense_web/app/(home)/login/page.tsx +++ b/surfsense_web/app/(home)/login/page.tsx @@ -6,10 +6,11 @@ import { useTranslations } from "next-intl"; import { Suspense, useEffect, useState } from "react"; import { toast } from "sonner"; import { Logo } from "@/components/Logo"; -import { useRuntimeConfig } from "@/components/providers/runtime-config"; import { Button } from "@/components/ui/button"; +import { useGlobalLoadingEffect } from "@/hooks/use-global-loading"; import { getAuthErrorDetails, shouldRetry } from "@/lib/auth-errors"; import { setRedirectPath } from "@/lib/auth-utils"; +import { AUTH_TYPE } from "@/lib/env-config"; import { AmbientBackground } from "./AmbientBackground"; import { GoogleLoginButton } from "./GoogleLoginButton"; import { LocalLoginForm } from "./LocalLoginForm"; @@ -18,7 +19,8 @@ function LoginContent() { const t = useTranslations("auth"); const tCommon = useTranslations("common"); const router = useRouter(); - const { authType } = useRuntimeConfig(); + const [authType, setAuthType] = useState<string | null>(null); + const [isLoading, setIsLoading] = useState(true); const [urlError, setUrlError] = useState<{ title: string; message: string } | null>(null); const searchParams = useSearchParams(); @@ -94,7 +96,19 @@ function LoginContent() { duration: 4000, }); } - }, [router, searchParams, t, tCommon]); + + // Get the auth type from centralized config + setAuthType(AUTH_TYPE); + setIsLoading(false); + }, [searchParams, t, tCommon]); + + // Use global loading screen for auth type determination - spinner animation won't reset + useGlobalLoadingEffect(isLoading); + + // Show nothing while loading - the GlobalLoadingProvider handles the loading UI + if (isLoading) { + return null; + } if (authType === "GOOGLE") { return <GoogleLoginButton />; diff --git a/surfsense_web/app/(home)/register/layout.tsx b/surfsense_web/app/(home)/register/layout.tsx deleted file mode 100644 index 361df85c1..000000000 --- a/surfsense_web/app/(home)/register/layout.tsx +++ /dev/null @@ -1,5 +0,0 @@ -import { RuntimeConfig } from "@/components/providers/runtime-config.server"; - -export default function RegisterLayout({ children }: { children: React.ReactNode }) { - return <RuntimeConfig>{children}</RuntimeConfig>; -} diff --git a/surfsense_web/app/(home)/register/page.tsx b/surfsense_web/app/(home)/register/page.tsx index 9421a0156..1fd1a4ecb 100644 --- a/surfsense_web/app/(home)/register/page.tsx +++ b/surfsense_web/app/(home)/register/page.tsx @@ -9,11 +9,11 @@ import { useEffect, useState } from "react"; import { type ExternalToast, toast } from "sonner"; import { registerMutationAtom } from "@/atoms/auth/auth-mutation.atoms"; import { Logo } from "@/components/Logo"; -import { useRuntimeConfig } from "@/components/providers/runtime-config"; import { Button } from "@/components/ui/button"; import { Spinner } from "@/components/ui/spinner"; import { getAuthErrorDetails, isNetworkError, shouldRetry } from "@/lib/auth-errors"; import { getBearerToken } from "@/lib/auth-utils"; +import { AUTH_TYPE } from "@/lib/env-config"; import { AppError, ValidationError } from "@/lib/error"; import { trackRegistrationAttempt, @@ -25,7 +25,6 @@ import { AmbientBackground } from "../login/AmbientBackground"; export default function RegisterPage() { const t = useTranslations("auth"); const tCommon = useTranslations("common"); - const { authType } = useRuntimeConfig(); const [email, setEmail] = useState(""); const [password, setPassword] = useState(""); const [confirmPassword, setConfirmPassword] = useState(""); @@ -45,10 +44,10 @@ export default function RegisterPage() { router.replace("/dashboard"); return; } - if (authType !== "LOCAL") { + if (AUTH_TYPE !== "LOCAL") { router.push("/login"); } - }, [authType, router]); + }, [router]); const handleSubmit = (e: React.FormEvent) => { e.preventDefault(); diff --git a/surfsense_web/app/api/v1/[...path]/route.ts b/surfsense_web/app/api/v1/[...path]/route.ts index 66ea78af5..418bf1a33 100644 --- a/surfsense_web/app/api/v1/[...path]/route.ts +++ b/surfsense_web/app/api/v1/[...path]/route.ts @@ -14,11 +14,7 @@ const HOP_BY_HOP_HEADERS = new Set([ ]); function getBackendBaseUrl() { - const base = - process.env.SURFSENSE_BACKEND_INTERNAL_URL || - // TODO: Remove FASTAPI_BACKEND_INTERNAL_URL after the post-Caddy env migration window. - process.env.FASTAPI_BACKEND_INTERNAL_URL || - "http://backend:8000"; + const base = process.env.FASTAPI_BACKEND_INTERNAL_URL || "http://localhost:8000"; return base.endsWith("/") ? base.slice(0, -1) : base; } diff --git a/surfsense_web/app/api/zero/query/route.ts b/surfsense_web/app/api/zero/query/route.ts index f08b012e7..35ef51fb5 100644 --- a/surfsense_web/app/api/zero/query/route.ts +++ b/surfsense_web/app/api/zero/query/route.ts @@ -1,7 +1,7 @@ import { mustGetQuery } from "@rocicorp/zero"; import { handleQueryRequest } from "@rocicorp/zero/server"; import { NextResponse } from "next/server"; -import { SERVER_BACKEND_URL } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; import type { Context } from "@/types/zero"; import { queries } from "@/zero/queries"; import { schema } from "@/zero/schema"; @@ -11,7 +11,11 @@ import { schema } from "@/zero/schema"; // (e.g. http://backend:8000). The browser-facing NEXT_PUBLIC_FASTAPI_BACKEND_URL // (e.g. http://localhost:8929) does NOT resolve from inside the frontend // container and would make every authenticated Zero query fail with a 503. -const backendURL = SERVER_BACKEND_URL.replace(/\/$/, ""); +const backendURL = ( + process.env.FASTAPI_BACKEND_INTERNAL_URL || + process.env.BACKEND_URL || + "http://localhost:8000" +).replace(/\/$/, ""); async function authenticateRequest( request: Request diff --git a/surfsense_web/app/auth/[...path]/route.ts b/surfsense_web/app/auth/[...path]/route.ts deleted file mode 100644 index 923f6eef3..000000000 --- a/surfsense_web/app/auth/[...path]/route.ts +++ /dev/null @@ -1,74 +0,0 @@ -import type { NextRequest } from "next/server"; - -export const dynamic = "force-dynamic"; - -const HOP_BY_HOP_HEADERS = new Set([ - "connection", - "keep-alive", - "proxy-authenticate", - "proxy-authorization", - "te", - "trailer", - "transfer-encoding", - "upgrade", -]); - -function getBackendBaseUrl() { - const base = - process.env.SURFSENSE_BACKEND_INTERNAL_URL || - // TODO: Remove FASTAPI_BACKEND_INTERNAL_URL after the post-Caddy env migration window. - process.env.FASTAPI_BACKEND_INTERNAL_URL || - "http://backend:8000"; - return base.endsWith("/") ? base.slice(0, -1) : base; -} - -function toUpstreamHeaders(headers: Headers) { - const nextHeaders = new Headers(headers); - nextHeaders.delete("host"); - nextHeaders.delete("content-length"); - return nextHeaders; -} - -function toClientHeaders(headers: Headers) { - const nextHeaders = new Headers(headers); - for (const header of HOP_BY_HOP_HEADERS) { - nextHeaders.delete(header); - } - return nextHeaders; -} - -async function proxy(request: NextRequest, context: { params: Promise<{ path?: string[] }> }) { - const params = await context.params; - const path = params.path?.join("/") || ""; - const upstreamUrl = new URL(`${getBackendBaseUrl()}/auth/${path}`); - upstreamUrl.search = request.nextUrl.search; - - const hasBody = request.method !== "GET" && request.method !== "HEAD"; - - const response = await fetch(upstreamUrl, { - method: request.method, - headers: toUpstreamHeaders(request.headers), - body: hasBody ? request.body : undefined, - // `duplex: "half"` is required by the Fetch spec when streaming a - // ReadableStream as the request body. Avoids buffering uploads in heap. - // @ts-expect-error - `duplex` is not yet in lib.dom RequestInit types. - duplex: hasBody ? "half" : undefined, - redirect: "manual", - }); - - return new Response(response.body, { - status: response.status, - statusText: response.statusText, - headers: toClientHeaders(response.headers), - }); -} - -export { - proxy as GET, - proxy as POST, - proxy as PUT, - proxy as PATCH, - proxy as DELETE, - proxy as OPTIONS, - proxy as HEAD, -}; diff --git a/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-empty-state.tsx b/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-empty-state.tsx index 1ee71c636..b2e7b2532 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-empty-state.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-empty-state.tsx @@ -1,5 +1,5 @@ "use client"; -import { AlarmClock } from "lucide-react"; +import { Workflow } from "lucide-react"; import Link from "next/link"; import { Button } from "@/components/ui/button"; @@ -18,7 +18,7 @@ export function AutomationsEmptyState({ searchSpaceId, canCreate }: AutomationsE return ( <div className="rounded-lg border border-dashed border-border/60 bg-muted/20 px-6 py-12 text-center"> <div className="mx-auto flex h-12 w-12 items-center justify-center rounded-full bg-muted text-muted-foreground"> - <AlarmClock className="h-6 w-6" aria-hidden /> + <Workflow className="h-6 w-6" aria-hidden /> </div> <h3 className="mt-4 text-base font-semibold text-foreground">No automations yet</h3> <p className="mt-1 text-sm text-muted-foreground max-w-md mx-auto"> diff --git a/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-table.tsx b/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-table.tsx index 74c604173..8314a5179 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-table.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-table.tsx @@ -1,5 +1,5 @@ "use client"; -import { AlarmClock, CalendarDays, Info } from "lucide-react"; +import { CalendarDays, Info, Workflow } from "lucide-react"; import { Table, TableBody, TableHead, TableHeader, TableRow } from "@/components/ui/table"; import type { AutomationSummary } from "@/contracts/types/automation.types"; import { AutomationRow } from "./automation-row"; @@ -31,7 +31,7 @@ export function AutomationsTable({ <TableRow className="hover:bg-transparent border-b border-border/60"> <TableHead className="px-4 md:px-6 border-r border-border/60"> <span className="flex items-center gap-1.5 text-sm font-medium text-muted-foreground/70"> - <AlarmClock size={14} className="opacity-60 text-muted-foreground" /> + <Workflow size={14} className="opacity-60 text-muted-foreground" /> Name </span> </TableHead> diff --git a/surfsense_web/app/dashboard/[search_space_id]/automations/components/builder/automation-builder-form.tsx b/surfsense_web/app/dashboard/[search_space_id]/automations/components/builder/automation-builder-form.tsx index a68e53a1c..59967080f 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/automations/components/builder/automation-builder-form.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/automations/components/builder/automation-builder-form.tsx @@ -130,7 +130,7 @@ export function AutomationBuilderForm({ // data into state, so there's no flicker/loop and the user's pick is sticky. const resolvedModels = useMemo<BuilderModels>( () => ({ - chatModelId: form.models.chatModelId || eligibleModels.llm.defaultId || 0, + agentLlmId: form.models.agentLlmId || eligibleModels.llm.defaultId || 0, imageConfigId: form.models.imageConfigId || eligibleModels.image.defaultId || 0, visionConfigId: form.models.visionConfigId || eligibleModels.vision.defaultId || 0, }), diff --git a/surfsense_web/app/dashboard/[search_space_id]/automations/components/builder/automation-model-fields.tsx b/surfsense_web/app/dashboard/[search_space_id]/automations/components/builder/automation-model-fields.tsx index 6dd42366b..2c4a0bf60 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/automations/components/builder/automation-model-fields.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/automations/components/builder/automation-model-fields.tsx @@ -25,7 +25,7 @@ import { getProviderIcon } from "@/lib/provider-icons"; import { Field } from "./form-field"; export interface AutomationModelSelection { - chatModelId: number; + agentLlmId: number; imageConfigId: number; visionConfigId: number; } @@ -39,7 +39,7 @@ interface AutomationModelFieldsProps { } /** - * Three eligible-only model pickers (Chat / Image / Vision) for the + * Three eligible-only model pickers (Agent LLM / Image / Vision) for the * automation builder + chat approval card. Options come from * {@link useAutomationEligibleModels} (premium globals + BYOK only); selection * is validated + snapshotted onto `definition.models` at create time. @@ -51,18 +51,18 @@ export function AutomationModelFields({ errors, }: AutomationModelFieldsProps) { const { llm, image, vision, isLoading } = useAutomationEligibleModels(); - const rolesHref = `/dashboard/${searchSpaceId}/search-space-settings/models`; + const rolesHref = `/dashboard/${searchSpaceId}/search-space-settings/roles`; return ( <div className="flex flex-col gap-4"> <ModelSelectField - label="Chat model" + label="Agent model" kind={llm} - value={value.chatModelId} + value={value.agentLlmId} isLoading={isLoading} rolesHref={rolesHref} - error={errors?.chatModelId} - onChange={(id) => onChange({ chatModelId: id })} + error={errors?.agentLlmId} + onChange={(id) => onChange({ agentLlmId: id })} /> <ModelSelectField label="Image model" diff --git a/surfsense_web/app/dashboard/[search_space_id]/buy-more/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/buy-more/page.tsx index 8ea4c1d7d..b4ec015b7 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/buy-more/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/buy-more/page.tsx @@ -1,15 +1,48 @@ "use client"; -import { AutoReloadSettings } from "@/components/settings/auto-reload-settings"; -import { BuyCreditsContent } from "@/components/settings/buy-credits-content"; +import { useState } from "react"; +import { BuyPagesContent } from "@/components/settings/buy-pages-content"; +import { BuyTokensContent } from "@/components/settings/buy-tokens-content"; +import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs"; + +const TABS = [ + { id: "pages", label: "Pages" }, + { id: "tokens", label: "Premium Credit" }, +] as const; + +type TabId = (typeof TABS)[number]["id"]; export default function BuyMorePage() { + const [activeTab, setActiveTab] = useState<TabId>("pages"); + return ( - <div className="flex min-h-[37rem] w-full select-none items-center justify-center py-8"> - <div className="w-full max-w-md space-y-8"> - <BuyCreditsContent /> - <AutoReloadSettings /> - </div> + <div className="w-full select-none"> + <Tabs + value={activeTab} + onValueChange={(value) => { + setActiveTab(value as TabId); + }} + className="relative min-h-[37rem] w-full" + > + <TabsList className="absolute top-20 left-1/2 -translate-x-1/2 rounded-xl bg-accent p-1"> + {TABS.map((tab) => ( + <TabsTrigger + key={tab.id} + value={tab.id} + className="h-8 rounded-lg px-4 text-sm font-semibold text-accent-foreground transition-colors hover:bg-transparent hover:text-white data-[state=active]:bg-[#4a4a4a] data-[state=active]:text-white data-[state=active]:shadow-none" + > + {tab.label} + </TabsTrigger> + ))} + </TabsList> + + <TabsContent value="pages" className="mt-0 flex min-h-[37rem] items-center pt-14"> + <BuyPagesContent /> + </TabsContent> + <TabsContent value="tokens" className="mt-0 flex min-h-[37rem] items-center pt-14"> + <BuyTokensContent /> + </TabsContent> + </Tabs> </div> ); } diff --git a/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx b/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx index c3eeb2bf6..3a41b5998 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx @@ -4,15 +4,15 @@ import { useAtomValue, useSetAtom } from "jotai"; import { useParams, usePathname, useRouter } from "next/navigation"; import { useTranslations } from "next-intl"; import type React from "react"; -import { useEffect, useState } from "react"; +import { useCallback, useEffect, useRef, useState } from "react"; +import { toast } from "sonner"; import { pendingUserImageDataUrlsAtom } from "@/atoms/chat/pending-user-images.atom"; import { myAccessAtom } from "@/atoms/members/members-query.atoms"; +import { updateLLMPreferencesMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; import { - globalLlmConfigStatusAtom, - globalModelConnectionsAtom, - modelConnectionsAtom, - modelRolesAtom, -} from "@/atoms/model-connections/model-connections-query.atoms"; + globalNewLLMConfigsAtom, + llmPreferencesAtom, +} from "@/atoms/new-llm-config/new-llm-config-query.atoms"; import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms"; import { DocumentUploadDialogProvider } from "@/components/assistant-ui/document-upload-popup"; import { LayoutDataProvider } from "@/components/layout"; @@ -34,30 +34,30 @@ export function DashboardClientLayout({ const router = useRouter(); const pathname = usePathname(); const { search_space_id } = useParams(); - const activeSearchSpaceId = useAtomValue(activeSearchSpaceIdAtom); const setActiveSearchSpaceIdState = useSetAtom(activeSearchSpaceIdAtom); const setPendingUserImageUrls = useSetAtom(pendingUserImageDataUrlsAtom); - const { data: modelRoles = {}, isLoading: loading, error } = useAtomValue(modelRolesAtom); - const { data: globalConnections = [], isLoading: globalConfigsLoading } = useAtomValue( - globalModelConnectionsAtom - ); - const { data: modelConnections = [], isLoading: modelConnectionsLoading } = - useAtomValue(modelConnectionsAtom); - const { data: globalConfigStatus, isLoading: globalConfigStatusLoading } = - useAtomValue(globalLlmConfigStatusAtom); + const { + data: preferences = {}, + isFetching: loading, + error, + refetch: refetchPreferences, + } = useAtomValue(llmPreferencesAtom); + const { data: globalConfigs = [], isFetching: globalConfigsLoading } = + useAtomValue(globalNewLLMConfigsAtom); + const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom); + + const isOnboardingComplete = useCallback(() => { + return isLlmOnboardingComplete(preferences.agent_llm_id, globalConfigs.length > 0); + }, [preferences.agent_llm_id, globalConfigs.length]); const { data: access = null, isLoading: accessLoading } = useAtomValue(myAccessAtom); const [hasCheckedOnboarding, setHasCheckedOnboarding] = useState(false); + const [isAutoConfiguring, setIsAutoConfiguring] = useState(false); + const hasAttemptedAutoConfig = useRef(false); const isOnboardingPage = pathname?.includes("/onboard"); const isOwner = access?.is_owner ?? false; - const isSearchSpaceReady = activeSearchSpaceId === searchSpaceId; - - useEffect(() => { - if (isSearchSpaceReady) return; - setHasCheckedOnboarding(false); - }, [isSearchSpaceReady]); useEffect(() => { if (isOnboardingPage) { @@ -66,27 +66,13 @@ export function DashboardClientLayout({ } if ( - isSearchSpaceReady && !loading && !accessLoading && !globalConfigsLoading && - !globalConfigStatusLoading && - !modelConnectionsLoading && - !hasCheckedOnboarding + !hasCheckedOnboarding && + !isAutoConfiguring ) { - // Onboarding is only relevant when no operator-provided - // global_llm_config.yaml exists. When it does, search spaces inherit - // the global config and should never be forced into onboarding. - if (globalConfigStatus?.exists) { - setHasCheckedOnboarding(true); - return; - } - - const onboardingComplete = isLlmOnboardingComplete( - modelRoles.chat_model_id, - globalConnections, - modelConnections - ); + const onboardingComplete = isOnboardingComplete(); if (onboardingComplete) { setHasCheckedOnboarding(true); @@ -98,25 +84,56 @@ export function DashboardClientLayout({ return; } + if (globalConfigs.length > 0 && !hasAttemptedAutoConfig.current) { + hasAttemptedAutoConfig.current = true; + setIsAutoConfiguring(true); + + const autoConfigureWithGlobal = async () => { + try { + const firstGlobalConfig = globalConfigs[0]; + await updatePreferences({ + search_space_id: Number(searchSpaceId), + data: { + agent_llm_id: firstGlobalConfig.id, + }, + }); + + await refetchPreferences(); + + toast.success("AI configured automatically!", { + description: `Using ${firstGlobalConfig.name}. Customize in Settings.`, + }); + + setHasCheckedOnboarding(true); + } catch (error) { + console.error("Auto-configuration failed:", error); + router.push(`/dashboard/${searchSpaceId}/onboard`); + } finally { + setIsAutoConfiguring(false); + } + }; + + autoConfigureWithGlobal(); + return; + } + router.push(`/dashboard/${searchSpaceId}/onboard`); setHasCheckedOnboarding(true); } }, [ - isSearchSpaceReady, loading, accessLoading, globalConfigsLoading, - globalConfigStatusLoading, - globalConfigStatus, - modelConnectionsLoading, - modelRoles.chat_model_id, - globalConnections, - modelConnections, + isOnboardingComplete, isOnboardingPage, isOwner, + isAutoConfiguring, + globalConfigs, router, searchSpaceId, hasCheckedOnboarding, + updatePreferences, + refetchPreferences, ]); const electronAPI = useElectronAPI(); @@ -168,14 +185,10 @@ export function DashboardClientLayout({ // Determine if we should show loading const shouldShowLoading = - !hasCheckedOnboarding && - (!isSearchSpaceReady || - loading || - accessLoading || - globalConfigsLoading || - globalConfigStatusLoading || - modelConnectionsLoading) && - !isOnboardingPage; + (!hasCheckedOnboarding && + (loading || accessLoading || globalConfigsLoading) && + !isOnboardingPage) || + isAutoConfiguring; // Use global loading screen - spinner animation won't reset useGlobalLoadingEffect(shouldShowLoading); diff --git a/surfsense_web/app/dashboard/[search_space_id]/connectors/callback/route.ts b/surfsense_web/app/dashboard/[search_space_id]/connectors/callback/route.ts index 96f7dc349..304f33a33 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/connectors/callback/route.ts +++ b/surfsense_web/app/dashboard/[search_space_id]/connectors/callback/route.ts @@ -16,12 +16,9 @@ export async function GET( }; const result = JSON.stringify(payload); - const response = new NextResponse(null, { - status: 302, - headers: { - Location: `/dashboard/${search_space_id}/new-chat`, - }, - }); + const redirectUrl = new URL(`/dashboard/${search_space_id}/new-chat`, request.url); + + const response = NextResponse.redirect(redirectUrl, { status: 302 }); response.cookies.set(OAUTH_RESULT_COOKIE, result, { path: "/", maxAge: 60, diff --git a/surfsense_web/app/dashboard/[search_space_id]/earn-credits/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/earn-credits/page.tsx deleted file mode 100644 index 3ff4c3cf8..000000000 --- a/surfsense_web/app/dashboard/[search_space_id]/earn-credits/page.tsx +++ /dev/null @@ -1,11 +0,0 @@ -"use client"; - -import { EarnCreditsContent } from "@/components/settings/earn-credits-content"; - -export default function EarnCreditsPage() { - return ( - <div className="w-full select-none space-y-6"> - <EarnCreditsContent /> - </div> - ); -} diff --git a/surfsense_web/app/dashboard/[search_space_id]/more-pages/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/more-pages/page.tsx index 46f1965d0..4b3301b9f 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/more-pages/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/more-pages/page.tsx @@ -1,18 +1,11 @@ "use client"; -import { useParams, useRouter } from "next/navigation"; -import { useEffect } from "react"; +import { MorePagesContent } from "@/components/settings/more-pages-content"; -// Legacy route kept as a redirect: older "insufficient credits" notifications -// and bookmarks may still point at /more-pages. export default function MorePagesPage() { - const router = useRouter(); - const params = useParams(); - const searchSpaceId = params?.search_space_id ?? ""; - - useEffect(() => { - router.replace(`/dashboard/${searchSpaceId}/earn-credits`); - }, [router, searchSpaceId]); - - return null; + return ( + <div className="w-full select-none space-y-6"> + <MorePagesContent /> + </div> + ); } diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 3594e15eb..75cfa4184 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -77,6 +77,11 @@ import { convertToThreadMessage, reconcileInterruptedAssistantMessages, } from "@/lib/chat/message-utils"; +import { + isPodcastGenerating, + looksLikePodcastRequest, + setActivePodcastTaskId, +} from "@/lib/chat/podcast-state"; import { createStreamFlushHelpers } from "@/lib/chat/stream-flush"; import { consumeSseEvents, processSharedStreamEvent } from "@/lib/chat/stream-pipeline"; import { @@ -106,7 +111,7 @@ import { extractUserTurnForNewChatApi, type NewChatUserImagePayload, } from "@/lib/chat/user-turn-api-parts"; -import { buildBackendUrl } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; import { NotFoundError } from "@/lib/error"; import { trackChatBlocked, @@ -613,18 +618,6 @@ export default function NewChatPage() { return; } - if (normalized.channel === "inline") { - if (normalized.assistantMessage) { - await persistAssistantErrorMessage({ - threadId, - assistantMsgId, - text: normalized.assistantMessage, - }); - } - toast.error(normalized.userMessage); - return; - } - toast.error(normalized.userMessage); }, [currentUser?.id, persistAssistantErrorMessage, searchSpaceId, setPremiumAlertForThread] @@ -765,9 +758,6 @@ export default function NewChatPage() { const loadedMessages = reconcileInterruptedAssistantMessages(messagesResponse.messages).map( convertToThreadMessage ); - if (messages.length > 0 && loadedMessages.length < messages.length) { - return; - } setMessages(loadedMessages); tokenUsageStore.clear(); @@ -788,7 +778,6 @@ export default function NewChatPage() { }, [ activeThreadId, isRunning, - messages.length, setMessageDocumentsMap, threadMessagesQuery.data, tokenUsageStore, @@ -919,9 +908,10 @@ export default function NewChatPage() { if (threadId) { const token = getBearerToken(); if (token) { + const backendUrl = BACKEND_URL; try { const response = await fetch( - buildBackendUrl(`/api/v1/threads/${threadId}/cancel-active-turn`), + `${backendUrl}/api/v1/threads/${threadId}/cancel-active-turn`, { method: "POST", headers: { @@ -964,6 +954,11 @@ export default function NewChatPage() { if (!userQuery.trim() && userImages.length === 0) return; + if (userQuery.trim() && isPodcastGenerating() && looksLikePodcastRequest(userQuery)) { + toast.warning("A podcast is already being generated."); + return; + } + const token = getBearerToken(); if (!token) { toast.error("Not authenticated. Please log in again."); @@ -1109,6 +1104,7 @@ export default function NewChatPage() { let streamBatcher: FrameBatchedUpdater | null = null; try { + const backendUrl = BACKEND_URL; const selection = await getAgentFilesystemSelection(searchSpaceId, { localFilesystemEnabled, }); @@ -1145,7 +1141,7 @@ export default function NewChatPage() { } const response = await fetchWithTurnCancellingRetry(() => - fetch(buildBackendUrl("/api/v1/new_chat"), { + fetch(`${backendUrl}/api/v1/new_chat`, { method: "POST", headers: { "Content-Type": "application/json", @@ -1222,6 +1218,17 @@ export default function NewChatPage() { recentCancelRequestedAtRef.current = Date.now(); } }, + onToolOutputAvailable: (event, sharedCtx) => { + if (event.output?.status === "pending" && event.output?.podcast_id) { + const idx = sharedCtx.toolCallIndices.get(event.toolCallId); + if (idx !== undefined) { + const part = sharedCtx.contentPartsState.contentParts[idx]; + if (part?.type === "tool-call" && part.toolName === "generate_podcast") { + setActivePodcastTaskId(String(event.output.podcast_id)); + } + } + } + }, }) ) { return; @@ -1640,11 +1647,12 @@ export default function NewChatPage() { } try { + const backendUrl = BACKEND_URL; const selection = await getAgentFilesystemSelection(searchSpaceId, { localFilesystemEnabled, }); const response = await fetchWithTurnCancellingRetry(() => - fetch(buildBackendUrl(`/api/v1/threads/${resumeThreadId}/resume`), { + fetch(`${backendUrl}/api/v1/threads/${resumeThreadId}/resume`, { method: "POST", headers: { "Content-Type": "application/json", @@ -2179,6 +2187,17 @@ export default function NewChatPage() { recentCancelRequestedAtRef.current = Date.now(); } }, + onToolOutputAvailable: (event, sharedCtx) => { + if (event.output?.status === "pending" && event.output?.podcast_id) { + const idx = sharedCtx.toolCallIndices.get(event.toolCallId); + if (idx !== undefined) { + const part = sharedCtx.contentPartsState.contentParts[idx]; + if (part?.type === "tool-call" && part.toolName === "generate_podcast") { + setActivePodcastTaskId(String(event.output.podcast_id)); + } + } + } + }, }) ) { return; @@ -2550,7 +2569,7 @@ export default function NewChatPage() { > <div key={searchSpaceId} className="flex h-full overflow-hidden"> <div className="relative flex-1 flex flex-col min-w-0 overflow-hidden"> - <Thread hasActiveThread={!!activeThreadId} /> + <Thread /> {isThreadMessagesLoading ? ( <div className="absolute inset-0 z-10 bg-panel"> <ThreadMessagesSkeleton /> diff --git a/surfsense_web/app/dashboard/[search_space_id]/onboard/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/onboard/page.tsx index 8efe81cce..de5c961e8 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/onboard/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/onboard/page.tsx @@ -2,89 +2,193 @@ import { useAtomValue } from "jotai"; import { useParams, useRouter } from "next/navigation"; -import { useEffect, useMemo } from "react"; +import { useEffect, useRef, useState } from "react"; +import { toast } from "sonner"; import { - globalLlmConfigStatusAtom, - globalModelConnectionsAtom, - modelConnectionsAtom, - modelRolesAtom, -} from "@/atoms/model-connections/model-connections-query.atoms"; + createNewLLMConfigMutationAtom, + updateLLMPreferencesMutationAtom, +} from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; +import { + globalNewLLMConfigsAtom, + llmPreferencesAtom, +} from "@/atoms/new-llm-config/new-llm-config-query.atoms"; import { Logo } from "@/components/Logo"; -import { ModelProviderConnectionsPanel } from "@/components/settings/model-connections/model-provider-connections-panel"; +import { LLMConfigForm, type LLMConfigFormData } from "@/components/shared/llm-config-form"; import { Button } from "@/components/ui/button"; +import { Spinner } from "@/components/ui/spinner"; import { useGlobalLoadingEffect } from "@/hooks/use-global-loading"; import { getBearerToken, redirectToLogin } from "@/lib/auth-utils"; -import { hasEnabledChatModel, isLlmOnboardingComplete } from "@/lib/onboarding"; +import { isLlmOnboardingComplete } from "@/lib/onboarding"; export default function OnboardPage() { const router = useRouter(); const params = useParams(); const searchSpaceId = Number(params.search_space_id); - const { data: globalConnections = [], isLoading: globalLoading } = useAtomValue( - globalModelConnectionsAtom - ); - const { data: connections = [] } = useAtomValue(modelConnectionsAtom); - const { data: roles = {}, isLoading: rolesLoading } = useAtomValue(modelRolesAtom); - const { data: globalConfigStatus, isLoading: globalConfigStatusLoading } = - useAtomValue(globalLlmConfigStatusAtom); + // Queries + const { + data: globalConfigs = [], + isFetching: globalConfigsLoading, + isSuccess: globalConfigsLoaded, + } = useAtomValue(globalNewLLMConfigsAtom); + const { data: preferences = {}, isFetching: preferencesLoading } = + useAtomValue(llmPreferencesAtom); + // Mutations + const { mutateAsync: createConfig, isPending: isCreating } = useAtomValue( + createNewLLMConfigMutationAtom + ); + const { mutateAsync: updatePreferences, isPending: isUpdatingPreferences } = useAtomValue( + updateLLMPreferencesMutationAtom + ); + + // State + const [isAutoConfiguring, setIsAutoConfiguring] = useState(false); + const hasAttemptedAutoConfig = useRef(false); + + // Check authentication useEffect(() => { - if (!getBearerToken()) redirectToLogin(); + const token = getBearerToken(); + if (!token) { + redirectToLogin(); + } }, []); - const hasUsableChatModel = useMemo( - () => hasEnabledChatModel([...globalConnections, ...connections]), - [globalConnections, connections] + const isOnboardingComplete = isLlmOnboardingComplete( + preferences.agent_llm_id, + globalConfigs.length > 0 ); - const onboardingComplete = isLlmOnboardingComplete( - roles.chat_model_id, - globalConnections, - connections - ); - - const isLoading = globalLoading || rolesLoading || globalConfigStatusLoading; - - // Onboarding only applies when no global_llm_config.yaml exists. If a global - // config is present (or onboarding is already complete), leave this page. - const shouldLeaveOnboarding = - !isLoading && (Boolean(globalConfigStatus?.exists) || onboardingComplete); - useEffect(() => { - if (shouldLeaveOnboarding) { - router.replace(`/dashboard/${searchSpaceId}/new-chat`); + if (!preferencesLoading && globalConfigsLoaded && isOnboardingComplete) { + router.push(`/dashboard/${searchSpaceId}/new-chat`); } - }, [shouldLeaveOnboarding, router, searchSpaceId]); + }, [preferencesLoading, globalConfigsLoaded, isOnboardingComplete, router, searchSpaceId]); - useGlobalLoadingEffect(isLoading || shouldLeaveOnboarding); + useEffect(() => { + const autoConfigureWithGlobal = async () => { + if (hasAttemptedAutoConfig.current) return; + if (globalConfigsLoading || preferencesLoading) return; + if (!globalConfigsLoaded) return; + if (isOnboardingComplete) return; - if (isLoading || shouldLeaveOnboarding) return null; + if (globalConfigs.length > 0) { + hasAttemptedAutoConfig.current = true; + setIsAutoConfiguring(true); + + try { + const firstGlobalConfig = globalConfigs[0]; + + await updatePreferences({ + search_space_id: searchSpaceId, + data: { + agent_llm_id: firstGlobalConfig.id, + }, + }); + + toast.success("AI configured automatically!", { + description: `Using ${firstGlobalConfig.name}. You can customize this later in Settings.`, + }); + + router.push(`/dashboard/${searchSpaceId}/new-chat`); + } catch (error) { + console.error("Auto-configuration failed:", error); + toast.error("Auto-configuration failed. Please add a configuration manually."); + setIsAutoConfiguring(false); + } + } + }; + + autoConfigureWithGlobal(); + }, [ + globalConfigs, + globalConfigsLoading, + globalConfigsLoaded, + preferencesLoading, + isOnboardingComplete, + updatePreferences, + searchSpaceId, + router, + ]); + + const handleSubmit = async (formData: LLMConfigFormData) => { + try { + const newConfig = await createConfig(formData); + + await updatePreferences({ + search_space_id: searchSpaceId, + data: { + agent_llm_id: newConfig.id, + }, + }); + + toast.success("Configuration created!", { + description: "Redirecting to chat...", + }); + + router.push(`/dashboard/${searchSpaceId}/new-chat`); + } catch (error) { + console.error("Failed to create config:", error); + if (error instanceof Error) { + toast.error(error.message || "Failed to create configuration"); + } + } + }; + + const isSubmitting = isCreating || isUpdatingPreferences; + + const isLoading = globalConfigsLoading || preferencesLoading || isAutoConfiguring; + useGlobalLoadingEffect(isLoading); + + if (isLoading) { + return null; + } + + if (globalConfigs.length > 0 && !isAutoConfiguring) { + return null; + } return ( - <div className="flex min-h-screen select-none flex-col items-center justify-center bg-main-panel p-4"> - <div className="w-full max-w-3xl space-y-6 text-center"> - <Logo className="mx-auto h-12 w-12" /> - <div className="space-y-2"> - <h1 className="text-2xl font-semibold tracking-tight">Choose a model</h1> - <p className="text-sm text-muted-foreground"> - Connect any supported provider, then enable the models you want SurfSense to use. - </p> + <div className="h-screen flex flex-col items-center p-4 bg-main-panel select-none overflow-hidden"> + <div className="w-full max-w-lg flex flex-col min-h-0 h-full gap-6 py-8"> + {/* Header */} + <div className="text-center space-y-3 shrink-0"> + <Logo className="w-12 h-12 mx-auto" /> + <div className="space-y-1"> + <h1 className="text-2xl font-semibold tracking-tight">Configure Your AI</h1> + <p className="text-sm text-muted-foreground"> + Add your LLM provider to get started with SurfSense + </p> + </div> + </div> + + {/* Form card */} + <div className="rounded-xl border bg-main-panel flex-1 min-h-0 overflow-y-auto px-6 py-6"> + <LLMConfigForm + searchSpaceId={searchSpaceId} + onSubmit={handleSubmit} + mode="create" + showAdvanced={true} + formId="onboard-config-form" + initialData={{ + citations_enabled: true, + use_default_system_instructions: true, + }} + /> + </div> + + {/* Footer */} + <div className="text-center space-y-4 shrink-0"> + <Button + type="submit" + form="onboard-config-form" + disabled={isSubmitting} + className="relative text-sm h-9 min-w-[180px]" + > + <span className={isSubmitting ? "opacity-0" : ""}>Start Using SurfSense</span> + {isSubmitting && <Spinner size="sm" className="absolute" />} + </Button> + <p className="text-xs text-muted-foreground">You can add more configurations later</p> </div> - <ModelProviderConnectionsPanel - searchSpaceId={searchSpaceId} - connections={connections} - className="flex flex-col gap-6 text-left" - footerAction={ - <Button - className="min-w-[112px]" - disabled={!onboardingComplete || !hasUsableChatModel} - onClick={() => router.push(`/dashboard/${searchSpaceId}/new-chat`)} - > - Start - </Button> - } - showAddProviderHeader={false} - /> </div> </div> ); diff --git a/surfsense_web/app/dashboard/[search_space_id]/purchase-success/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/purchase-success/page.tsx index a8a88c5a5..8eaec3e5a 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/purchase-success/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/purchase-success/page.tsx @@ -112,7 +112,9 @@ export default function PurchaseSuccessPage() { {state.kind === "still_pending" && "Your payment is still being processed by your bank. We'll apply your purchase as soon as it clears — usually within a few minutes. You can safely close this page."} {state.kind === "completed" && - `Added ${formatCredit(state.data.credit_micros_granted ?? 0)} of credit to your account.`} + (state.data.purchase_type === "page_packs" + ? `Added ${formatNumber(state.data.pages_granted ?? 0)} pages to your account.` + : `Added ${formatCredit(state.data.premium_credit_micros_granted ?? 0)} of premium credit to your account.`)} {state.kind === "failed" && "Stripe reported the checkout as failed or expired. Your card was not charged."} {state.kind === "error" && @@ -121,9 +123,18 @@ export default function PurchaseSuccessPage() { </CardDescription> </CardHeader> <CardContent className="space-y-3 text-center"> - {state.kind === "completed" && ( + {state.kind === "completed" && state.data.purchase_type === "page_packs" && ( <p className="text-sm text-muted-foreground"> - New credit balance: {formatCredit(state.data.credit_micros_balance ?? 0)} + New balance: {formatNumber(state.data.pages_limit ?? 0)} total pages + {typeof state.data.pages_used === "number" + ? ` (${formatNumber((state.data.pages_limit ?? 0) - state.data.pages_used)} remaining)` + : ""} + </p> + )} + {state.kind === "completed" && state.data.purchase_type === "premium_tokens" && ( + <p className="text-sm text-muted-foreground"> + New premium credit balance:{" "} + {formatCredit(state.data.premium_credit_micros_limit ?? 0)} </p> )} {state.kind === "error" && ( @@ -135,7 +146,7 @@ export default function PurchaseSuccessPage() { <Link href={`/dashboard/${searchSpaceId}/new-chat`}>Back to Dashboard</Link> </Button> <Button asChild variant="outline" className="w-full"> - <Link href={`/dashboard/${searchSpaceId}/buy-more`}>Buy credits</Link> + <Link href={`/dashboard/${searchSpaceId}/buy-more`}>Buy More</Link> </Button> </CardFooter> </Card> @@ -143,6 +154,10 @@ export default function PurchaseSuccessPage() { ); } +function formatNumber(n: number): string { + return new Intl.NumberFormat("en-US").format(n); +} + function formatCredit(micros: number): string { const dollars = micros / 1_000_000; return new Intl.NumberFormat("en-US", { diff --git a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/image-models/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/image-models/page.tsx new file mode 100644 index 000000000..b300f8078 --- /dev/null +++ b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/image-models/page.tsx @@ -0,0 +1,6 @@ +import { ImageModelManager } from "@/components/settings/image-model-manager"; + +export default async function Page({ params }: { params: Promise<{ search_space_id: string }> }) { + const { search_space_id } = await params; + return <ImageModelManager searchSpaceId={Number(search_space_id)} />; +} diff --git a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/layout-shell.tsx b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/layout-shell.tsx index bb928f8f7..22f68edab 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/layout-shell.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/layout-shell.tsx @@ -1,6 +1,15 @@ "use client"; -import { BookText, Cpu, Earth, Settings, UserKey } from "lucide-react"; +import { + BookText, + Bot, + CircleUser, + Earth, + ImageIcon, + ListChecks, + ScanEye, + UserKey, +} from "lucide-react"; import Link from "next/link"; import { useSelectedLayoutSegment } from "next/navigation"; import { useTranslations } from "next-intl"; @@ -11,7 +20,10 @@ import { cn } from "@/lib/utils"; export type SearchSpaceSettingsTab = | "general" + | "roles" | "models" + | "image-models" + | "vision-models" | "team-roles" | "prompts" | "public-links"; @@ -43,12 +55,27 @@ export function SearchSpaceSettingsLayoutShell({ { value: "general" as const, label: t("nav_general"), - icon: <Settings className="h-4 w-4" />, + icon: <CircleUser className="h-4 w-4" />, + }, + { + value: "roles" as const, + label: t("nav_role_assignments"), + icon: <ListChecks className="h-4 w-4" />, }, { value: "models" as const, - label: t("nav_models"), - icon: <Cpu className="h-4 w-4" />, + label: t("nav_agent_models"), + icon: <Bot className="h-4 w-4" />, + }, + { + value: "image-models" as const, + label: t("nav_image_models"), + icon: <ImageIcon className="h-4 w-4" />, + }, + { + value: "vision-models" as const, + label: t("nav_vision_models"), + icon: <ScanEye className="h-4 w-4" />, }, { value: "team-roles" as const, diff --git a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/models/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/models/page.tsx index c97ef7630..d68194782 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/models/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/models/page.tsx @@ -1,6 +1,6 @@ -import { ModelConnectionsSettings } from "@/components/settings/model-connections-settings"; +import { AgentModelManager } from "@/components/settings/agent-model-manager"; export default async function Page({ params }: { params: Promise<{ search_space_id: string }> }) { const { search_space_id } = await params; - return <ModelConnectionsSettings searchSpaceId={Number(search_space_id)} />; + return <AgentModelManager searchSpaceId={Number(search_space_id)} />; } diff --git a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/roles/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/roles/page.tsx new file mode 100644 index 000000000..5bad50cd3 --- /dev/null +++ b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/roles/page.tsx @@ -0,0 +1,6 @@ +import { LLMRoleManager } from "@/components/settings/llm-role-manager"; + +export default async function Page({ params }: { params: Promise<{ search_space_id: string }> }) { + const { search_space_id } = await params; + return <LLMRoleManager key={search_space_id} searchSpaceId={Number(search_space_id)} />; +} diff --git a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/vision-models/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/vision-models/page.tsx new file mode 100644 index 000000000..06aea003a --- /dev/null +++ b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/vision-models/page.tsx @@ -0,0 +1,6 @@ +import { VisionModelManager } from "@/components/settings/vision-model-manager"; + +export default async function Page({ params }: { params: Promise<{ search_space_id: string }> }) { + const { search_space_id } = await params; + return <VisionModelManager searchSpaceId={Number(search_space_id)} />; +} diff --git a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/MessagingChannelsContent.tsx b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/MessagingChannelsContent.tsx index 4a3c5e9e7..b0cb6699c 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/MessagingChannelsContent.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/MessagingChannelsContent.tsx @@ -1,11 +1,10 @@ "use client"; -import { AlertTriangle, RefreshCw, ShieldAlert } from "lucide-react"; +import { RefreshCw, ShieldAlert } from "lucide-react"; import { useParams } from "next/navigation"; import { QRCodeSVG } from "qrcode.react"; import { useCallback, useEffect, useState } from "react"; import { toast } from "sonner"; -import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; import { Button } from "@/components/ui/button"; import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card"; import { @@ -20,7 +19,7 @@ import { Skeleton } from "@/components/ui/skeleton"; import type { SearchSpace } from "@/contracts/types/search-space.types"; import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service"; import { authenticatedFetch } from "@/lib/auth-utils"; -import { buildBackendUrl } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; import { cn } from "@/lib/utils"; type GatewayConnection = { @@ -40,7 +39,6 @@ type GatewayConnection = { }; type GatewayConfig = { - enabled: boolean; telegram_enabled: boolean; whatsapp_intake_mode: "disabled" | "cloud" | "baileys"; slack_enabled: boolean; @@ -49,14 +47,6 @@ type GatewayConfig = { type GatewayConfigState = GatewayConfig | null; -const DISABLED_GATEWAY_CONFIG: GatewayConfig = { - enabled: false, - telegram_enabled: false, - whatsapp_intake_mode: "disabled", - slack_enabled: false, - discord_enabled: false, -}; - type Pairing = { binding_id: number; code: string; @@ -90,26 +80,16 @@ export function MessagingChannelsContent() { const whatsappMode = gatewayConfig?.whatsapp_intake_mode ?? "disabled"; const slackGatewayEnabled = gatewayConfig?.slack_enabled ?? false; const discordGatewayEnabled = gatewayConfig?.discord_enabled ?? false; - const gatewayDisabled = gatewayConfig?.enabled === false; const fetchConnections = useCallback(async (platform?: GatewayPlatform) => { - const res = await authenticatedFetch( - buildBackendUrl("/api/v1/gateway/connections", platform ? { platform } : undefined) - ); - if (!res.ok) return []; - const data = await res.json(); - return Array.isArray(data) ? (data as GatewayConnection[]) : []; + const query = platform ? `?platform=${encodeURIComponent(platform)}` : ""; + const res = await authenticatedFetch(`${BACKEND_URL}/api/v1/gateway/connections${query}`); + return (await res.json()) as GatewayConnection[]; }, []); - const fetchGatewayConfig = useCallback(async (): Promise<GatewayConfig> => { - const res = await authenticatedFetch(buildBackendUrl("/api/v1/gateway/config")); - if (!res.ok) return DISABLED_GATEWAY_CONFIG; - const data = (await res.json()) as Partial<GatewayConfig>; - return { - ...DISABLED_GATEWAY_CONFIG, - ...data, - enabled: data.enabled ?? true, - }; + const fetchGatewayConfig = useCallback(async () => { + const res = await authenticatedFetch(`${BACKEND_URL}/api/v1/gateway/config`); + return (await res.json()) as GatewayConfig; }, []); const refresh = useCallback(async () => { @@ -145,9 +125,7 @@ export function MessagingChannelsContent() { const refreshBaileysHealth = useCallback(async () => { if (whatsappMode !== "baileys") return; - const res = await authenticatedFetch( - buildBackendUrl("/api/v1/gateway/whatsapp/baileys/health") - ); + const res = await authenticatedFetch(`${BACKEND_URL}/api/v1/gateway/whatsapp/baileys/health`); if (!res.ok) return; const data = (await res.json()) as BaileysHealth; setBaileysHealth(data); @@ -158,7 +136,7 @@ export function MessagingChannelsContent() { }, [refreshBaileysHealth]); async function startPairing(platform: PairingPlatform) { - const res = await authenticatedFetch(buildBackendUrl("/api/v1/gateway/bindings/start"), { + const res = await authenticatedFetch(`${BACKEND_URL}/api/v1/gateway/bindings/start`, { method: "POST", headers: { "Content-Type": "application/json" }, body: JSON.stringify({ platform, search_space_id: searchSpaceId }), @@ -170,7 +148,7 @@ export function MessagingChannelsContent() { async function installSlackGateway() { const res = await authenticatedFetch( - buildBackendUrl("/api/v1/gateway/slack/install", { search_space_id: searchSpaceId }) + `${BACKEND_URL}/api/v1/gateway/slack/install?search_space_id=${searchSpaceId}` ); if (!res.ok) return; const data = (await res.json()) as { auth_url?: string }; @@ -181,7 +159,7 @@ export function MessagingChannelsContent() { async function installDiscordGateway() { const res = await authenticatedFetch( - buildBackendUrl("/api/v1/gateway/discord/install", { search_space_id: searchSpaceId }) + `${BACKEND_URL}/api/v1/gateway/discord/install?search_space_id=${searchSpaceId}` ); if (!res.ok) return; const data = (await res.json()) as { auth_url?: string }; @@ -203,8 +181,8 @@ export function MessagingChannelsContent() { async function revoke(connection: GatewayConnection) { const url = connection.route_type === "account" && connection.account_id - ? buildBackendUrl(`/api/v1/gateway/accounts/${connection.account_id}`) - : buildBackendUrl(`/api/v1/gateway/bindings/${connection.id}`); + ? `${BACKEND_URL}/api/v1/gateway/accounts/${connection.account_id}` + : `${BACKEND_URL}/api/v1/gateway/bindings/${connection.id}`; await authenticatedFetch(url, { method: "DELETE", }); @@ -227,8 +205,8 @@ export function MessagingChannelsContent() { ); const url = connection.route_type === "account" && connection.account_id - ? buildBackendUrl(`/api/v1/gateway/accounts/${connection.account_id}/search-space`) - : buildBackendUrl(`/api/v1/gateway/bindings/${connection.id}/search-space`); + ? `${BACKEND_URL}/api/v1/gateway/accounts/${connection.account_id}/search-space` + : `${BACKEND_URL}/api/v1/gateway/bindings/${connection.id}/search-space`; const res = await authenticatedFetch(url, { method: "PATCH", headers: { "Content-Type": "application/json" }, @@ -244,7 +222,7 @@ export function MessagingChannelsContent() { } async function resume(connection: GatewayConnection) { - await authenticatedFetch(buildBackendUrl(`/api/v1/gateway/bindings/${connection.id}/resume`), { + await authenticatedFetch(`${BACKEND_URL}/api/v1/gateway/bindings/${connection.id}/resume`, { method: "POST", }); await refreshPlatform(connection.platform as GatewayPlatform); @@ -403,21 +381,7 @@ export function MessagingChannelsContent() { <div className="grid items-stretch gap-3 sm:grid-cols-2"> {isGatewayConfigLoading ? renderGatewaySkeletons() : null} - {!isGatewayConfigLoading && gatewayDisabled ? ( - <Alert className="col-span-full" variant="warning"> - <AlertTriangle aria-hidden /> - <AlertTitle>Messaging Channels coming soon</AlertTitle> - <AlertDescription> - <p> - Soon you'll be able to connect WhatsApp, Telegram, Slack, and Discord to your - SurfSense agent so you can ask questions, route messages to search spaces, and get - answers from your knowledge base without leaving your chat app. - </p> - </AlertDescription> - </Alert> - ) : null} - - {!isGatewayConfigLoading && !gatewayDisabled && !hasEnabledGateway ? ( + {!isGatewayConfigLoading && !hasEnabledGateway ? ( <Card className="col-span-full border-accent bg-accent/20"> <CardHeader className="space-y-1.5 p-4"> <CardTitle className="text-sm">No messaging gateways enabled</CardTitle> @@ -425,7 +389,7 @@ export function MessagingChannelsContent() { </Card> ) : null} - {!gatewayDisabled && telegramGatewayEnabled ? ( + {telegramGatewayEnabled ? ( <Card className="order-1 group relative h-full overflow-hidden border-accent bg-accent/20 transition-all duration-200 hover:shadow-md"> <CardHeader className="space-y-1.5 p-4 pb-2"> <div className="flex items-center justify-between gap-3"> @@ -461,7 +425,7 @@ export function MessagingChannelsContent() { </Card> ) : null} - {!gatewayDisabled && slackGatewayEnabled ? ( + {slackGatewayEnabled ? ( <Card className="order-4 group relative h-full overflow-hidden border-accent bg-accent/20 transition-all duration-200 hover:shadow-md"> <CardHeader className="space-y-1.5 p-4 pb-2"> <div className="flex items-center justify-between gap-3"> @@ -493,7 +457,7 @@ export function MessagingChannelsContent() { </Card> ) : null} - {!gatewayDisabled && discordGatewayEnabled ? ( + {discordGatewayEnabled ? ( <Card className="order-3 group relative h-full overflow-hidden border-accent bg-accent/20 transition-all duration-200 hover:shadow-md"> <CardHeader className="space-y-1.5 p-4 pb-2"> <div className="flex items-center justify-between gap-3"> @@ -525,7 +489,7 @@ export function MessagingChannelsContent() { </Card> ) : null} - {!gatewayDisabled && whatsappMode !== "disabled" ? ( + {whatsappMode !== "disabled" ? ( <Card className="order-2 group relative h-full overflow-hidden border-accent bg-accent/20 transition-all duration-200 hover:shadow-md"> <CardHeader className="space-y-1.5 p-4 pb-2"> <div className="flex items-center justify-between gap-3"> diff --git a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PurchaseHistoryContent.tsx b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PurchaseHistoryContent.tsx index 263a286c1..cf73b5eba 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PurchaseHistoryContent.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PurchaseHistoryContent.tsx @@ -13,21 +13,25 @@ import { TableHeader, TableRow, } from "@/components/ui/table"; -import type { CreditPurchase, PagePurchase, PurchaseStatus } from "@/contracts/types/stripe.types"; +import type { + PagePurchase, + PagePurchaseStatus, + TokenPurchase, +} from "@/contracts/types/stripe.types"; import { stripeApiService } from "@/lib/apis/stripe-api.service"; import { cn } from "@/lib/utils"; -type PurchaseKind = "pages" | "credits"; +type PurchaseKind = "pages" | "tokens"; type UnifiedPurchase = { id: string; kind: PurchaseKind; created_at: string; - status: PurchaseStatus; + status: PagePurchaseStatus; /** * Granted units. Interpretation depends on ``kind``: - * - ``"pages"`` — integer number of indexed pages (legacy history). - * - ``"credits"`` — integer micro-USD of credit (1_000_000 = $1.00). + * - ``"pages"`` — integer number of indexed pages. + * - ``"tokens"`` — integer micro-USD of credit (1_000_000 = $1.00). * The ``Granted`` column formats accordingly. */ granted: number; @@ -35,7 +39,7 @@ type UnifiedPurchase = { currency: string | null; }; -const STATUS_STYLES: Record<PurchaseStatus, { label: string; className: string }> = { +const STATUS_STYLES: Record<PagePurchaseStatus, { label: string; className: string }> = { completed: { label: "Completed", className: "bg-emerald-600 text-white border-transparent hover:bg-emerald-600", @@ -59,8 +63,8 @@ const KIND_META: Record< icon: FileText, iconClass: "text-sky-500", }, - credits: { - label: "Credits", + tokens: { + label: "Premium Credit", icon: Coins, iconClass: "text-amber-500", }, @@ -93,10 +97,10 @@ function normalizePagePurchase(p: PagePurchase): UnifiedPurchase { }; } -function normalizeCreditPurchase(p: CreditPurchase): UnifiedPurchase { +function normalizeTokenPurchase(p: TokenPurchase): UnifiedPurchase { return { id: p.id, - kind: "credits", + kind: "tokens", created_at: p.created_at, status: p.status, granted: p.credit_micros_granted, @@ -106,10 +110,10 @@ function normalizeCreditPurchase(p: CreditPurchase): UnifiedPurchase { } function formatGranted(p: UnifiedPurchase): string { - if (p.kind === "credits") { + if (p.kind === "tokens") { const dollars = p.granted / 1_000_000; - // Credit packs are always whole dollars at the moment, but future - // fractional grants (refunds, partial top-ups, auto-reload) shouldn't + // Premium credit packs are always whole dollars at the moment, but + // future fractional grants (refunds, partial top-ups) shouldn't // silently round to "$0". if (dollars >= 1) return `$${dollars.toFixed(2)} of credit`; if (dollars > 0) return `$${dollars.toFixed(3)} of credit`; @@ -123,26 +127,26 @@ export function PurchaseHistoryContent() { queries: [ { queryKey: ["stripe-purchases"], - queryFn: () => stripeApiService.getPagePurchases(), + queryFn: () => stripeApiService.getPurchases(), }, { - queryKey: ["stripe-credit-purchases"], - queryFn: () => stripeApiService.getCreditPurchases(), + queryKey: ["stripe-token-purchases"], + queryFn: () => stripeApiService.getTokenPurchases(), }, ], }); - const [pagesQuery, creditsQuery] = results; - const isLoading = pagesQuery.isLoading || creditsQuery.isLoading; + const [pagesQuery, tokensQuery] = results; + const isLoading = pagesQuery.isLoading || tokensQuery.isLoading; const purchases = useMemo<UnifiedPurchase[]>(() => { const pagePurchases = pagesQuery.data?.purchases ?? []; - const creditPurchases = creditsQuery.data?.purchases ?? []; + const tokenPurchases = tokensQuery.data?.purchases ?? []; return [ ...pagePurchases.map(normalizePagePurchase), - ...creditPurchases.map(normalizeCreditPurchase), + ...tokenPurchases.map(normalizeTokenPurchase), ].sort((a, b) => new Date(b.created_at).getTime() - new Date(a.created_at).getTime()); - }, [pagesQuery.data, creditsQuery.data]); + }, [pagesQuery.data, tokensQuery.data]); if (isLoading) { return ( @@ -158,7 +162,7 @@ export function PurchaseHistoryContent() { <ReceiptText className="h-8 w-8 text-muted-foreground" /> <p className="text-sm font-medium">No purchases yet</p> <p className="text-xs text-muted-foreground"> - Your credit purchases will appear here after checkout. + Your page and premium credit purchases will appear here after checkout. </p> </div> ); diff --git a/surfsense_web/app/dashboard/[search_space_id]/user-settings/purchases/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/user-settings/purchases/page.tsx index 55647fe29..3fa08c278 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/user-settings/purchases/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/user-settings/purchases/page.tsx @@ -1,11 +1,5 @@ -import { AutoReloadSettings } from "@/components/settings/auto-reload-settings"; import { PurchaseHistoryContent } from "../components/PurchaseHistoryContent"; export default function Page() { - return ( - <div className="space-y-6"> - <AutoReloadSettings /> - <PurchaseHistoryContent /> - </div> - ); + return <PurchaseHistoryContent />; } diff --git a/surfsense_web/app/dashboard/dashboard-shell.tsx b/surfsense_web/app/dashboard/dashboard-shell.tsx deleted file mode 100644 index f84cd56eb..000000000 --- a/surfsense_web/app/dashboard/dashboard-shell.tsx +++ /dev/null @@ -1,42 +0,0 @@ -"use client"; - -import { useEffect, useState } from "react"; -import { USER_QUERY_KEY } from "@/atoms/user/user-query.atoms"; -import { useGlobalLoadingEffect } from "@/hooks/use-global-loading"; -import { ensureTokensFromElectron, getBearerToken, redirectToLogin } from "@/lib/auth-utils"; -import { queryClient } from "@/lib/query-client/client"; - -export function DashboardShell({ children }: { children: React.ReactNode }) { - const [isCheckingAuth, setIsCheckingAuth] = useState(true); - - // Use the global loading screen - spinner animation won't reset - useGlobalLoadingEffect(isCheckingAuth); - - useEffect(() => { - async function checkAuth() { - let token = getBearerToken(); - if (!token) { - const synced = await ensureTokensFromElectron(); - if (synced) token = getBearerToken(); - } - if (!token) { - redirectToLogin(); - return; - } - queryClient.invalidateQueries({ queryKey: [...USER_QUERY_KEY] }); - setIsCheckingAuth(false); - } - checkAuth(); - }, []); - - // Return null while loading - the global provider handles the loading UI - if (isCheckingAuth) { - return null; - } - - return ( - <div className="h-full flex flex-col "> - <div className="flex-1 min-h-0">{children}</div> - </div> - ); -} diff --git a/surfsense_web/app/dashboard/layout.tsx b/surfsense_web/app/dashboard/layout.tsx index 6212c92e7..1f5481b15 100644 --- a/surfsense_web/app/dashboard/layout.tsx +++ b/surfsense_web/app/dashboard/layout.tsx @@ -1,14 +1,46 @@ -import { RuntimeConfig } from "@/components/providers/runtime-config.server"; -import { DashboardShell } from "./dashboard-shell"; +"use client"; + +import { useEffect, useState } from "react"; +import { USER_QUERY_KEY } from "@/atoms/user/user-query.atoms"; +import { useGlobalLoadingEffect } from "@/hooks/use-global-loading"; +import { ensureTokensFromElectron, getBearerToken, redirectToLogin } from "@/lib/auth-utils"; +import { queryClient } from "@/lib/query-client/client"; interface DashboardLayoutProps { children: React.ReactNode; } export default function DashboardLayout({ children }: DashboardLayoutProps) { + const [isCheckingAuth, setIsCheckingAuth] = useState(true); + + // Use the global loading screen - spinner animation won't reset + useGlobalLoadingEffect(isCheckingAuth); + + useEffect(() => { + async function checkAuth() { + let token = getBearerToken(); + if (!token) { + const synced = await ensureTokensFromElectron(); + if (synced) token = getBearerToken(); + } + if (!token) { + redirectToLogin(); + return; + } + queryClient.invalidateQueries({ queryKey: [...USER_QUERY_KEY] }); + setIsCheckingAuth(false); + } + checkAuth(); + }, []); + + // Return null while loading - the global provider handles the loading UI + if (isCheckingAuth) { + return null; + } + return ( - <RuntimeConfig> - <DashboardShell>{children}</DashboardShell> - </RuntimeConfig> + <div className="h-full flex flex-col "> + <div className="flex-1 min-h-0">{children}</div> + </div> ); } diff --git a/surfsense_web/app/desktop/login/layout.tsx b/surfsense_web/app/desktop/login/layout.tsx deleted file mode 100644 index 83556d314..000000000 --- a/surfsense_web/app/desktop/login/layout.tsx +++ /dev/null @@ -1,5 +0,0 @@ -import { RuntimeConfig } from "@/components/providers/runtime-config.server"; - -export default function DesktopLoginLayout({ children }: { children: React.ReactNode }) { - return <RuntimeConfig>{children}</RuntimeConfig>; -} diff --git a/surfsense_web/app/desktop/login/page.tsx b/surfsense_web/app/desktop/login/page.tsx index 0d91588e1..41c956f3e 100644 --- a/surfsense_web/app/desktop/login/page.tsx +++ b/surfsense_web/app/desktop/login/page.tsx @@ -8,7 +8,6 @@ import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { toast } from "sonner"; import { loginMutationAtom } from "@/atoms/auth/auth-mutation.atoms"; import { DEFAULT_SHORTCUTS, keyEventToAccelerator } from "@/components/desktop/shortcut-recorder"; -import { useIsGoogleAuth } from "@/components/providers/runtime-config"; import { Button } from "@/components/ui/button"; import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; @@ -18,8 +17,9 @@ import { Spinner } from "@/components/ui/spinner"; import { useElectronAPI } from "@/hooks/use-platform"; import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service"; import { setBearerToken } from "@/lib/auth-utils"; -import { buildBackendUrl } from "@/lib/env-config"; +import { AUTH_TYPE, BACKEND_URL } from "@/lib/env-config"; +const isGoogleAuth = AUTH_TYPE === "GOOGLE"; type ShortcutKey = "generalAssist" | "quickAsk" | "screenshotAssist"; type ShortcutMap = typeof DEFAULT_SHORTCUTS; @@ -189,7 +189,6 @@ function HotkeyRow({ export default function DesktopLoginPage() { const router = useRouter(); const api = useElectronAPI(); - const isGoogleAuth = useIsGoogleAuth(); const [{ mutateAsync: login, isPending: isLoggingIn }] = useAtom(loginMutationAtom); const [email, setEmail] = useState(""); @@ -240,7 +239,7 @@ export default function DesktopLoginPage() { const handleGoogleLogin = () => { if (isGoogleRedirecting) return; setIsGoogleRedirecting(true); - window.location.href = buildBackendUrl("/auth/google/authorize-redirect"); + window.location.href = `${BACKEND_URL}/auth/google/authorize-redirect`; }; const autoSetSearchSpace = async () => { diff --git a/surfsense_web/app/docs/layout.tsx b/surfsense_web/app/docs/layout.tsx index cc5f2118a..9311a45b4 100644 --- a/surfsense_web/app/docs/layout.tsx +++ b/surfsense_web/app/docs/layout.tsx @@ -8,15 +8,12 @@ const gridTemplate = `"sidebar header toc" "sidebar toc-popover toc" "sidebar main toc" 1fr / var(--fd-sidebar-col) minmax(0, 1fr) min-content`; -const docsSurfaceClass = - "bg-main-panel [--color-fd-background:var(--main-panel)] [--color-fd-card:var(--main-panel)] [--color-fd-popover:var(--main-panel)] [--color-fd-muted:var(--main-panel)] [--color-fd-secondary:var(--main-panel)]"; - export default function Layout({ children }: { children: ReactNode }) { return ( <DocsLayout tree={source.pageTree} {...baseOptions} - containerProps={{ style: { gridTemplate }, className: docsSurfaceClass }} + containerProps={{ style: { gridTemplate }, className: "bg-fd-card" }} sidebar={{ components: { Separator: SidebarSeparator, diff --git a/surfsense_web/app/globals.css b/surfsense_web/app/globals.css index 4a29edfa6..3cdb34bff 100644 --- a/surfsense_web/app/globals.css +++ b/surfsense_web/app/globals.css @@ -58,11 +58,6 @@ --highlight: oklch(0.852 0.199 91.936); } -html[data-surfsense-auth-type="GOOGLE"] .runtime-auth-local, -html[data-surfsense-auth-type="LOCAL"] .runtime-auth-google { - display: none; -} - .dark { --background: oklch(0.145 0 0); --foreground: oklch(0.985 0 0); diff --git a/surfsense_web/app/layout.tsx b/surfsense_web/app/layout.tsx index 46182f40e..eef03d463 100644 --- a/surfsense_web/app/layout.tsx +++ b/surfsense_web/app/layout.tsx @@ -2,7 +2,6 @@ import type { Metadata, Viewport } from "next"; import "./globals.css"; import { RootProvider } from "fumadocs-ui/provider/next"; import { Roboto } from "next/font/google"; -import Script from "next/script"; import { AnnouncementToastProvider } from "@/components/announcements/AnnouncementToastProvider"; import { DesktopUpdateToast } from "@/components/desktop/desktop-update-toast"; import { GlobalLoadingProvider } from "@/components/providers/GlobalLoadingProvider"; @@ -17,13 +16,8 @@ import { import { ThemeProvider } from "@/components/theme/theme-provider"; import { Toaster } from "@/components/ui/sonner"; import { LocaleProvider } from "@/contexts/LocaleContext"; -import { BUILD_TIME_AUTH_TYPE } from "@/lib/env-config"; import { PlatformProvider } from "@/contexts/platform-context"; import { ReactQueryClientProvider } from "@/lib/query-client/query-client.provider"; -import { - getRuntimeAuthInitScript, - resolveRuntimeAuthUiMode, -} from "@/lib/runtime-auth-config"; import { cn } from "@/lib/utils"; const roboto = Roboto({ @@ -52,7 +46,7 @@ export const metadata: Metadata = { alternates: { canonical: "https://www.surfsense.com", }, - title: "SurfSense - Open Source, Privacy-Focused NotebookLM Alternative for Teams", + title: "SurfSense – Open Source, Privacy-Focused NotebookLM Alternative for Teams", description: "Open source NotebookLM alternative for teams with no data limits. Use ChatGPT, Claude AI, and any AI model for free.", keywords: [ @@ -94,7 +88,7 @@ export const metadata: Metadata = { "SurfSense", ], openGraph: { - title: "SurfSense - Open Source, Privacy-Focused NotebookLM Alternative for Teams", + title: "SurfSense – Open Source, Privacy-Focused NotebookLM Alternative for Teams", description: "Open source NotebookLM alternative for teams with no data limits. Use ChatGPT, Claude, and any AI model for free.", url: "https://www.surfsense.com", @@ -112,7 +106,7 @@ export const metadata: Metadata = { }, twitter: { card: "summary_large_image", - title: "SurfSense - Open Source, Privacy-Focused NotebookLM Alternative for Teams", + title: "SurfSense – Open Source, Privacy-Focused NotebookLM Alternative for Teams", description: "Open source NotebookLM alternative for teams with no data limits. Use ChatGPT, Claude AI, and any AI model for free.", creator: "@SurfSenseAI", @@ -137,15 +131,8 @@ export default function RootLayout({ // Language can be switched dynamically through LanguageSwitcher component // Locale state is managed by LocaleContext and persisted in localStorage return ( - <html - lang="en" - data-surfsense-auth-type={resolveRuntimeAuthUiMode(BUILD_TIME_AUTH_TYPE)} - suppressHydrationWarning - > + <html lang="en" suppressHydrationWarning> <head> - <Script id="surfsense-runtime-auth-init" strategy="beforeInteractive"> - {getRuntimeAuthInitScript(BUILD_TIME_AUTH_TYPE)} - </Script> <link rel="preconnect" href="https://api.github.com" /> <OrganizationJsonLd /> <WebSiteJsonLd /> diff --git a/surfsense_web/app/sitemap.ts b/surfsense_web/app/sitemap.ts index d0de9fb3e..82ec405f9 100644 --- a/surfsense_web/app/sitemap.ts +++ b/surfsense_web/app/sitemap.ts @@ -1,7 +1,6 @@ import { loader } from "fumadocs-core/source"; import type { MetadataRoute } from "next"; import { blog, changelog } from "@/.source/server"; -import { SERVER_BACKEND_URL } from "@/lib/env-config"; import { source as docsSource } from "@/lib/source"; const blogSource = loader({ @@ -15,10 +14,11 @@ const changelogSource = loader({ }); const BASE_URL = "https://www.surfsense.com"; +const BACKEND_URL = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; async function getFreeModelSlugs(): Promise<string[]> { try { - const res = await fetch(`${SERVER_BACKEND_URL}/api/v1/public/anon-chat/models`, { + const res = await fetch(`${BACKEND_URL}/api/v1/public/anon-chat/models`, { next: { revalidate: 3600 }, }); if (!res.ok) return []; diff --git a/surfsense_web/app/verify-token/route.ts b/surfsense_web/app/verify-token/route.ts index 9df460779..b7ed762de 100644 --- a/surfsense_web/app/verify-token/route.ts +++ b/surfsense_web/app/verify-token/route.ts @@ -1,16 +1,12 @@ import { type NextRequest, NextResponse } from "next/server"; -function getBackendBaseUrl() { - const base = - process.env.SURFSENSE_BACKEND_INTERNAL_URL || - // TODO: Remove FASTAPI_BACKEND_INTERNAL_URL after the post-Caddy env migration window. - process.env.FASTAPI_BACKEND_INTERNAL_URL || - "http://backend:8000"; - return base.replace(/\/+$/, ""); -} +const backendBaseUrl = (process.env.INTERNAL_FASTAPI_BACKEND_URL || "http://backend:8000").replace( + /\/+$/, + "" +); export async function GET(request: NextRequest) { - const response = await fetch(`${getBackendBaseUrl()}/verify-token`, { + const response = await fetch(`${backendBaseUrl}/verify-token`, { method: "GET", headers: { Authorization: request.headers.get("authorization") || "", diff --git a/surfsense_web/atoms/automations/automations-mutation.atoms.ts b/surfsense_web/atoms/automations/automations-mutation.atoms.ts index 288d97c63..a81cd1578 100644 --- a/surfsense_web/atoms/automations/automations-mutation.atoms.ts +++ b/surfsense_web/atoms/automations/automations-mutation.atoms.ts @@ -57,9 +57,9 @@ export const createAutomationMutationAtom = atomWithMutation(() => ({ task_count: variables.definition.plan.length, trigger_type: variables.triggers?.[0]?.type ?? "none", has_schedule: (variables.triggers?.length ?? 0) > 0, - chat_model_id: variables.definition.models?.chat_model_id, - image_gen_model_id: variables.definition.models?.image_gen_model_id, - vision_model_id: variables.definition.models?.vision_model_id, + agent_llm_id: variables.definition.models?.agent_llm_id, + image_generation_config_id: variables.definition.models?.image_generation_config_id, + vision_llm_config_id: variables.definition.models?.vision_llm_config_id, tags_count: variables.definition.metadata?.tags?.length, }); }, diff --git a/surfsense_web/atoms/image-gen-config/image-gen-config-mutation.atoms.ts b/surfsense_web/atoms/image-gen-config/image-gen-config-mutation.atoms.ts new file mode 100644 index 000000000..922c398c9 --- /dev/null +++ b/surfsense_web/atoms/image-gen-config/image-gen-config-mutation.atoms.ts @@ -0,0 +1,96 @@ +import { atomWithMutation } from "jotai-tanstack-query"; +import { toast } from "sonner"; +import type { + CreateImageGenConfigRequest, + CreateImageGenConfigResponse, + DeleteImageGenConfigResponse, + GetImageGenConfigsResponse, + UpdateImageGenConfigRequest, + UpdateImageGenConfigResponse, +} from "@/contracts/types/new-llm-config.types"; +import { imageGenConfigApiService } from "@/lib/apis/image-gen-config-api.service"; +import { cacheKeys } from "@/lib/query-client/cache-keys"; +import { queryClient } from "@/lib/query-client/client"; +import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; + +/** + * Mutation atom for creating a new ImageGenerationConfig + */ +export const createImageGenConfigMutationAtom = atomWithMutation((get) => { + const searchSpaceId = get(activeSearchSpaceIdAtom); + + return { + mutationKey: ["image-gen-configs", "create"], + meta: { suppressGlobalErrorToast: true }, + enabled: !!searchSpaceId, + mutationFn: async (request: CreateImageGenConfigRequest) => { + return imageGenConfigApiService.createConfig(request); + }, + onSuccess: (_: CreateImageGenConfigResponse, request: CreateImageGenConfigRequest) => { + toast.success(`${request.name} created`); + queryClient.invalidateQueries({ + queryKey: cacheKeys.imageGenConfigs.all(Number(searchSpaceId)), + }); + }, + onError: (error: Error) => { + toast.error(error.message || "Failed to create image model"); + }, + }; +}); + +/** + * Mutation atom for updating an existing ImageGenerationConfig + */ +export const updateImageGenConfigMutationAtom = atomWithMutation((get) => { + const searchSpaceId = get(activeSearchSpaceIdAtom); + + return { + mutationKey: ["image-gen-configs", "update"], + meta: { suppressGlobalErrorToast: true }, + enabled: !!searchSpaceId, + mutationFn: async (request: UpdateImageGenConfigRequest) => { + return imageGenConfigApiService.updateConfig(request); + }, + onSuccess: (_: UpdateImageGenConfigResponse, request: UpdateImageGenConfigRequest) => { + toast.success(`${request.data.name ?? "Configuration"} updated`); + queryClient.invalidateQueries({ + queryKey: cacheKeys.imageGenConfigs.all(Number(searchSpaceId)), + }); + queryClient.invalidateQueries({ + queryKey: cacheKeys.imageGenConfigs.byId(request.id), + }); + }, + onError: (error: Error) => { + toast.error(error.message || "Failed to update image model"); + }, + }; +}); + +/** + * Mutation atom for deleting an ImageGenerationConfig + */ +export const deleteImageGenConfigMutationAtom = atomWithMutation((get) => { + const searchSpaceId = get(activeSearchSpaceIdAtom); + + return { + mutationKey: ["image-gen-configs", "delete"], + meta: { suppressGlobalErrorToast: true }, + enabled: !!searchSpaceId, + mutationFn: async (request: { id: number; name: string }) => { + return imageGenConfigApiService.deleteConfig(request.id); + }, + onSuccess: (_: DeleteImageGenConfigResponse, request: { id: number; name: string }) => { + toast.success(`${request.name} deleted`); + queryClient.setQueryData( + cacheKeys.imageGenConfigs.all(Number(searchSpaceId)), + (oldData: GetImageGenConfigsResponse | undefined) => { + if (!oldData) return oldData; + return oldData.filter((config) => config.id !== request.id); + } + ); + }, + onError: (error: Error) => { + toast.error(error.message || "Failed to delete image model"); + }, + }; +}); diff --git a/surfsense_web/atoms/image-gen-config/image-gen-config-query.atoms.ts b/surfsense_web/atoms/image-gen-config/image-gen-config-query.atoms.ts new file mode 100644 index 000000000..a45e69a03 --- /dev/null +++ b/surfsense_web/atoms/image-gen-config/image-gen-config-query.atoms.ts @@ -0,0 +1,33 @@ +import { atomWithQuery } from "jotai-tanstack-query"; +import { imageGenConfigApiService } from "@/lib/apis/image-gen-config-api.service"; +import { cacheKeys } from "@/lib/query-client/cache-keys"; +import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; + +/** + * Query atom for fetching user-created image gen configs for the active search space + */ +export const imageGenConfigsAtom = atomWithQuery((get) => { + const searchSpaceId = get(activeSearchSpaceIdAtom); + + return { + queryKey: cacheKeys.imageGenConfigs.all(Number(searchSpaceId)), + enabled: !!searchSpaceId, + staleTime: 5 * 60 * 1000, // 5 minutes + queryFn: async () => { + return imageGenConfigApiService.getConfigs(Number(searchSpaceId)); + }, + }; +}); + +/** + * Query atom for fetching global image gen configs (from YAML, negative IDs) + */ +export const globalImageGenConfigsAtom = atomWithQuery(() => { + return { + queryKey: cacheKeys.imageGenConfigs.global(), + staleTime: 10 * 60 * 1000, // 10 minutes - global configs rarely change + queryFn: async () => { + return imageGenConfigApiService.getGlobalConfigs(); + }, + }; +}); diff --git a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts deleted file mode 100644 index f00bf76f9..000000000 --- a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts +++ /dev/null @@ -1,214 +0,0 @@ -import { atomWithMutation } from "jotai-tanstack-query"; -import { toast } from "sonner"; -import type { - ConnectionCreateRequest, - ConnectionRead, - ConnectionUpdateRequest, - ModelCreateRequest, - ModelPreviewRead, - ModelRead, - ModelRoles, - ModelsBulkUpdateRequest, - ModelTestPreviewRequest, - ModelUpdateRequest, - VerifyConnectionResponse, -} from "@/contracts/types/model-connections.types"; -import { modelConnectionsApiService } from "@/lib/apis/model-connections-api.service"; -import { cacheKeys } from "@/lib/query-client/cache-keys"; -import { queryClient } from "@/lib/query-client/client"; -import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; - -function invalidateModelConnections(searchSpaceId: number) { - queryClient.invalidateQueries({ - queryKey: cacheKeys.modelConnections.all(searchSpaceId), - }); - queryClient.invalidateQueries({ - queryKey: cacheKeys.modelConnections.roles(searchSpaceId), - }); -} - -function upsertModelConnection(searchSpaceId: number, connection: ConnectionRead) { - queryClient.setQueryData<ConnectionRead[]>( - cacheKeys.modelConnections.all(searchSpaceId), - (current = []) => { - if (current.some((item) => item.id === connection.id)) { - return current.map((item) => (item.id === connection.id ? connection : item)); - } - return [...current, connection]; - } - ); -} - -export const createModelConnectionMutationAtom = atomWithMutation((get) => { - const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); - return { - mutationKey: ["model-connections", "create"], - mutationFn: (request: ConnectionCreateRequest) => - modelConnectionsApiService.createConnection(request), - onSuccess: (connection: ConnectionRead, request: ConnectionCreateRequest) => { - const resolvedSearchSpaceId = Number( - request.search_space_id ?? connection.search_space_id ?? searchSpaceId - ); - toast.success("Connection created"); - if (resolvedSearchSpaceId > 0) { - upsertModelConnection(resolvedSearchSpaceId, connection); - invalidateModelConnections(resolvedSearchSpaceId); - } - }, - onError: (error: Error) => toast.error(error.message || "Failed to create connection"), - }; -}); - -export const updateModelConnectionMutationAtom = atomWithMutation((get) => { - const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); - return { - mutationKey: ["model-connections", "update"], - mutationFn: ({ id, data }: { id: number; data: ConnectionUpdateRequest }) => - modelConnectionsApiService.updateConnection(id, data), - onSuccess: () => { - toast.success("Connection updated"); - invalidateModelConnections(searchSpaceId); - }, - onError: (error: Error) => toast.error(error.message || "Failed to update connection"), - }; -}); - -export const deleteModelConnectionMutationAtom = atomWithMutation((get) => { - const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); - return { - mutationKey: ["model-connections", "delete"], - mutationFn: (id: number) => modelConnectionsApiService.deleteConnection(id), - onSuccess: () => { - toast.success("Connection deleted"); - invalidateModelConnections(searchSpaceId); - }, - onError: (error: Error) => toast.error(error.message || "Failed to delete connection"), - }; -}); - -export const verifyModelConnectionMutationAtom = atomWithMutation((get) => { - const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); - return { - mutationKey: ["model-connections", "verify"], - mutationFn: (id: number) => modelConnectionsApiService.verifyConnection(id), - onSuccess: (result: VerifyConnectionResponse) => { - if (result.ok) { - toast.success("Connection verified"); - } else { - // Non-fatal: many providers lack a /models endpoint yet still serve - // chat. Guide the user to add model IDs manually instead of alarming. - toast.warning( - result.message - ? `${result.message} Chat may still work — add model IDs manually.` - : "Couldn't list models. Chat may still work — add model IDs manually." - ); - } - invalidateModelConnections(searchSpaceId); - }, - onError: (error: Error) => toast.error(error.message || "Failed to verify connection"), - }; -}); - -export const discoverConnectionModelsMutationAtom = atomWithMutation((get) => { - const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); - return { - mutationKey: ["model-connections", "discover"], - mutationFn: (id: number) => modelConnectionsApiService.discoverModels(id), - onSuccess: (models: ModelRead[]) => { - toast.success( - models.length ? `${models.length} models discovered` : "No models found for this connection" - ); - invalidateModelConnections(searchSpaceId); - }, - onError: (error: Error) => toast.error(error.message || "Failed to discover models"), - }; -}); - -export const previewConnectionModelsMutationAtom = atomWithMutation(() => { - return { - mutationKey: ["model-connections", "discover-preview"], - mutationFn: (request: ConnectionCreateRequest) => - modelConnectionsApiService.previewModels(request), - onSuccess: (_models: ModelPreviewRead[]) => {}, - onError: (error: Error) => toast.error(error.message || "Failed to discover models"), - }; -}); - -export const testPreviewModelMutationAtom = atomWithMutation(() => { - return { - mutationKey: ["model-connections", "test-preview"], - mutationFn: (request: ModelTestPreviewRequest) => - modelConnectionsApiService.testPreviewModel(request), - onSuccess: (result: VerifyConnectionResponse) => { - if (!result.ok) { - toast.error(result.message || "Model test failed"); - } - }, - onError: (error: Error) => toast.error(error.message || "Failed to test model"), - }; -}); - -export const addManualModelMutationAtom = atomWithMutation((get) => { - const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); - return { - mutationKey: ["models", "add-manual"], - mutationFn: ({ connectionId, data }: { connectionId: number; data: ModelCreateRequest }) => - modelConnectionsApiService.addManualModel(connectionId, data), - onSuccess: () => { - toast.success("Model added"); - invalidateModelConnections(searchSpaceId); - }, - onError: (error: Error) => toast.error(error.message || "Failed to add model"), - }; -}); - -export const updateModelMutationAtom = atomWithMutation((get) => { - const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); - return { - mutationKey: ["models", "update"], - mutationFn: ({ id, data }: { id: number; data: ModelUpdateRequest }) => - modelConnectionsApiService.updateModel(id, data), - onSuccess: () => invalidateModelConnections(searchSpaceId), - onError: (error: Error) => toast.error(error.message || "Failed to update model"), - }; -}); - -export const bulkUpdateModelsMutationAtom = atomWithMutation((get) => { - const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); - return { - mutationKey: ["models", "bulk-update"], - mutationFn: ({ connectionId, data }: { connectionId: number; data: ModelsBulkUpdateRequest }) => - modelConnectionsApiService.bulkUpdateModels(connectionId, data), - onSuccess: () => invalidateModelConnections(searchSpaceId), - onError: (error: Error) => toast.error(error.message || "Failed to update models"), - }; -}); - -export const testModelMutationAtom = atomWithMutation((get) => { - const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); - return { - mutationKey: ["models", "test"], - mutationFn: (id: number) => modelConnectionsApiService.testModel(id), - onSuccess: (result: VerifyConnectionResponse) => { - if (result.ok) toast.success("Model test succeeded"); - else toast.error(result.message || "Model test failed"); - invalidateModelConnections(searchSpaceId); - }, - onError: (error: Error) => toast.error(error.message || "Failed to test model"), - }; -}); - -export const updateModelRolesMutationAtom = atomWithMutation((get) => { - const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); - return { - mutationKey: ["model-roles", "update"], - mutationFn: (roles: ModelRoles) => - modelConnectionsApiService.updateModelRoles(searchSpaceId, roles), - onSuccess: () => { - queryClient.invalidateQueries({ - queryKey: cacheKeys.modelConnections.roles(searchSpaceId), - }); - }, - onError: (error: Error) => toast.error(error.message || "Failed to update model roles"), - }; -}); diff --git a/surfsense_web/atoms/model-connections/model-connections-query.atoms.ts b/surfsense_web/atoms/model-connections/model-connections-query.atoms.ts deleted file mode 100644 index 04dad9b21..000000000 --- a/surfsense_web/atoms/model-connections/model-connections-query.atoms.ts +++ /dev/null @@ -1,46 +0,0 @@ -import { atomWithQuery } from "jotai-tanstack-query"; -import { modelConnectionsApiService } from "@/lib/apis/model-connections-api.service"; -import { getBearerToken } from "@/lib/auth-utils"; -import { cacheKeys } from "@/lib/query-client/cache-keys"; -import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; - -export const globalModelConnectionsAtom = atomWithQuery(() => ({ - queryKey: cacheKeys.modelConnections.global(), - enabled: !!getBearerToken(), - staleTime: 10 * 60 * 1000, - queryFn: () => modelConnectionsApiService.getGlobalConnections(), -})); - -export const globalLlmConfigStatusAtom = atomWithQuery(() => ({ - queryKey: cacheKeys.modelConnections.globalConfigStatus(), - enabled: !!getBearerToken(), - staleTime: 60 * 60 * 1000, - queryFn: () => modelConnectionsApiService.getGlobalLlmConfigStatus(), -})); - -export const modelProvidersAtom = atomWithQuery(() => ({ - queryKey: cacheKeys.modelConnections.providers(), - enabled: !!getBearerToken(), - staleTime: 60 * 60 * 1000, - queryFn: () => modelConnectionsApiService.getModelProviders(), -})); - -export const modelConnectionsAtom = atomWithQuery((get) => { - const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); - return { - queryKey: cacheKeys.modelConnections.all(searchSpaceId), - enabled: !!searchSpaceId, - staleTime: 5 * 60 * 1000, - queryFn: () => modelConnectionsApiService.getConnections(searchSpaceId), - }; -}); - -export const modelRolesAtom = atomWithQuery((get) => { - const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); - return { - queryKey: cacheKeys.modelConnections.roles(searchSpaceId), - enabled: !!searchSpaceId, - staleTime: 5 * 60 * 1000, - queryFn: () => modelConnectionsApiService.getModelRoles(searchSpaceId), - }; -}); diff --git a/surfsense_web/atoms/new-llm-config/new-llm-config-mutation.atoms.ts b/surfsense_web/atoms/new-llm-config/new-llm-config-mutation.atoms.ts new file mode 100644 index 000000000..476d89d4c --- /dev/null +++ b/surfsense_web/atoms/new-llm-config/new-llm-config-mutation.atoms.ts @@ -0,0 +1,132 @@ +import { atomWithMutation } from "jotai-tanstack-query"; +import { toast } from "sonner"; +import type { + CreateNewLLMConfigRequest, + CreateNewLLMConfigResponse, + DeleteNewLLMConfigRequest, + DeleteNewLLMConfigResponse, + GetNewLLMConfigsResponse, + UpdateLLMPreferencesRequest, + UpdateNewLLMConfigRequest, + UpdateNewLLMConfigResponse, +} from "@/contracts/types/new-llm-config.types"; +import { newLLMConfigApiService } from "@/lib/apis/new-llm-config-api.service"; +import { cacheKeys } from "@/lib/query-client/cache-keys"; +import { queryClient } from "@/lib/query-client/client"; +import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; + +/** + * Mutation atom for creating a new NewLLMConfig + */ +export const createNewLLMConfigMutationAtom = atomWithMutation((get) => { + const searchSpaceId = get(activeSearchSpaceIdAtom); + + return { + mutationKey: ["new-llm-configs", "create"], + meta: { suppressGlobalErrorToast: true }, + enabled: !!searchSpaceId, + mutationFn: async (request: CreateNewLLMConfigRequest) => { + return newLLMConfigApiService.createConfig(request); + }, + onSuccess: (_: CreateNewLLMConfigResponse, request: CreateNewLLMConfigRequest) => { + toast.success(`${request.name} created`); + queryClient.invalidateQueries({ + queryKey: cacheKeys.newLLMConfigs.all(Number(searchSpaceId)), + }); + }, + onError: (error: Error) => { + toast.error(error.message || "Failed to create model"); + }, + }; +}); + +/** + * Mutation atom for updating an existing NewLLMConfig + */ +export const updateNewLLMConfigMutationAtom = atomWithMutation((get) => { + const searchSpaceId = get(activeSearchSpaceIdAtom); + + return { + mutationKey: ["new-llm-configs", "update"], + meta: { suppressGlobalErrorToast: true }, + enabled: !!searchSpaceId, + mutationFn: async (request: UpdateNewLLMConfigRequest) => { + return newLLMConfigApiService.updateConfig(request); + }, + onSuccess: (_: UpdateNewLLMConfigResponse, request: UpdateNewLLMConfigRequest) => { + toast.success(`${request.data.name ?? "Configuration"} updated`); + queryClient.invalidateQueries({ + queryKey: cacheKeys.newLLMConfigs.all(Number(searchSpaceId)), + }); + queryClient.invalidateQueries({ + queryKey: cacheKeys.newLLMConfigs.byId(request.id), + }); + }, + onError: (error: Error) => { + toast.error(error.message || "Failed to update"); + }, + }; +}); + +/** + * Mutation atom for deleting a NewLLMConfig + */ +export const deleteNewLLMConfigMutationAtom = atomWithMutation((get) => { + const searchSpaceId = get(activeSearchSpaceIdAtom); + + return { + mutationKey: ["new-llm-configs", "delete"], + meta: { suppressGlobalErrorToast: true }, + enabled: !!searchSpaceId, + mutationFn: async (request: DeleteNewLLMConfigRequest & { name: string }) => { + return newLLMConfigApiService.deleteConfig({ id: request.id }); + }, + onSuccess: ( + _: DeleteNewLLMConfigResponse, + request: DeleteNewLLMConfigRequest & { name: string } + ) => { + toast.success(`${request.name} deleted`); + queryClient.setQueryData( + cacheKeys.newLLMConfigs.all(Number(searchSpaceId)), + (oldData: GetNewLLMConfigsResponse | undefined) => { + if (!oldData) return oldData; + return oldData.filter((config) => config.id !== request.id); + } + ); + }, + onError: (error: Error) => { + toast.error(error.message || "Failed to delete"); + }, + }; +}); + +/** + * Mutation atom for updating LLM preferences (role assignments) + */ +export const updateLLMPreferencesMutationAtom = atomWithMutation((get) => { + const searchSpaceId = get(activeSearchSpaceIdAtom); + + return { + mutationKey: ["llm-preferences", "update"], + meta: { suppressGlobalErrorToast: true }, + enabled: !!searchSpaceId, + mutationFn: async (request: UpdateLLMPreferencesRequest) => { + return newLLMConfigApiService.updateLLMPreferences(request); + }, + onSuccess: (_data, request: UpdateLLMPreferencesRequest) => { + queryClient.setQueryData( + cacheKeys.newLLMConfigs.preferences(Number(searchSpaceId)), + (old: Record<string, unknown> | undefined) => ({ ...old, ...request.data }) + ); + // Automation eligibility is derived from these model preferences + // (agent/image/vision). Invalidate it so the automations gate alert + // reflects the new selection without a manual refresh. + queryClient.invalidateQueries({ + queryKey: cacheKeys.automations.modelEligibility(Number(searchSpaceId)), + }); + }, + onError: (error: Error) => { + toast.error(error.message || "Failed to update LLM preferences"); + }, + }; +}); diff --git a/surfsense_web/atoms/new-llm-config/new-llm-config-query.atoms.ts b/surfsense_web/atoms/new-llm-config/new-llm-config-query.atoms.ts new file mode 100644 index 000000000..410d061e5 --- /dev/null +++ b/surfsense_web/atoms/new-llm-config/new-llm-config-query.atoms.ts @@ -0,0 +1,98 @@ +import { atomWithQuery } from "jotai-tanstack-query"; +import type { LLMModel } from "@/contracts/enums/llm-models"; +import { LLM_MODELS } from "@/contracts/enums/llm-models"; +import { newLLMConfigApiService } from "@/lib/apis/new-llm-config-api.service"; +import { getBearerToken } from "@/lib/auth-utils"; +import { cacheKeys } from "@/lib/query-client/cache-keys"; +import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; + +/** + * Query atom for fetching all NewLLMConfigs for the active search space + */ +export const newLLMConfigsAtom = atomWithQuery((get) => { + const searchSpaceId = get(activeSearchSpaceIdAtom); + + return { + queryKey: cacheKeys.newLLMConfigs.all(Number(searchSpaceId)), + enabled: !!searchSpaceId, + staleTime: 5 * 60 * 1000, // 5 minutes + queryFn: async () => { + return newLLMConfigApiService.getConfigs({ + search_space_id: Number(searchSpaceId), + }); + }, + }; +}); + +/** + * Query atom for fetching global NewLLMConfigs (from YAML, negative IDs) + */ +export const globalNewLLMConfigsAtom = atomWithQuery(() => { + return { + queryKey: cacheKeys.newLLMConfigs.global(), + staleTime: 10 * 60 * 1000, // 10 minutes - global configs rarely change + enabled: !!getBearerToken(), + queryFn: async () => { + return newLLMConfigApiService.getGlobalConfigs(); + }, + }; +}); + +/** + * Query atom for fetching LLM preferences (role assignments) for the active search space + */ +export const llmPreferencesAtom = atomWithQuery((get) => { + const searchSpaceId = get(activeSearchSpaceIdAtom); + + return { + queryKey: cacheKeys.newLLMConfigs.preferences(Number(searchSpaceId)), + enabled: !!searchSpaceId, + staleTime: 5 * 60 * 1000, // 5 minutes + queryFn: async () => { + return newLLMConfigApiService.getLLMPreferences(Number(searchSpaceId)); + }, + }; +}); + +/** + * Query atom for fetching default system instructions template + */ +export const defaultSystemInstructionsAtom = atomWithQuery(() => { + return { + queryKey: cacheKeys.newLLMConfigs.defaultInstructions(), + staleTime: 60 * 60 * 1000, // 1 hour - this rarely changes + queryFn: async () => { + return newLLMConfigApiService.getDefaultSystemInstructions(); + }, + }; +}); + +/** + * Query atom for the dynamic model catalogue. + * Fetched from the backend (which proxies OpenRouter's public API). + * Falls back to the static hardcoded list on error. + */ +export const modelListAtom = atomWithQuery(() => { + return { + queryKey: cacheKeys.newLLMConfigs.modelList(), + staleTime: 60 * 60 * 1000, // 1 hour - models don't change often + placeholderData: LLM_MODELS, + queryFn: async (): Promise<LLMModel[]> => { + const data = await newLLMConfigApiService.getModels(); + const dynamicModels = data.map((m) => ({ + value: m.value, + label: m.label, + provider: m.provider, + contextWindow: m.context_window ?? undefined, + })); + + // Providers covered by the dynamic API (from OpenRouter mapping). + // For uncovered providers (Ollama, Groq, Bedrock, etc.) keep the + // hand-curated static suggestions so users still see model options. + const coveredProviders = new Set(dynamicModels.map((m) => m.provider)); + const staticFallbacks = LLM_MODELS.filter((m) => !coveredProviders.has(m.provider)); + + return [...dynamicModels, ...staticFallbacks]; + }, + }; +}); diff --git a/surfsense_web/atoms/vision-llm-config/vision-llm-config-mutation.atoms.ts b/surfsense_web/atoms/vision-llm-config/vision-llm-config-mutation.atoms.ts new file mode 100644 index 000000000..f46b977d5 --- /dev/null +++ b/surfsense_web/atoms/vision-llm-config/vision-llm-config-mutation.atoms.ts @@ -0,0 +1,87 @@ +import { atomWithMutation } from "jotai-tanstack-query"; +import { toast } from "sonner"; +import type { + CreateVisionLLMConfigRequest, + CreateVisionLLMConfigResponse, + DeleteVisionLLMConfigResponse, + GetVisionLLMConfigsResponse, + UpdateVisionLLMConfigRequest, + UpdateVisionLLMConfigResponse, +} from "@/contracts/types/new-llm-config.types"; +import { visionLLMConfigApiService } from "@/lib/apis/vision-llm-config-api.service"; +import { cacheKeys } from "@/lib/query-client/cache-keys"; +import { queryClient } from "@/lib/query-client/client"; +import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; + +export const createVisionLLMConfigMutationAtom = atomWithMutation((get) => { + const searchSpaceId = get(activeSearchSpaceIdAtom); + + return { + mutationKey: ["vision-llm-configs", "create"], + meta: { suppressGlobalErrorToast: true }, + enabled: !!searchSpaceId, + mutationFn: async (request: CreateVisionLLMConfigRequest) => { + return visionLLMConfigApiService.createConfig(request); + }, + onSuccess: (_: CreateVisionLLMConfigResponse, request: CreateVisionLLMConfigRequest) => { + toast.success(`${request.name} created`); + queryClient.invalidateQueries({ + queryKey: cacheKeys.visionLLMConfigs.all(Number(searchSpaceId)), + }); + }, + onError: (error: Error) => { + toast.error(error.message || "Failed to create vision model"); + }, + }; +}); + +export const updateVisionLLMConfigMutationAtom = atomWithMutation((get) => { + const searchSpaceId = get(activeSearchSpaceIdAtom); + + return { + mutationKey: ["vision-llm-configs", "update"], + meta: { suppressGlobalErrorToast: true }, + enabled: !!searchSpaceId, + mutationFn: async (request: UpdateVisionLLMConfigRequest) => { + return visionLLMConfigApiService.updateConfig(request); + }, + onSuccess: (_: UpdateVisionLLMConfigResponse, request: UpdateVisionLLMConfigRequest) => { + toast.success(`${request.data.name ?? "Configuration"} updated`); + queryClient.invalidateQueries({ + queryKey: cacheKeys.visionLLMConfigs.all(Number(searchSpaceId)), + }); + queryClient.invalidateQueries({ + queryKey: cacheKeys.visionLLMConfigs.byId(request.id), + }); + }, + onError: (error: Error) => { + toast.error(error.message || "Failed to update vision model"); + }, + }; +}); + +export const deleteVisionLLMConfigMutationAtom = atomWithMutation((get) => { + const searchSpaceId = get(activeSearchSpaceIdAtom); + + return { + mutationKey: ["vision-llm-configs", "delete"], + meta: { suppressGlobalErrorToast: true }, + enabled: !!searchSpaceId, + mutationFn: async (request: { id: number; name: string }) => { + return visionLLMConfigApiService.deleteConfig(request.id); + }, + onSuccess: (_: DeleteVisionLLMConfigResponse, request: { id: number; name: string }) => { + toast.success(`${request.name} deleted`); + queryClient.setQueryData( + cacheKeys.visionLLMConfigs.all(Number(searchSpaceId)), + (oldData: GetVisionLLMConfigsResponse | undefined) => { + if (!oldData) return oldData; + return oldData.filter((config) => config.id !== request.id); + } + ); + }, + onError: (error: Error) => { + toast.error(error.message || "Failed to delete vision model"); + }, + }; +}); diff --git a/surfsense_web/atoms/vision-llm-config/vision-llm-config-query.atoms.ts b/surfsense_web/atoms/vision-llm-config/vision-llm-config-query.atoms.ts new file mode 100644 index 000000000..906ce638f --- /dev/null +++ b/surfsense_web/atoms/vision-llm-config/vision-llm-config-query.atoms.ts @@ -0,0 +1,51 @@ +import { atomWithQuery } from "jotai-tanstack-query"; +import type { LLMModel } from "@/contracts/enums/llm-models"; +import { VISION_MODELS } from "@/contracts/enums/vision-providers"; +import { visionLLMConfigApiService } from "@/lib/apis/vision-llm-config-api.service"; +import { cacheKeys } from "@/lib/query-client/cache-keys"; +import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; + +export const visionLLMConfigsAtom = atomWithQuery((get) => { + const searchSpaceId = get(activeSearchSpaceIdAtom); + + return { + queryKey: cacheKeys.visionLLMConfigs.all(Number(searchSpaceId)), + enabled: !!searchSpaceId, + staleTime: 5 * 60 * 1000, + queryFn: async () => { + return visionLLMConfigApiService.getConfigs(Number(searchSpaceId)); + }, + }; +}); + +export const globalVisionLLMConfigsAtom = atomWithQuery(() => { + return { + queryKey: cacheKeys.visionLLMConfigs.global(), + staleTime: 10 * 60 * 1000, + queryFn: async () => { + return visionLLMConfigApiService.getGlobalConfigs(); + }, + }; +}); + +export const visionModelListAtom = atomWithQuery(() => { + return { + queryKey: cacheKeys.visionLLMConfigs.modelList(), + staleTime: 60 * 60 * 1000, + placeholderData: VISION_MODELS, + queryFn: async (): Promise<LLMModel[]> => { + const data = await visionLLMConfigApiService.getModels(); + const dynamicModels = data.map((m) => ({ + value: m.value, + label: m.label, + provider: m.provider, + contextWindow: m.context_window ?? undefined, + })); + + const coveredProviders = new Set(dynamicModels.map((m) => m.provider)); + const staticFallbacks = VISION_MODELS.filter((m) => !coveredProviders.has(m.provider)); + + return [...dynamicModels, ...staticFallbacks]; + }, + }; +}); diff --git a/surfsense_web/components/agent-action-log/action-log-dialog.tsx b/surfsense_web/components/agent-action-log/action-log-dialog.tsx index 5f3b83db1..1d0eefc17 100644 --- a/surfsense_web/components/agent-action-log/action-log-dialog.tsx +++ b/surfsense_web/components/agent-action-log/action-log-dialog.tsx @@ -2,7 +2,7 @@ import { useQueryClient } from "@tanstack/react-query"; import { useAtom, useAtomValue } from "jotai"; -import { RefreshCw, Workflow } from "lucide-react"; +import { RefreshCcw, Workflow } from "lucide-react"; import { useCallback } from "react"; import { actionLogDialogAtom } from "@/atoms/agent/action-log-dialog.atom"; import { agentFlagsAtom } from "@/atoms/agent/agent-flags-query.atom"; @@ -112,7 +112,7 @@ export function ActionLogDialog() { className="absolute right-14 top-4 size-8 rounded-full p-0 text-muted-foreground hover:bg-accent hover:text-accent-foreground" aria-label="Refresh action log" > - <RefreshCw className={isFetching ? "size-3.5 animate-spin" : "size-3.5"} /> + <RefreshCcw className={isFetching ? "size-3.5 animate-spin" : "size-3.5"} /> </Button> <div className="flex min-h-0 flex-1 flex-col overflow-y-auto scrollbar-thin"> diff --git a/surfsense_web/components/assistant-ui/assistant-message.tsx b/surfsense_web/components/assistant-ui/assistant-message.tsx index 59006b26e..fd24600c2 100644 --- a/surfsense_web/components/assistant-ui/assistant-message.tsx +++ b/surfsense_web/components/assistant-ui/assistant-message.tsx @@ -26,9 +26,9 @@ import type { FC } from "react"; import { useEffect, useMemo, useRef, useState } from "react"; import { commentsEnabledAtom, targetCommentIdAtom } from "@/atoms/chat/current-thread.atom"; import { - globalModelConnectionsAtom, - modelConnectionsAtom, -} from "@/atoms/model-connections/model-connections-query.atoms"; + globalNewLLMConfigsAtom, + newLLMConfigsAtom, +} from "@/atoms/new-llm-config/new-llm-config-query.atoms"; import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms"; import { CitationMetadataProvider, @@ -37,10 +37,7 @@ import { import { MarkdownText } from "@/components/assistant-ui/markdown-text"; import { ReasoningMessagePart } from "@/components/assistant-ui/reasoning-message-part"; import { RevertTurnButton } from "@/components/assistant-ui/revert-turn-button"; -import { - type TokenUsageModelBreakdown, - useTokenUsage, -} from "@/components/assistant-ui/token-usage-context"; +import { useTokenUsage } from "@/components/assistant-ui/token-usage-context"; import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; import { CommentPanelContainer } from "@/components/chat-comments/comment-panel-container/comment-panel-container"; import { CommentSheet } from "@/components/chat-comments/comment-sheet/comment-sheet"; @@ -87,7 +84,7 @@ const GenerateResumeToolUI = dynamic( ); const GeneratePodcastToolUI = dynamic( () => - import("@/components/tool-ui/podcast").then((m) => ({ + import("@/components/tool-ui/generate-podcast").then((m) => ({ default: m.GeneratePodcastToolUI, })), { ssr: false } @@ -271,81 +268,29 @@ function formatTurnCost(micros: number): string { return "$0"; } -function normalizeUsageModelKey(modelKey: string): string { - return modelKey.trim().replace(/^~/, ""); -} - -function bareModelKey(modelKey: string): string { - const normalized = normalizeUsageModelKey(modelKey); - const parts = normalized.split("/"); - return parts[parts.length - 1] || normalized; -} - -function inferProviderFromModelKey(modelKey: string) { - const normalized = normalizeUsageModelKey(modelKey); - const [provider] = normalized.split("/"); - return provider && provider !== normalized ? provider : null; -} - -function titleCaseModelPart(part: string) { - if (!part) return ""; - const upper = part.toUpperCase(); - if (/^\d+(\.\d+)?[BKM]$/.test(upper)) return upper; - if (["gpt", "oai", "api", "llm", "vlm"].includes(part.toLowerCase())) return upper; - return part.charAt(0).toUpperCase() + part.slice(1); -} - -function humanizeModelId(modelKey: string): string { - const bare = bareModelKey(modelKey) - .replace(/:latest$/i, "") - .replace(/[-_]+/g, " ") - .trim(); - if (!bare) return modelKey; - return bare.split(/\s+/).map(titleCaseModelPart).join(" "); -} - const MessageInfoDropdown: FC<{ chatTurnId: string | null | undefined }> = ({ chatTurnId }) => { const messageId = useAuiState(({ message }) => message?.id); const createdAt = useAuiState(({ message }) => message?.createdAt); const usage = useTokenUsage(messageId); - const { data: globalConnections = [] } = useAtomValue(globalModelConnectionsAtom); - const { data: localConnections = [] } = useAtomValue(modelConnectionsAtom); + const { data: localConfigs } = useAtomValue(newLLMConfigsAtom); + const { data: globalConfigs } = useAtomValue(globalNewLLMConfigsAtom); - const modelConnectionByKey = useMemo(() => { - const map = new Map<string, { name: string; provider: string; modelId: string }>(); - for (const connection of [...globalConnections, ...localConnections]) { - for (const model of connection.models) { - const normalizedModelId = normalizeUsageModelKey(model.model_id); - const entry = { - name: model.display_name || model.model_id, - provider: connection.provider, - modelId: model.model_id, - }; - map.set(model.model_id, entry); - map.set(normalizedModelId, entry); - map.set(bareModelKey(model.model_id), entry); - } + const configByModel = useMemo(() => { + const map = new Map<string, { name: string; provider: string }>(); + for (const c of [...(globalConfigs ?? []), ...(localConfigs ?? [])]) { + map.set(c.model_name, { name: c.name, provider: c.provider }); } return map; - }, [globalConnections, localConnections]); + }, [localConfigs, globalConfigs]); - const resolveModel = (modelKey: string, counts: TokenUsageModelBreakdown) => { - const normalizedKey = normalizeUsageModelKey(counts.model_id || counts.model || modelKey); - const connectionModel = - modelConnectionByKey.get(modelKey) ?? - modelConnectionByKey.get(normalizeUsageModelKey(modelKey)) ?? - modelConnectionByKey.get(normalizedKey) ?? - modelConnectionByKey.get(bareModelKey(normalizedKey)); - const provider = - counts.provider || connectionModel?.provider || inferProviderFromModelKey(normalizedKey); - const modelId = counts.model_id || connectionModel?.modelId || modelKey; - const name = counts.display_name || connectionModel?.name || humanizeModelId(modelId); - return { - name, - modelId, - icon: provider ? getProviderIcon(provider, { className: "size-3.5 shrink-0" }) : null, - }; + const resolveModel = (modelKey: string) => { + const parts = modelKey.split("/"); + const bare = parts[parts.length - 1] ?? modelKey; + const config = configByModel.get(modelKey) ?? configByModel.get(bare); + return config + ? { name: config.name, icon: getProviderIcon(config.provider, { className: "size-3.5" }) } + : { name: modelKey, icon: null }; }; const modelBreakdown = usage ? (usage.usage ?? usage.model_breakdown) : undefined; @@ -374,12 +319,12 @@ const MessageInfoDropdown: FC<{ chatTurnId: string | null | undefined }> = ({ ch <ActionBarMorePrimitive.Separator className="bg-popover-border mx-1 my-1 h-px" /> {models.length > 0 ? ( models.map(([model, counts]) => { - const { name, icon } = resolveModel(model, counts); + const { name, icon } = resolveModel(model); const costMicros = counts.cost_micros; return ( <ActionBarMorePrimitive.Item key={model} - className="focus:bg-accent focus:text-accent-foreground relative flex cursor-default flex-col items-start gap-1 rounded-sm px-2 py-1.5 text-sm outline-hidden select-none" + className="focus:bg-accent focus:text-accent-foreground relative flex cursor-default flex-col items-start gap-0.5 rounded-sm px-2 py-1.5 text-sm outline-hidden select-none" onSelect={(e) => e.preventDefault()} > <span className="flex items-center gap-1.5 text-xs font-medium"> diff --git a/surfsense_web/components/assistant-ui/connector-popup/connect-forms/components/obsidian-connect-form.tsx b/surfsense_web/components/assistant-ui/connector-popup/connect-forms/components/obsidian-connect-form.tsx index 695e97d7b..a9231d846 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connect-forms/components/obsidian-connect-form.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connect-forms/components/obsidian-connect-form.tsx @@ -6,6 +6,7 @@ import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; import { Button } from "@/components/ui/button"; import { EnumConnectorName } from "@/contracts/enums/connector"; import { useApiKey } from "@/hooks/use-api-key"; +import { BACKEND_URL } from "@/lib/env-config"; import { getConnectorBenefits } from "../connector-benefits"; import type { ConnectFormProps } from "../index"; diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/circleback-config.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/circleback-config.tsx index 283c052cb..4de8500a6 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/circleback-config.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/circleback-config.tsx @@ -9,7 +9,7 @@ import { Button } from "@/components/ui/button"; import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; import { authenticatedFetch } from "@/lib/auth-utils"; -import { buildBackendUrl } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; import type { ConnectorConfigProps } from "../index"; export interface CirclebackConfigProps extends ConnectorConfigProps { onNameChange?: (name: string) => void; @@ -42,10 +42,17 @@ export const CirclebackConfig: FC<CirclebackConfigProps> = ({ connector, onNameC const doFetch = async () => { if (!connector.search_space_id) return; + const baseUrl = BACKEND_URL; + if (!baseUrl) { + console.error("NEXT_PUBLIC_FASTAPI_BACKEND_URL is not configured"); + setIsLoading(false); + return; + } + setIsLoading(true); try { const response = await authenticatedFetch( - buildBackendUrl(`/api/v1/webhooks/circleback/${connector.search_space_id}/info`), + `${baseUrl}/api/v1/webhooks/circleback/${connector.search_space_id}/info`, { signal: controller.signal } ); if (controller.signal.aborted) return; diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx index 1fc555471..011eeec96 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx @@ -13,7 +13,7 @@ import { getConnectorIcon } from "@/contracts/enums/connectorIcons"; import type { SearchSourceConnector } from "@/contracts/types/connector.types"; import { authenticatedFetch } from "@/lib/auth-utils"; import { getReauthEndpoint } from "@/lib/connector-telemetry"; -import { buildBackendUrl } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; import { cn } from "@/lib/utils"; import { DateRangeSelector } from "../../components/date-range-selector"; import { PeriodicSyncConfig } from "../../components/periodic-sync-config"; @@ -95,13 +95,12 @@ export const ConnectorEditView: FC<ConnectorEditViewProps> = ({ if (!spaceId || !reauthEndpoint) return; setReauthing(true); try { - const response = await authenticatedFetch( - buildBackendUrl(reauthEndpoint, { - connector_id: connector.id, - space_id: spaceId, - return_url: window.location.pathname, - }) - ); + const backendUrl = BACKEND_URL; + const url = new URL(`${backendUrl}${reauthEndpoint}`); + url.searchParams.set("connector_id", String(connector.id)); + url.searchParams.set("space_id", String(spaceId)); + url.searchParams.set("return_url", window.location.pathname); + const response = await authenticatedFetch(url.toString()); if (!response.ok) { const data = await response.json().catch(() => ({})); toast.error(data.detail ?? "Failed to initiate re-authentication."); diff --git a/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts b/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts index 2f10152b8..45c174d74 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts +++ b/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts @@ -16,7 +16,7 @@ import type { SearchSourceConnector } from "@/contracts/types/connector.types"; import { searchSourceConnector } from "@/contracts/types/connector.types"; import { OAUTH_RESULT_COOKIE, parseOAuthCallbackResult } from "@/contracts/types/oauth.types"; import { authenticatedFetch } from "@/lib/auth-utils"; -import { buildBackendUrl } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; import { trackConnectorConnected, trackConnectorDeleted, @@ -351,7 +351,9 @@ export const useConnectorDialog = () => { trackConnectorSetupStarted(Number(searchSpaceId), connector.connectorType, "oauth_click"); try { - const url = buildBackendUrl(connector.authEndpoint, { space_id: searchSpaceId }); + // Check if authEndpoint already has query parameters + const separator = connector.authEndpoint.includes("?") ? "&" : "?"; + const url = `${BACKEND_URL}${connector.authEndpoint}${separator}space_id=${searchSpaceId}`; const response = await authenticatedFetch(url, { method: "GET" }); diff --git a/surfsense_web/components/assistant-ui/connector-popup/tabs/all-connectors-tab.tsx b/surfsense_web/components/assistant-ui/connector-popup/tabs/all-connectors-tab.tsx index f7b27441b..4977219f7 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/tabs/all-connectors-tab.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/tabs/all-connectors-tab.tsx @@ -2,10 +2,10 @@ import { Search } from "lucide-react"; import type { FC } from "react"; -import { useIsSelfHosted } from "@/components/providers/runtime-config"; import { EnumConnectorName } from "@/contracts/enums/connector"; import type { SearchSourceConnector } from "@/contracts/types/connector.types"; import { usePlatform } from "@/hooks/use-platform"; +import { isSelfHosted } from "@/lib/env-config"; import { ConnectorCard } from "../components/connector-card"; import { COMPOSIO_CONNECTORS, @@ -22,11 +22,6 @@ type OAuthConnector = (typeof OAUTH_CONNECTORS)[number]; type ComposioConnector = (typeof COMPOSIO_CONNECTORS)[number]; type OtherConnector = (typeof OTHER_CONNECTORS)[number]; type CrawlerConnector = (typeof CRAWLERS)[number]; -type DeploymentFilterableConnector = { - readonly id: string; - readonly selfHostedOnly?: boolean; - readonly desktopOnly?: boolean; -}; /** * Extract the display name from a full connector name. @@ -71,14 +66,14 @@ export const AllConnectorsTab: FC<AllConnectorsTabProps> = ({ onManage, onViewAccountsList, }) => { - const selfHosted = useIsSelfHosted(); + const selfHosted = isSelfHosted(); const { isDesktop } = usePlatform(); const matchesSearch = (title: string, description: string) => title.toLowerCase().includes(searchQuery.toLowerCase()) || description.toLowerCase().includes(searchQuery.toLowerCase()); - const passesDeploymentFilter = (c: DeploymentFilterableConnector) => + const passesDeploymentFilter = (c: { selfHostedOnly?: boolean; desktopOnly?: boolean }) => (!c.selfHostedOnly || selfHosted) && (!c.desktopOnly || isDesktop); // Filter connectors based on search and deployment mode diff --git a/surfsense_web/components/assistant-ui/connector-popup/views/connector-accounts-list-view.tsx b/surfsense_web/components/assistant-ui/connector-popup/views/connector-accounts-list-view.tsx index f53537cdc..05b684397 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/views/connector-accounts-list-view.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/views/connector-accounts-list-view.tsx @@ -12,7 +12,7 @@ import { getConnectorIcon } from "@/contracts/enums/connectorIcons"; import type { SearchSourceConnector } from "@/contracts/types/connector.types"; import { authenticatedFetch } from "@/lib/auth-utils"; import { getReauthEndpoint } from "@/lib/connector-telemetry"; -import { buildBackendUrl } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; import { formatRelativeDate } from "@/lib/format-date"; import { cn } from "@/lib/utils"; import { LIVE_CONNECTOR_TYPES } from "../constants/connector-constants"; @@ -61,13 +61,12 @@ export const ConnectorAccountsListView: FC<ConnectorAccountsListViewProps> = ({ if (!searchSpaceId || !endpoint) return; setReauthingId(connector.id); try { - const response = await authenticatedFetch( - buildBackendUrl(endpoint, { - connector_id: connector.id, - space_id: searchSpaceId, - return_url: window.location.pathname, - }) - ); + const backendUrl = BACKEND_URL; + const url = new URL(`${backendUrl}${endpoint}`); + url.searchParams.set("connector_id", String(connector.id)); + url.searchParams.set("space_id", String(searchSpaceId)); + url.searchParams.set("return_url", window.location.pathname); + const response = await authenticatedFetch(url.toString()); if (!response.ok) { const data = await response.json().catch(() => ({})); toast.error(data.detail ?? "Failed to initiate re-authentication."); diff --git a/surfsense_web/components/assistant-ui/markdown-text.tsx b/surfsense_web/components/assistant-ui/markdown-text.tsx index cd32ca920..ee36e8499 100644 --- a/surfsense_web/components/assistant-ui/markdown-text.tsx +++ b/surfsense_web/components/assistant-ui/markdown-text.tsx @@ -110,7 +110,7 @@ const MarkdownTextImpl = () => { return ( <CitationUrlMapContext.Provider value={urlMapRef}> <MarkdownTextPrimitive - smooth + smooth={false} remarkPlugins={[remarkGfm, [remarkMath, { singleDollarTextMath: false }]]} rehypePlugins={[rehypeKatex]} className="aui-md" diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index c8da125f4..5796109f0 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -48,10 +48,10 @@ import { connectorDialogOpenAtom } from "@/atoms/connector-dialog/connector-dial import { connectorsAtom } from "@/atoms/connectors/connector-query.atoms"; import { membersAtom } from "@/atoms/members/members-query.atoms"; import { - globalModelConnectionsAtom, - modelConnectionsAtom, - modelRolesAtom, -} from "@/atoms/model-connections/model-connections-query.atoms"; + globalNewLLMConfigsAtom, + llmPreferencesAtom, + newLLMConfigsAtom, +} from "@/atoms/new-llm-config/new-llm-config-query.atoms"; import { currentUserAtom } from "@/atoms/user/user-query.atoms"; import { AssistantMessage } from "@/components/assistant-ui/assistant-message"; import { ChatSessionStatus } from "@/components/assistant-ui/chat-session-status"; @@ -68,7 +68,6 @@ import { import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; import { UserMessage } from "@/components/assistant-ui/user-message"; import { ChatExamplePrompts } from "@/components/new-chat/chat-example-prompts"; -import { ChatHeader } from "@/components/new-chat/chat-header"; import { ComposerSuggestionPopoverContent } from "@/components/new-chat/composer-suggestion-popup"; import { PromptPicker, type PromptPickerRef } from "@/components/new-chat/prompt-picker"; import { Avatar, AvatarFallback, AvatarGroup } from "@/components/ui/avatar"; @@ -145,15 +144,11 @@ function getComposerSuggestionAnchorPoint( }; } -interface ThreadProps { - hasActiveThread?: boolean; -} - -export const Thread: FC<ThreadProps> = ({ hasActiveThread = false }) => { - return <ThreadContent hasActiveThread={hasActiveThread} />; +export const Thread: FC = () => { + return <ThreadContent />; }; -const ThreadContent: FC<ThreadProps> = ({ hasActiveThread = false }) => { +const ThreadContent: FC = () => { return ( <ThreadPrimitive.Root className="aui-root aui-thread-root @container flex h-full min-h-0 flex-col bg-main-panel" @@ -163,13 +158,13 @@ const ThreadContent: FC<ThreadProps> = ({ hasActiveThread = false }) => { > <ChatViewport footer={ - <AuiIf condition={({ thread }) => hasActiveThread || !thread.isEmpty}> + <AuiIf condition={({ thread }) => !thread.isEmpty}> <PremiumQuotaPinnedAlert /> <Composer /> </AuiIf> } > - <AuiIf condition={({ thread }) => !hasActiveThread && thread.isEmpty}> + <AuiIf condition={({ thread }) => thread.isEmpty}> <ThreadWelcome /> </AuiIf> @@ -522,11 +517,6 @@ const Composer: FC = () => { editorRef.current?.focus(); }, [isDesktop, showDocumentPopover, showPromptPicker, threadId]); - const handleChatModelSelected = useCallback(() => { - if (!isDesktop) return; - editorRef.current?.focus(); - }, [isDesktop]); - // Close document picker when a sidebar slide-out panel (inbox, etc.) opens. // React only on changes to the tick — comparing against the previously-seen // value preserves the one-shot semantics of the prior window-event approach @@ -937,11 +927,7 @@ const Composer: FC = () => { className="min-h-[48px] sm:min-h-[24px] **:data-slate-placeholder:font-normal" /> </div> - <ComposerAction - isBlockedByOtherUser={isBlockedByOtherUser} - searchSpaceId={Number(search_space_id)} - onChatModelSelected={handleChatModelSelected} - /> + <ComposerAction isBlockedByOtherUser={isBlockedByOtherUser} /> <ConnectorIndicator showTrigger={false} /> </div> <ConnectToolsBanner @@ -960,15 +946,9 @@ const Composer: FC = () => { interface ComposerActionProps { isBlockedByOtherUser?: boolean; - searchSpaceId: number; - onChatModelSelected?: () => void; } -const ComposerAction: FC<ComposerActionProps> = ({ - isBlockedByOtherUser = false, - searchSpaceId, - onChatModelSelected, -}) => { +const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false }) => { const mentionedDocuments = useAtomValue(mentionedDocumentsAtom); const setConnectorDialogOpen = useSetAtom(connectorDialogOpenAtom); const [toolsPopoverOpen, setToolsPopoverOpen] = useState(false); @@ -996,9 +976,9 @@ const ComposerAction: FC<ComposerActionProps> = ({ if (url) setPendingScreenImages((prev) => [...prev, url]); }, [electronAPI, setPendingScreenImages]); - const { data: globalModelConnections } = useAtomValue(globalModelConnectionsAtom); - const { data: modelConnections } = useAtomValue(modelConnectionsAtom); - const { data: modelRoles } = useAtomValue(modelRolesAtom); + const { data: userConfigs } = useAtomValue(newLLMConfigsAtom); + const { data: globalConfigs } = useAtomValue(globalNewLLMConfigsAtom); + const { data: preferences } = useAtomValue(llmPreferencesAtom); const { data: agentTools } = useAtomValue(agentToolsAtom); const disabledTools = useAtomValue(disabledToolsAtom); @@ -1085,18 +1065,15 @@ const ComposerAction: FC<ComposerActionProps> = ({ }, [hydrateDisabled]); const hasModelConfigured = useMemo(() => { - const chatModelId = modelRoles?.chat_model_id ?? 0; - if (chatModelId === 0) { - return [...(globalModelConnections ?? []), ...(modelConnections ?? [])].some((connection) => - connection.models.some((model) => model.enabled && Boolean(model.supports_chat)) - ); + if (!preferences) return false; + const agentLlmId = preferences.agent_llm_id; + if (agentLlmId === null || agentLlmId === undefined) return false; + + if (agentLlmId <= 0) { + return globalConfigs?.some((c) => c.id === agentLlmId) ?? false; } - return [...(globalModelConnections ?? []), ...(modelConnections ?? [])].some((connection) => - connection.models.some( - (model) => model.id === chatModelId && model.enabled && Boolean(model.supports_chat) - ) - ); - }, [modelRoles?.chat_model_id, globalModelConnections, modelConnections]); + return userConfigs?.some((c) => c.id === agentLlmId) ?? false; + }, [preferences, globalConfigs, userConfigs]); const isSendDisabled = isComposerEmpty || !hasModelConfigured || isBlockedByOtherUser; @@ -1577,12 +1554,7 @@ const ComposerAction: FC<ComposerActionProps> = ({ <span>Select a model</span> </div> )} - <div className="ml-auto flex min-w-0 shrink-0 items-center gap-2"> - <ChatHeader - searchSpaceId={searchSpaceId} - className="h-9 max-w-[44vw] px-2 sm:max-w-[220px] sm:px-3" - onChatModelSelected={onChatModelSelected} - /> + <div className="flex items-center gap-2"> <AuiIf condition={({ thread }) => !thread.isRunning}> <ComposerPrimitive.Send asChild disabled={isSendDisabled}> <TooltipIconButton @@ -1590,7 +1562,7 @@ const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser ? "Wait for AI to finish responding" : !hasModelConfigured - ? "Please select a model to start chatting" + ? "Please select a model from the header to start chatting" : isComposerEmpty ? "Enter a message or add a screenshot to send" : "Send message" @@ -1600,7 +1572,7 @@ const ComposerAction: FC<ComposerActionProps> = ({ variant="default" size="icon" className={cn( - "aui-composer-send size-9 shrink-0 rounded-full", + "aui-composer-send size-9 rounded-full", isSendDisabled && "cursor-not-allowed opacity-50" )} aria-label="Send message" @@ -1617,7 +1589,7 @@ const ComposerAction: FC<ComposerActionProps> = ({ type="button" variant="default" size="icon" - className="aui-composer-cancel size-9 shrink-0 rounded-full" + className="aui-composer-cancel size-9 rounded-full" aria-label="Stop generating" > <SquareIcon className="aui-composer-cancel-icon size-3.5 fill-current" /> diff --git a/surfsense_web/components/assistant-ui/token-usage-context.tsx b/surfsense_web/components/assistant-ui/token-usage-context.tsx index 8db8c2b50..dd80bcac3 100644 --- a/surfsense_web/components/assistant-ui/token-usage-context.tsx +++ b/surfsense_web/components/assistant-ui/token-usage-context.tsx @@ -9,18 +9,6 @@ import { useSyncExternalStore, } from "react"; -export interface TokenUsageModelBreakdown { - prompt_tokens: number; - completion_tokens: number; - total_tokens: number; - cost_micros?: number; - model?: string | null; - model_ref?: string | null; - model_id?: string | null; - display_name?: string | null; - provider?: string | null; -} - export interface TokenUsageData { prompt_tokens: number; completion_tokens: number; @@ -32,8 +20,24 @@ export interface TokenUsageData { * before the migration won't have it. */ cost_micros?: number; - usage?: Record<string, TokenUsageModelBreakdown>; - model_breakdown?: Record<string, TokenUsageModelBreakdown>; + usage?: Record< + string, + { + prompt_tokens: number; + completion_tokens: number; + total_tokens: number; + cost_micros?: number; + } + >; + model_breakdown?: Record< + string, + { + prompt_tokens: number; + completion_tokens: number; + total_tokens: number; + cost_micros?: number; + } + >; } type Listener = () => void; diff --git a/surfsense_web/components/auth/sign-in-button.tsx b/surfsense_web/components/auth/sign-in-button.tsx index d0a563a54..7f5a77f36 100644 --- a/surfsense_web/components/auth/sign-in-button.tsx +++ b/surfsense_web/components/auth/sign-in-button.tsx @@ -3,7 +3,7 @@ import Link from "next/link"; import { useState } from "react"; import { Button } from "@/components/ui/button"; -import { buildBackendUrl } from "@/lib/env-config"; +import { AUTH_TYPE, BACKEND_URL } from "@/lib/env-config"; import { trackLoginAttempt } from "@/lib/posthog/events"; import { cn } from "@/lib/utils"; @@ -46,54 +46,54 @@ interface SignInButtonProps { } export const SignInButton = ({ variant = "desktop" }: SignInButtonProps) => { + const isGoogleAuth = AUTH_TYPE === "GOOGLE"; const [isRedirecting, setIsRedirecting] = useState(false); const handleGoogleLogin = () => { if (isRedirecting) return; setIsRedirecting(true); trackLoginAttempt("google"); - window.location.href = buildBackendUrl("/auth/google/authorize-redirect"); + window.location.href = `${BACKEND_URL}/auth/google/authorize-redirect`; }; - const getGoogleClassName = () => { + const getClassName = () => { if (variant === "desktop") { - return "hidden rounded-full border border-white bg-white px-5 py-2 text-sm font-medium text-[#1f1f1f] shadow-sm hover:bg-zinc-100 hover:text-[#1f1f1f] md:flex dark:border-white"; + return isGoogleAuth + ? "hidden rounded-full border border-white bg-white px-5 py-2 text-sm font-medium text-[#1f1f1f] shadow-sm hover:bg-zinc-100 hover:text-[#1f1f1f] md:flex dark:border-white" + : "hidden rounded-full bg-black px-8 py-2 text-sm font-bold text-white shadow-[0px_-2px_0px_0px_rgba(255,255,255,0.4)_inset] md:block dark:bg-white dark:text-black"; } if (variant === "compact") { - return "rounded-full border border-white bg-white px-4 py-1.5 text-sm font-medium text-[#1f1f1f] shadow-sm hover:bg-zinc-100 hover:text-[#1f1f1f] dark:border-white"; + return isGoogleAuth + ? "rounded-full border border-white bg-white px-4 py-1.5 text-sm font-medium text-[#1f1f1f] shadow-sm hover:bg-zinc-100 hover:text-[#1f1f1f] dark:border-white" + : "rounded-full bg-black px-6 py-1.5 text-sm font-bold text-white shadow-[0px_-2px_0px_0px_rgba(255,255,255,0.4)_inset] dark:bg-white dark:text-black"; } // mobile - return "w-full rounded-lg border border-white bg-white px-8 py-2.5 font-medium text-[#1f1f1f] shadow-sm hover:bg-zinc-100 hover:text-[#1f1f1f] dark:border-white touch-manipulation"; + return isGoogleAuth + ? "w-full rounded-lg border border-white bg-white px-8 py-2.5 font-medium text-[#1f1f1f] shadow-sm hover:bg-zinc-100 hover:text-[#1f1f1f] dark:border-white touch-manipulation" + : "w-full rounded-lg bg-black px-8 py-2 font-medium text-white shadow-[0px_-2px_0px_0px_rgba(255,255,255,0.4)_inset] dark:bg-white dark:text-black text-center touch-manipulation"; }; - const getLocalClassName = () => { - if (variant === "desktop") { - return "hidden rounded-full bg-black px-8 py-2 text-sm font-bold text-white shadow-[0px_-2px_0px_0px_rgba(255,255,255,0.4)_inset] md:block dark:bg-white dark:text-black"; - } - if (variant === "compact") { - return "rounded-full bg-black px-6 py-1.5 text-sm font-bold text-white shadow-[0px_-2px_0px_0px_rgba(255,255,255,0.4)_inset] dark:bg-white dark:text-black"; - } - return "w-full rounded-lg bg-black px-8 py-2 font-medium text-white shadow-[0px_-2px_0px_0px_rgba(255,255,255,0.4)_inset] dark:bg-white dark:text-black text-center touch-manipulation"; - }; - - return ( - <> + if (isGoogleAuth) { + return ( <Button type="button" variant="ghost" onClick={handleGoogleLogin} disabled={isRedirecting} className={cn( - "runtime-auth-google flex items-center justify-center gap-2 transition-colors duration-200 disabled:cursor-not-allowed disabled:opacity-50", - getGoogleClassName() + "flex items-center justify-center gap-2 transition-colors duration-200 disabled:cursor-not-allowed disabled:opacity-50", + getClassName() )} > <GoogleLogo className="h-4 w-4" /> <span>Sign In</span> </Button> - <Link href="/login" className={cn("runtime-auth-local", getLocalClassName())}> - Sign In - </Link> - </> + ); + } + + return ( + <Link href="/login" className={getClassName()}> + Sign In + </Link> ); }; diff --git a/surfsense_web/components/documents/download-original-button.tsx b/surfsense_web/components/documents/download-original-button.tsx index e04ead89a..b79b289b4 100644 --- a/surfsense_web/components/documents/download-original-button.tsx +++ b/surfsense_web/components/documents/download-original-button.tsx @@ -7,7 +7,7 @@ import { Button } from "@/components/ui/button"; import { Spinner } from "@/components/ui/spinner"; import { documentsApiService } from "@/lib/apis/documents-api.service"; import { authenticatedFetch } from "@/lib/auth-utils"; -import { buildBackendUrl } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; interface DownloadOriginalButtonProps { documentId: number; @@ -41,7 +41,7 @@ export function DownloadOriginalButton({ documentId }: DownloadOriginalButtonPro setDownloading(true); try { const response = await authenticatedFetch( - buildBackendUrl(`/api/v1/documents/${documentId}/download-original`), + `${BACKEND_URL}/api/v1/documents/${documentId}/download-original`, { method: "GET" } ); if (!response.ok) throw new Error("Download failed"); diff --git a/surfsense_web/components/editor-panel/editor-panel.tsx b/surfsense_web/components/editor-panel/editor-panel.tsx index 75283c81f..01983cbe1 100644 --- a/surfsense_web/components/editor-panel/editor-panel.tsx +++ b/surfsense_web/components/editor-panel/editor-panel.tsx @@ -17,7 +17,6 @@ import { toast } from "sonner"; import { closeEditorPanelAtom, editorPanelAtom } from "@/atoms/editor/editor-panel.atom"; import { DownloadOriginalButton } from "@/components/documents/download-original-button"; import { VersionHistoryButton } from "@/components/documents/version-history"; -import { PlateErrorBoundary } from "@/components/editor/plate-error-boundary"; import { SourceCodeEditor } from "@/components/editor/source-code-editor"; import { fetchMemoryEditorDocument, @@ -35,15 +34,14 @@ import { useMediaQuery } from "@/hooks/use-media-query"; import { useElectronAPI } from "@/hooks/use-platform"; import { authenticatedFetch, getBearerToken, redirectToLogin } from "@/lib/auth-utils"; import { inferMonacoLanguageFromPath } from "@/lib/editor-language"; -import { buildBackendUrl } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; const PlateEditor = dynamic( () => import("@/components/editor/plate-editor").then((m) => ({ default: m.PlateEditor })), { ssr: false, loading: () => <EditorPanelSkeleton /> } ); -const LARGE_DOCUMENT_THRESHOLD = 1 * 1024 * 1024; // 1MB, matches backend -const LARGE_DOCUMENT_LINE_THRESHOLD = 5000; +const LARGE_DOCUMENT_THRESHOLD = 2 * 1024 * 1024; // 2MB interface EditorContent { document_id: number; @@ -51,11 +49,9 @@ interface EditorContent { document_type?: string; source_markdown: string; content_size_bytes?: number; - line_count?: number; chunk_count?: number; viewer_mode?: ViewerMode; editor_plate_max_bytes?: number; - editor_plate_max_lines?: number; } const EDITABLE_DOCUMENT_TYPES = new Set(["FILE", "NOTE"]); @@ -122,15 +118,6 @@ function getUtf8ByteSize(value: string): number { return new TextEncoder().encode(value).byteLength; } -function countLines(value: string): number { - if (!value) return 0; - let count = 1; - for (let i = 0; i < value.length; i++) { - if (value.charCodeAt(i) === 10) count++; - } - return count; -} - function formatBytes(bytes: number): string { if (bytes >= 1024 * 1024) { return `${(bytes / 1024 / 1024).toFixed(1)}MB`; @@ -197,17 +184,10 @@ export function EditorPanelContent({ ); const plateMaxBytes = editorDoc?.editor_plate_max_bytes ?? LARGE_DOCUMENT_THRESHOLD; - const plateMaxLines = editorDoc?.editor_plate_max_lines ?? LARGE_DOCUMENT_LINE_THRESHOLD; - const docSizeBytes = editorDoc?.content_size_bytes ?? 0; - const docLineCount = - editorDoc?.line_count ?? - (editorDoc?.source_markdown ? countLines(editorDoc.source_markdown) : 0); - const isLargeDocument = docSizeBytes > plateMaxBytes || docLineCount > plateMaxLines; + const isLargeDocument = (editorDoc?.content_size_bytes ?? 0) > plateMaxBytes; const viewerMode: ViewerMode = isMemoryMode ? "plate" - : editorDoc?.viewer_mode === "monaco" || isLargeDocument - ? "monaco" - : "plate"; + : (editorDoc?.viewer_mode ?? (isLargeDocument ? "monaco" : "plate")); useEffect(() => { const controller = new AbortController(); @@ -280,12 +260,10 @@ export function EditorPanelContent({ return; } - const response = await authenticatedFetch( - buildBackendUrl( - `/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/editor-content` - ), - { method: "GET" } + const url = new URL( + `${BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/editor-content` ); + const response = await authenticatedFetch(url.toString(), { method: "GET" }); if (controller.signal.aborted) return; @@ -424,7 +402,7 @@ export function EditorPanelContent({ return; } const response = await authenticatedFetch( - buildBackendUrl(`/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/save`), + `${BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/save`, { method: "POST", headers: { "Content-Type": "application/json" }, @@ -443,8 +421,7 @@ export function EditorPanelContent({ setEditedMarkdown(null); if (!options?.silent) { const savedSizeBytes = getUtf8ByteSize(markdownRef.current); - const savedLineCount = countLines(markdownRef.current); - if (savedSizeBytes > plateMaxBytes || savedLineCount > plateMaxLines) { + if (savedSizeBytes > plateMaxBytes) { toast.success("Document saved. It will reopen in raw markdown mode."); } else { toast.success("Document saved! Reindexing in background..."); @@ -470,7 +447,6 @@ export function EditorPanelContent({ memoryLimits, memoryScope, plateMaxBytes, - plateMaxLines, resolveLocalVirtualPath, searchSpaceId, ] @@ -491,12 +467,8 @@ export function EditorPanelContent({ const localFileLanguage = inferMonacoLanguageFromPath(localFilePath); const activeMarkdown = editedMarkdown ?? editorDoc?.source_markdown ?? ""; const activeMarkdownSizeBytes = useMemo(() => getUtf8ByteSize(activeMarkdown), [activeMarkdown]); - const activeMarkdownLineCount = useMemo(() => countLines(activeMarkdown), [activeMarkdown]); - const isNearPlateLimit = - activeMarkdownSizeBytes >= plateMaxBytes * 0.9 || - activeMarkdownLineCount >= plateMaxLines * 0.9; - const isOverPlateLimit = - activeMarkdownSizeBytes > plateMaxBytes || activeMarkdownLineCount > plateMaxLines; + const isNearPlateLimit = activeMarkdownSizeBytes >= plateMaxBytes * 0.9; + const isOverPlateLimit = activeMarkdownSizeBytes > plateMaxBytes; const showPlateSizeWarning = showEditingActions && !isMemoryMode && !isLocalFileMode && isNearPlateLimit; const memoryLimitState = isMemoryMode @@ -509,13 +481,6 @@ export function EditorPanelContent({ ? "text-orange-500" : "text-muted-foreground"; const saveDisabled = saving || !hasUnsavedChanges || (memoryLimitState?.isOverLimit ?? false); - const editorInstanceKey = `${ - isMemoryMode - ? `memory-${memoryScope ?? "user"}` - : isLocalFileMode - ? (localFilePath ?? "local-file") - : documentId - }-${isEditing ? "editing" : "viewing"}`; const handleCancelEditing = useCallback(() => { const savedContent = editorDoc?.source_markdown ?? ""; @@ -531,9 +496,7 @@ export function EditorPanelContent({ setDownloading(true); try { const response = await authenticatedFetch( - buildBackendUrl( - `/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/download-markdown` - ), + `${BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/download-markdown`, { method: "GET" } ); if (!response.ok) throw new Error("Download failed"); @@ -562,7 +525,7 @@ export function EditorPanelContent({ <AlertDescription className="flex items-center justify-between gap-4"> <span> This document is too large for the editor ( - {formatBytes(editorDoc.content_size_bytes ?? 0)}, {docLineCount.toLocaleString()} lines,{" "} + {Math.round((editorDoc.content_size_bytes ?? 0) / 1024 / 1024)}MB,{" "} {editorDoc.chunk_count ?? 0} chunks). Showing raw markdown below. </span> <Button @@ -839,43 +802,36 @@ export function EditorPanelContent({ <FileText className="size-4" /> <AlertDescription> {isOverPlateLimit - ? `This document is ${formatBytes(activeMarkdownSizeBytes)} and ${activeMarkdownLineCount.toLocaleString()} lines, above the rich editor limit of ${formatBytes(plateMaxBytes)} or ${plateMaxLines.toLocaleString()} lines. You can save, but it will reopen in raw markdown mode.` - : `This document is approaching the rich editor limit (${formatBytes(activeMarkdownSizeBytes)} of ${formatBytes(plateMaxBytes)}, ${activeMarkdownLineCount.toLocaleString()} of ${plateMaxLines.toLocaleString()} lines).`} + ? `This document is ${formatBytes(activeMarkdownSizeBytes)}, above the rich editor limit of ${formatBytes(plateMaxBytes)}. You can save, but it will reopen in raw markdown mode.` + : `This document is approaching the rich editor limit (${formatBytes(activeMarkdownSizeBytes)} of ${formatBytes(plateMaxBytes)}).`} </AlertDescription> </Alert> )} <div className="flex-1 min-h-0 overflow-hidden"> - <PlateErrorBoundary - key={`plate-boundary-${editorInstanceKey}`} - fallback={ - <SourceCodeEditor - path={`${editorDoc.title || "document"}.md`} - language="markdown" - value={editorDoc.source_markdown} - readOnly - onChange={() => {}} - /> - } - > - <PlateEditor - key={editorInstanceKey} - preset="full" - markdown={editorDoc.source_markdown} - onMarkdownChange={handleMarkdownChange} - readOnly={!isEditing} - placeholder="Start writing..." - editorVariant="default" - allowModeToggle={false} - reserveToolbarSpace - defaultEditing={isEditing} - className="**:[[role=toolbar]]:bg-sidebar!" - // Render `[citation:N]` badges in view mode only. - // Edit mode keeps raw text so the user can edit/delete - // tokens directly. `local_file` never reaches this branch - // (handled by the source_code editor above). - enableCitations={!isEditing && !isLocalFileMode && !isMemoryMode} - /> - </PlateErrorBoundary> + <PlateEditor + key={`${ + isMemoryMode + ? `memory-${memoryScope ?? "user"}` + : isLocalFileMode + ? (localFilePath ?? "local-file") + : documentId + }-${isEditing ? "editing" : "viewing"}`} + preset="full" + markdown={editorDoc.source_markdown} + onMarkdownChange={handleMarkdownChange} + readOnly={!isEditing} + placeholder="Start writing..." + editorVariant="default" + allowModeToggle={false} + reserveToolbarSpace + defaultEditing={isEditing} + className="**:[[role=toolbar]]:bg-sidebar!" + // Render `[citation:N]` badges in view mode only. + // Edit mode keeps raw text so the user can edit/delete + // tokens directly. `local_file` never reaches this branch + // (handled by the source_code editor above). + enableCitations={!isEditing && !isLocalFileMode && !isMemoryMode} + /> </div> </div> ) : ( diff --git a/surfsense_web/components/editor-panel/memory.ts b/surfsense_web/components/editor-panel/memory.ts index 1beb977a6..aa5b1f68d 100644 --- a/surfsense_web/components/editor-panel/memory.ts +++ b/surfsense_web/components/editor-panel/memory.ts @@ -1,7 +1,6 @@ "use client"; import { authenticatedFetch } from "@/lib/auth-utils"; -import { buildBackendUrl } from "@/lib/env-config"; export type MemoryScope = "user" | "team"; @@ -30,6 +29,10 @@ function getMemoryPath(scope: MemoryScope, searchSpaceId?: number | null) { return `/api/v1/searchspaces/${searchSpaceId}/memory`; } +function getBackendUrl(path: string) { + return `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}${path}`; +} + export function getMemoryLimitState(length: number, limits?: MemoryLimits | null) { if (!limits) { return { @@ -62,7 +65,7 @@ export async function fetchMemoryEditorDocument({ title?: string | null; signal?: AbortSignal; }) { - const response = await authenticatedFetch(buildBackendUrl(getMemoryPath(scope, searchSpaceId)), { + const response = await authenticatedFetch(getBackendUrl(getMemoryPath(scope, searchSpaceId)), { method: "GET", signal, }); @@ -94,7 +97,7 @@ export async function saveMemoryMarkdown({ searchSpaceId?: number | null; markdown: string; }) { - const response = await authenticatedFetch(buildBackendUrl(getMemoryPath(scope, searchSpaceId)), { + const response = await authenticatedFetch(getBackendUrl(getMemoryPath(scope, searchSpaceId)), { method: "PUT", headers: { "Content-Type": "application/json" }, body: JSON.stringify({ memory_md: markdown }), diff --git a/surfsense_web/components/editor/plate-error-boundary.tsx b/surfsense_web/components/editor/plate-error-boundary.tsx deleted file mode 100644 index c5c18f5e0..000000000 --- a/surfsense_web/components/editor/plate-error-boundary.tsx +++ /dev/null @@ -1,34 +0,0 @@ -"use client"; - -import { Component, type ReactNode } from "react"; - -interface PlateErrorBoundaryProps { - children: ReactNode; - fallback: ReactNode; -} - -interface PlateErrorBoundaryState { - hasError: boolean; -} - -export class PlateErrorBoundary extends Component< - PlateErrorBoundaryProps, - PlateErrorBoundaryState -> { - constructor(props: PlateErrorBoundaryProps) { - super(props); - this.state = { hasError: false }; - } - - static getDerivedStateFromError(): PlateErrorBoundaryState { - return { hasError: true }; - } - - render() { - if (this.state.hasError) { - return this.props.fallback; - } - - return this.props.children; - } -} diff --git a/surfsense_web/components/free-chat/anonymous-chat.tsx b/surfsense_web/components/free-chat/anonymous-chat.tsx index e3b8273bc..aff58f7bc 100644 --- a/surfsense_web/components/free-chat/anonymous-chat.tsx +++ b/surfsense_web/components/free-chat/anonymous-chat.tsx @@ -6,7 +6,7 @@ import { Button } from "@/components/ui/button"; import type { AnonModel, AnonQuotaResponse } from "@/contracts/types/anonymous-chat.types"; import { anonymousChatApiService } from "@/lib/apis/anonymous-chat-api.service"; import { readSSEStream } from "@/lib/chat/streaming-state"; -import { buildBackendUrl } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; import { trackAnonymousChatMessageSent } from "@/lib/posthog/events"; import { cn } from "@/lib/utils"; import { QuotaBar } from "./quota-bar"; @@ -81,7 +81,7 @@ export function AnonymousChat({ model }: AnonymousChatProps) { content: m.content, })); - const response = await fetch(buildBackendUrl("/api/v1/public/anon-chat/stream"), { + const response = await fetch(`${BACKEND_URL}/api/v1/public/anon-chat/stream`, { method: "POST", headers: { "Content-Type": "application/json" }, credentials: "include", @@ -188,6 +188,9 @@ export function AnonymousChat({ model }: AnonymousChatProps) { </div> </div> <h2 className="text-xl font-semibold mb-2">{model.name}</h2> + {model.description && ( + <p className="text-sm text-muted-foreground max-w-md">{model.description}</p> + )} <p className="text-xs text-muted-foreground mt-4"> Free to use · No login required · Start typing below </p> diff --git a/surfsense_web/components/free-chat/free-chat-page.tsx b/surfsense_web/components/free-chat/free-chat-page.tsx index 966aaee60..b28b1e0a1 100644 --- a/surfsense_web/components/free-chat/free-chat-page.tsx +++ b/surfsense_web/components/free-chat/free-chat-page.tsx @@ -33,8 +33,9 @@ import { updateThinkingSteps, updateToolCall, } from "@/lib/chat/streaming-state"; -import { buildBackendUrl } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; import { trackAnonymousChatMessageSent } from "@/lib/posthog/events"; +import { FreeModelSelector } from "./free-model-selector"; import { FreeThread } from "./free-thread"; import { RemoveAdsBanner } from "./remove-ads-banner"; @@ -62,21 +63,6 @@ function normalizeFreeChatErrorMessage(error: unknown): string { if (code === "THREAD_BUSY") { return "A previous response is still stopping. Please try again in a moment."; } - if (code === "MODEL_AUTH_FAILED") { - return "This model’s API key is invalid or expired. Switch models, or update the API key."; - } - if (code === "MODEL_NOT_FOUND") { - return "This model is unavailable or no longer exists. Please switch models."; - } - if (code === "MODEL_CONTEXT_LIMIT") { - return "This request is too large for the selected model. Reduce the input or switch models."; - } - if (code === "MODEL_PROVIDER_UNAVAILABLE") { - return "The selected model provider is temporarily unavailable. Please try again or switch models."; - } - if (code === "RATE_LIMITED") { - return "This model is temporarily rate-limited. Please try again in a few seconds or switch models."; - } return error.message || "An unexpected error occurred"; } @@ -168,7 +154,7 @@ export function FreeChatPage() { assistantMsgId: string, signal: AbortSignal, turnstileToken: string | null - ): Promise<"captcha" | undefined> => { + ): Promise<"captcha" | void> => { const reqBody: Record<string, unknown> = { model_slug: modelSlug, messages: messageHistory, @@ -176,7 +162,7 @@ export function FreeChatPage() { if (!webSearchEnabled) reqBody.disabled_tools = ["web_search"]; if (turnstileToken) reqBody.turnstile_token = turnstileToken; - const response = await fetch(buildBackendUrl("/api/v1/public/anon-chat/stream"), { + const response = await fetch(`${BACKEND_URL}/api/v1/public/anon-chat/stream`, { method: "POST", headers: { "Content-Type": "application/json" }, credentials: "include", @@ -498,6 +484,10 @@ export function FreeChatPage() { <TimelineDataUI /> <StepSeparatorDataUI /> <div className="flex h-full flex-col overflow-hidden"> + <div className="flex h-14 shrink-0 items-center justify-between border-b border-border/40 px-4"> + <FreeModelSelector /> + </div> + <RemoveAdsBanner /> {captchaRequired && TURNSTILE_SITE_KEY && ( diff --git a/surfsense_web/components/free-chat/free-composer.tsx b/surfsense_web/components/free-chat/free-composer.tsx index 162b906ad..46d9e0259 100644 --- a/surfsense_web/components/free-chat/free-composer.tsx +++ b/surfsense_web/components/free-chat/free-composer.tsx @@ -13,7 +13,6 @@ import { useAnonymousMode } from "@/contexts/anonymous-mode"; import { useLoginGate } from "@/contexts/login-gate"; import { anonymousChatApiService } from "@/lib/apis/anonymous-chat-api.service"; import { cn } from "@/lib/utils"; -import { FreeModelSelector } from "./free-model-selector"; const ANON_ALLOWED_EXTENSIONS = new Set([ ".md", @@ -228,8 +227,7 @@ export const FreeComposer: FC = () => { </Tooltip> </div> - <div className="flex min-w-0 items-center gap-1"> - <FreeModelSelector className="h-8 max-w-[44vw] px-2 sm:max-w-[220px] sm:px-3" /> + <div className="flex items-center gap-1"> {!isRunning ? ( <ComposerPrimitive.Send asChild> <TooltipIconButton tooltip="Send" variant="default" className="size-8 rounded-full"> diff --git a/surfsense_web/components/free-chat/free-model-selector.tsx b/surfsense_web/components/free-chat/free-model-selector.tsx index d04bca8a2..9bf4ecee5 100644 --- a/surfsense_web/components/free-chat/free-model-selector.tsx +++ b/surfsense_web/components/free-chat/free-model-selector.tsx @@ -1,6 +1,6 @@ "use client"; -import { Check, ChevronDown, Cpu } from "lucide-react"; +import { Bot, Check, ChevronDown } from "lucide-react"; import { useRouter } from "next/navigation"; import { useCallback, useEffect, useMemo, useState } from "react"; import { Badge } from "@/components/ui/badge"; @@ -82,7 +82,7 @@ export function FreeModelSelector({ className }: { className?: string }) { </> ) : ( <> - <Cpu className="size-4 text-muted-foreground" /> + <Bot className="size-4 text-muted-foreground" /> <span className="text-muted-foreground">Select Model</span> </> )} diff --git a/surfsense_web/components/homepage/global-announcement.tsx b/surfsense_web/components/homepage/global-announcement.tsx deleted file mode 100644 index 212be42c7..000000000 --- a/surfsense_web/components/homepage/global-announcement.tsx +++ /dev/null @@ -1,27 +0,0 @@ -import { IconInfoCircle } from "@tabler/icons-react"; -import { GLOBAL_ANNOUNCEMENT_ENABLED, GLOBAL_ANNOUNCEMENT_MESSAGE } from "@/lib/env-config"; - -/** - * Small, site-wide banner for planned downtime / maintenance notices. - * - * Controlled entirely through build-time env vars so it can be toggled from - * Vercel without a code change: - * - NEXT_PUBLIC_GLOBAL_ANNOUNCEMENT_ENABLED ("true" to show) - * - NEXT_PUBLIC_GLOBAL_ANNOUNCEMENT_MESSAGE (the copy to display) - */ -export function GlobalAnnouncement() { - const message = GLOBAL_ANNOUNCEMENT_MESSAGE.trim(); - - if (!GLOBAL_ANNOUNCEMENT_ENABLED || !message) { - return null; - } - - return ( - <div className="fixed bottom-0 left-0 right-0 z-60 w-full bg-amber-500/15 text-amber-900 backdrop-blur-md dark:bg-amber-400/10 dark:text-amber-200 border-t border-amber-500/30"> - <div className="mx-auto flex max-w-7xl items-center justify-center gap-2 px-4 py-2 text-center text-sm font-medium"> - <IconInfoCircle className="h-4 w-4 shrink-0" /> - <span>{message}</span> - </div> - </div> - ); -} diff --git a/surfsense_web/components/homepage/hero-section.tsx b/surfsense_web/components/homepage/hero-section.tsx index c9430f098..09cf316d8 100644 --- a/surfsense_web/components/homepage/hero-section.tsx +++ b/surfsense_web/components/homepage/hero-section.tsx @@ -37,7 +37,7 @@ import { getAssetLabel, usePrimaryDownload, } from "@/lib/desktop-download-utils"; -import { buildBackendUrl } from "@/lib/env-config"; +import { AUTH_TYPE, BACKEND_URL } from "@/lib/env-config"; import { trackLoginAttempt } from "@/lib/posthog/events"; import { cn } from "@/lib/utils"; @@ -314,35 +314,39 @@ export function HeroSection() { } function GetStartedButton() { + const isGoogleAuth = AUTH_TYPE === "GOOGLE"; const [isRedirecting, setIsRedirecting] = useState(false); const handleGoogleLogin = () => { if (isRedirecting) return; setIsRedirecting(true); trackLoginAttempt("google"); - window.location.href = buildBackendUrl("/auth/google/authorize-redirect"); + window.location.href = `${BACKEND_URL}/auth/google/authorize-redirect`; }; - return ( - <> + if (isGoogleAuth) { + return ( <Button type="button" variant="ghost" onClick={handleGoogleLogin} disabled={isRedirecting} - className="runtime-auth-google h-14 w-full cursor-pointer gap-3 rounded-lg border border-white bg-white text-center text-base font-medium text-[#1f1f1f] shadow-sm transition duration-150 hover:bg-zinc-100 hover:text-[#1f1f1f] sm:w-56 dark:border-white" + className="h-14 w-full cursor-pointer gap-3 rounded-lg border border-white bg-white text-center text-base font-medium text-[#1f1f1f] shadow-sm transition duration-150 hover:bg-zinc-100 hover:text-[#1f1f1f] sm:w-56 dark:border-white" > <GoogleLogo className="h-5 w-5" /> <span>Continue with Google</span> </Button> - <Button - asChild - variant="ghost" - className="runtime-auth-local h-14 w-full rounded-lg bg-black text-center text-base font-medium text-white shadow-sm ring-1 shadow-black/10 ring-black/10 transition duration-150 active:scale-98 hover:bg-black sm:w-52 dark:bg-white dark:text-black dark:hover:bg-white" - > - <Link href="/login">Get Started</Link> - </Button> - </> + ); + } + + return ( + <Button + asChild + variant="ghost" + className="h-14 w-full rounded-lg bg-black text-center text-base font-medium text-white shadow-sm ring-1 shadow-black/10 ring-black/10 transition duration-150 active:scale-98 hover:bg-black sm:w-52 dark:bg-white dark:text-black dark:hover:bg-white" + > + <Link href="/login">Get Started</Link> + </Button> ); } diff --git a/surfsense_web/components/icons/providers/azure.svg b/surfsense_web/components/icons/providers/azure.svg deleted file mode 100644 index ba80f55ca..000000000 --- a/surfsense_web/components/icons/providers/azure.svg +++ /dev/null @@ -1 +0,0 @@ -<svg viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M7.242 1.613A1.11 1.11 0 018.295.857h6.977L8.03 22.316a1.11 1.11 0 01-1.052.755h-5.43a1.11 1.11 0 01-1.053-1.466L7.242 1.613z" fill="url(#azure-gradient-0)"></path><path d="M18.397 15.296H7.4a.51.51 0 00-.347.882l7.066 6.595c.206.192.477.298.758.298h6.226l-2.706-7.775z" fill="#0078D4"></path><path d="M15.272.857H7.497L0 23.071h7.775l1.596-4.73 5.068 4.73h6.665l-2.707-7.775h-7.998L15.272.857z" fill="url(#azure-gradient-1)"></path><path d="M17.193 1.613a1.11 1.11 0 00-1.052-.756h-7.81.035c.477 0 .9.304 1.052.756l6.748 19.992a1.11 1.11 0 01-1.052 1.466h-.12 7.895a1.11 1.11 0 001.052-1.466L17.193 1.613z" fill="url(#azure-gradient-2)"></path><defs><linearGradient gradientUnits="userSpaceOnUse" id="azure-gradient-0" x1="8.247" x2="1.002" y1="1.626" y2="23.03"><stop stop-color="#114A8B"></stop><stop offset="1" stop-color="#0669BC"></stop></linearGradient><linearGradient gradientUnits="userSpaceOnUse" id="azure-gradient-1" x1="14.042" x2="12.324" y1="15.302" y2="15.888"><stop stop-opacity=".3"></stop><stop offset=".071" stop-opacity=".2"></stop><stop offset=".321" stop-opacity=".1"></stop><stop offset=".623" stop-opacity=".05"></stop><stop offset="1" stop-opacity="0"></stop></linearGradient><linearGradient gradientUnits="userSpaceOnUse" id="azure-gradient-2" x1="12.841" x2="20.793" y1="1.626" y2="22.814"><stop stop-color="#3CCBF4"></stop><stop offset="1" stop-color="#2892DF"></stop></linearGradient></defs></svg> \ No newline at end of file diff --git a/surfsense_web/components/icons/providers/bedrock.svg b/surfsense_web/components/icons/providers/bedrock.svg index cde500c0d..195aa6594 100644 --- a/surfsense_web/components/icons/providers/bedrock.svg +++ b/surfsense_web/components/icons/providers/bedrock.svg @@ -1 +1 @@ -<svg viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><defs><linearGradient id="bedrock-gradient" x1="80%" x2="20%" y1="20%" y2="80%"><stop offset="0%" stop-color="#6350FB"></stop><stop offset="50%" stop-color="#3D8FFF"></stop><stop offset="100%" stop-color="#9AD8F8"></stop></linearGradient></defs><path d="M13.05 15.513h3.08c.214 0 .389.177.389.394v1.82a1.704 1.704 0 011.296 1.661c0 .943-.755 1.708-1.685 1.708-.931 0-1.686-.765-1.686-1.708 0-.807.554-1.484 1.297-1.662v-1.425h-2.69v4.663a.395.395 0 01-.188.338l-2.69 1.641a.385.385 0 01-.405-.002l-4.926-3.086a.395.395 0 01-.185-.336V16.3L2.196 14.87A.395.395 0 012 14.555L2 14.528V9.406c0-.14.073-.27.192-.34l2.465-1.462V4.448c0-.129.062-.249.165-.322l.021-.014L9.77 1.058a.385.385 0 01.407 0l2.69 1.675a.395.395 0 01.185.336V7.6h3.856V5.683a1.704 1.704 0 01-1.296-1.662c0-.943.755-1.708 1.685-1.708.931 0 1.685.765 1.685 1.708 0 .807-.553 1.484-1.296 1.662v2.311a.391.391 0 01-.389.394h-4.245v1.806h6.624a1.69 1.69 0 011.64-1.313c.93 0 1.685.764 1.685 1.707 0 .943-.754 1.708-1.685 1.708a1.69 1.69 0 01-1.64-1.314H13.05v1.937h4.953l.915 1.18a1.66 1.66 0 01.84-.227c.931 0 1.685.764 1.685 1.707 0 .943-.754 1.708-1.685 1.708-.93 0-1.685-.765-1.685-1.708 0-.346.102-.668.276-.937l-.724-.935H13.05v1.806zM9.973 1.856L7.93 3.122V6.09h-.778V3.604L5.435 4.669v2.945l2.11 1.36L9.712 7.61V5.334h.778V7.83c0 .136-.07.263-.184.335L7.963 9.638v2.081l1.422 1.009-.446.646-1.406-.998-1.53 1.005-.423-.66 1.605-1.055v-1.99L5.038 8.29l-2.26 1.34v1.676l1.972-1.189.398.677-2.37 1.429V14.3l2.166 1.258 2.27-1.368.397.677-2.176 1.311V19.3l1.876 1.175 2.365-1.426.398.678-2.017 1.216 1.918 1.201 2.298-1.403v-5.78l-4.758 2.893-.4-.675 5.158-3.136V3.289L9.972 1.856zM16.13 18.47a.913.913 0 00-.908.92c0 .507.406.918.908.918a.913.913 0 00.907-.919.913.913 0 00-.907-.92zm3.63-3.81a.913.913 0 00-.908.92c0 .508.406.92.907.92a.913.913 0 00.908-.92.913.913 0 00-.908-.92zm1.555-4.99a.913.913 0 00-.908.92c0 .507.407.918.908.918a.913.913 0 00.907-.919.913.913 0 00-.907-.92zM17.296 3.1a.913.913 0 00-.907.92c0 .508.406.92.907.92a.913.913 0 00.908-.92.913.913 0 00-.908-.92z" fill="url(#bedrock-gradient)" fill-rule="nonzero"></path></svg> \ No newline at end of file +<svg fill="currentColor" fill-rule="evenodd" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M13.05 15.513h3.08c.214 0 .389.177.389.394v1.82a1.704 1.704 0 011.296 1.661c0 .943-.755 1.708-1.685 1.708-.931 0-1.686-.765-1.686-1.708 0-.807.554-1.484 1.297-1.662v-1.425h-2.69v4.663a.395.395 0 01-.188.338l-2.69 1.641a.385.385 0 01-.405-.002l-4.926-3.086a.395.395 0 01-.185-.336V16.3L2.196 14.87A.395.395 0 012 14.555L2 14.528V9.406c0-.14.073-.27.192-.34l2.465-1.462V4.448c0-.129.062-.249.165-.322l.021-.014L9.77 1.058a.385.385 0 01.407 0l2.69 1.675a.395.395 0 01.185.336V7.6h3.856V5.683a1.704 1.704 0 01-1.296-1.662c0-.943.755-1.708 1.685-1.708.931 0 1.685.765 1.685 1.708 0 .807-.553 1.484-1.296 1.662v2.311a.391.391 0 01-.389.394h-4.245v1.806h6.624a1.69 1.69 0 011.64-1.313c.93 0 1.685.764 1.685 1.707 0 .943-.754 1.708-1.685 1.708a1.69 1.69 0 01-1.64-1.314H13.05v1.937h4.953l.915 1.18a1.66 1.66 0 01.84-.227c.931 0 1.685.764 1.685 1.707 0 .943-.754 1.708-1.685 1.708-.93 0-1.685-.765-1.685-1.708 0-.346.102-.668.276-.937l-.724-.935H13.05v1.806zM9.973 1.856L7.93 3.122V6.09h-.778V3.604L5.435 4.669v2.945l2.11 1.36L9.712 7.61V5.334h.778V7.83c0 .136-.07.263-.184.335L7.963 9.638v2.081l1.422 1.009-.446.646-1.406-.998-1.53 1.005-.423-.66 1.605-1.055v-1.99L5.038 8.29l-2.26 1.34v1.676l1.972-1.189.398.677-2.37 1.429V14.3l2.166 1.258 2.27-1.368.397.677-2.176 1.311V19.3l1.876 1.175 2.365-1.426.398.678-2.017 1.216 1.918 1.201 2.298-1.403v-5.78l-4.758 2.893-.4-.675 5.158-3.136V3.289L9.972 1.856zM16.13 18.47a.913.913 0 00-.908.92c0 .507.406.918.908.918a.913.913 0 00.907-.919.913.913 0 00-.907-.92zm3.63-3.81a.913.913 0 00-.908.92c0 .508.406.92.907.92a.913.913 0 00.908-.92.913.913 0 00-.908-.92zm1.555-4.99a.913.913 0 00-.908.92c0 .507.407.918.908.918a.913.913 0 00.907-.919.913.913 0 00-.907-.92zM17.296 3.1a.913.913 0 00-.907.92c0 .508.406.92.907.92a.913.913 0 00.908-.92.913.913 0 00-.908-.92z"></path></svg> \ No newline at end of file diff --git a/surfsense_web/components/icons/providers/claude.svg b/surfsense_web/components/icons/providers/claude.svg deleted file mode 100644 index 8d732d5b0..000000000 --- a/surfsense_web/components/icons/providers/claude.svg +++ /dev/null @@ -1 +0,0 @@ -<svg viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M4.709 15.955l4.72-2.647.08-.23-.08-.128H9.2l-.79-.048-2.698-.073-2.339-.097-2.266-.122-.571-.121L0 11.784l.055-.352.48-.321.686.06 1.52.103 2.278.158 1.652.097 2.449.255h.389l.055-.157-.134-.098-.103-.097-2.358-1.596-2.552-1.688-1.336-.972-.724-.491-.364-.462-.158-1.008.656-.722.881.06.225.061.893.686 1.908 1.476 2.491 1.833.365.304.145-.103.019-.073-.164-.274-1.355-2.446-1.446-2.49-.644-1.032-.17-.619a2.97 2.97 0 01-.104-.729L6.283.134 6.696 0l.996.134.42.364.62 1.414 1.002 2.229 1.555 3.03.456.898.243.832.091.255h.158V9.01l.128-1.706.237-2.095.23-2.695.08-.76.376-.91.747-.492.584.28.48.685-.067.444-.286 1.851-.559 2.903-.364 1.942h.212l.243-.242.985-1.306 1.652-2.064.73-.82.85-.904.547-.431h1.033l.76 1.129-.34 1.166-1.064 1.347-.881 1.142-1.264 1.7-.79 1.36.073.11.188-.02 2.856-.606 1.543-.28 1.841-.315.833.388.091.395-.328.807-1.969.486-2.309.462-3.439.813-.042.03.049.061 1.549.146.662.036h1.622l3.02.225.79.522.474.638-.079.485-1.215.62-1.64-.389-3.829-.91-1.312-.329h-.182v.11l1.093 1.068 2.006 1.81 2.509 2.33.127.578-.322.455-.34-.049-2.205-1.657-.851-.747-1.926-1.62h-.128v.17l.444.649 2.345 3.521.122 1.08-.17.353-.608.213-.668-.122-1.374-1.925-1.415-2.167-1.143-1.943-.14.08-.674 7.254-.316.37-.729.28-.607-.461-.322-.747.322-1.476.389-1.924.315-1.53.286-1.9.17-.632-.012-.042-.14.018-1.434 1.967-2.18 2.945-1.726 1.845-.414.164-.717-.37.067-.662.401-.589 2.388-3.036 1.44-1.882.93-1.086-.006-.158h-.055L4.132 18.56l-1.13.146-.487-.456.061-.746.231-.243 1.908-1.312-.006.006z" fill="#D97757" fill-rule="nonzero"></path></svg> \ No newline at end of file diff --git a/surfsense_web/components/icons/providers/index.ts b/surfsense_web/components/icons/providers/index.ts index 5c8276e62..aefa2a053 100644 --- a/surfsense_web/components/icons/providers/index.ts +++ b/surfsense_web/components/icons/providers/index.ts @@ -1,10 +1,8 @@ export { default as Ai21Icon } from "./ai21.svg"; export { default as AnthropicIcon } from "./anthropic.svg"; export { default as AnyscaleIcon } from "./anyscale.svg"; -export { default as AzureIcon } from "./azure.svg"; export { default as BedrockIcon } from "./bedrock.svg"; export { default as CerebrasIcon } from "./cerebras.svg"; -export { default as ClaudeIcon } from "./claude.svg"; export { default as CohereIcon } from "./cohere.svg"; export { default as CometApiIcon } from "./cometapi.svg"; export { default as DatabricksIcon } from "./dbrx.svg"; @@ -15,7 +13,6 @@ export { default as GeminiIcon } from "./gemini.svg"; export { default as GitHubModelsIcon } from "./github.svg"; export { default as GroqIcon } from "./groq.svg"; export { default as HuggingFaceIcon } from "./huggingface.svg"; -export { default as LmStudioIcon } from "./lm-studio.svg"; export { default as MiniMaxIcon } from "./minimax.svg"; export { default as MistralIcon } from "./mistral.svg"; export { default as MoonshotIcon } from "./moonshot.svg"; diff --git a/surfsense_web/components/icons/providers/lm-studio.svg b/surfsense_web/components/icons/providers/lm-studio.svg deleted file mode 100644 index b6ae7db3e..000000000 --- a/surfsense_web/components/icons/providers/lm-studio.svg +++ /dev/null @@ -1,21 +0,0 @@ -<svg viewBox="0 0 512 512" fill="none" xmlns="http://www.w3.org/2000/svg"> -<path d="M0 179.2C0 116.474 0 85.1112 12.2073 61.1531C22.9451 40.0789 40.0789 22.9451 61.1531 12.2073C85.1112 0 116.474 0 179.2 0H332.8C395.526 0 426.889 0 450.847 12.2073C471.921 22.9451 489.055 40.0789 499.793 61.1531C512 85.1112 512 116.474 512 179.2V332.8C512 395.526 512 426.889 499.793 450.847C489.055 471.921 471.921 489.055 450.847 499.793C426.889 512 395.526 512 332.8 512H179.2C116.474 512 85.1112 512 61.1531 499.793C40.0789 489.055 22.9451 471.921 12.2073 450.847C0 426.889 0 395.526 0 332.8V179.2Z" fill="url(#lm-studio-gradient)"/> -<rect opacity="0.25" x="128" y="84" width="224" height="44" rx="22" fill="white"/> -<rect x="64" y="84" width="224" height="44" rx="22" fill="white"/> -<rect opacity="0.25" x="224" y="144" width="224" height="44" rx="22" fill="white"/> -<rect x="160" y="144" width="224" height="44" rx="22" fill="white"/> -<rect opacity="0.25" x="168" y="204" width="224" height="44" rx="22" fill="white"/> -<rect x="104" y="204" width="224" height="44" rx="22" fill="white"/> -<rect opacity="0.25" x="112" y="264" width="224" height="44" rx="22" fill="white"/> -<rect x="48" y="264" width="224" height="44" rx="22" fill="white"/> -<rect opacity="0.25" x="176" y="324" width="224" height="44" rx="22" fill="white"/> -<rect x="112" y="324" width="224" height="44" rx="22" fill="white"/> -<rect opacity="0.25" x="304" y="384" width="152" height="44" rx="22" fill="white"/> -<rect x="240" y="384" width="152" height="44" rx="22" fill="white"/> -<defs> -<linearGradient id="lm-studio-gradient" x1="-219.792" y1="229.426" x2="239.06" y2="702.601" gradientUnits="userSpaceOnUse"> -<stop stop-color="#6E7EF3"/> -<stop offset="1" stop-color="#4F13BE"/> -</linearGradient> -</defs> -</svg> diff --git a/surfsense_web/components/icons/providers/vertexai.svg b/surfsense_web/components/icons/providers/vertexai.svg index e46a3ca0f..45adce83b 100644 --- a/surfsense_web/components/icons/providers/vertexai.svg +++ b/surfsense_web/components/icons/providers/vertexai.svg @@ -1 +1 @@ -<svg viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M11.995 20.216a1.892 1.892 0 100 3.785 1.892 1.892 0 000-3.785zm0 2.806a.927.927 0 11.927-.914.914.914 0 01-.927.914z" fill="#4285F4"></path><path clip-rule="evenodd" d="M21.687 14.144c.237.038.452.16.605.344a.978.978 0 01-.18 1.3l-8.24 6.082a1.892 1.892 0 00-1.147-1.508l8.28-6.08a.991.991 0 01.682-.138z" fill="#669DF6" fill-rule="evenodd"></path><path clip-rule="evenodd" d="M10.122 21.842l-8.217-6.066a.952.952 0 01-.206-1.287.978.978 0 011.287-.206l8.28 6.08a1.893 1.893 0 00-1.144 1.479z" fill="#AECBFA" fill-rule="evenodd"></path><path d="M4.273 4.475a.978.978 0 01-.965-.965V1.09a.978.978 0 111.943 0v2.42a.978.978 0 01-.978.965zM4.247 13.034a.978.978 0 100-1.956.978.978 0 000 1.956zM4.247 10.19a.978.978 0 100-1.956.978.978 0 000 1.956zM4.247 7.332a.978.978 0 100-1.956.978.978 0 000 1.956z" fill="#AECBFA"></path><path d="M19.718 7.307a.978.978 0 01-.965-.979v-2.42a.965.965 0 011.93 0v2.42a.964.964 0 01-.965.979zM19.743 13.047a.978.978 0 100-1.956.978.978 0 000 1.956zM19.743 10.151a.978.978 0 100-1.956.978.978 0 000 1.956zM19.743 2.068a.978.978 0 100-1.956.978.978 0 000 1.956z" fill="#4285F4"></path><path d="M11.995 15.917a.978.978 0 01-.965-.965v-2.459a.978.978 0 011.943 0v2.433a.976.976 0 01-.978.991zM11.995 18.762a.978.978 0 100-1.956.978.978 0 000 1.956zM11.995 10.64a.978.978 0 100-1.956.978.978 0 000 1.956zM11.995 7.783a.978.978 0 100-1.956.978.978 0 000 1.956z" fill="#669DF6"></path><path d="M15.856 10.177a.978.978 0 01-.965-.965v-2.42a.977.977 0 011.702-.763.979.979 0 01.241.763v2.42a.978.978 0 01-.978.965zM15.869 4.913a.978.978 0 100-1.956.978.978 0 000 1.956zM15.869 15.853a.978.978 0 100-1.956.978.978 0 000 1.956zM15.869 12.996a.978.978 0 100-1.956.978.978 0 000 1.956z" fill="#4285F4"></path><path d="M8.121 15.853a.978.978 0 100-1.956.978.978 0 000 1.956zM8.121 7.783a.978.978 0 100-1.956.978.978 0 000 1.956zM8.121 4.913a.978.978 0 100-1.957.978.978 0 000 1.957zM8.134 12.996a.978.978 0 01-.978-.94V9.611a.965.965 0 011.93 0v2.445a.966.966 0 01-.952.94z" fill="#AECBFA"></path></svg> \ No newline at end of file +<svg fill="currentColor" fill-rule="evenodd" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M11.995 20.216a1.892 1.892 0 100 3.785 1.892 1.892 0 000-3.785zm0 2.806a.927.927 0 11.927-.914.914.914 0 01-.927.914z"></path><path clip-rule="evenodd" d="M21.687 14.144c.237.038.452.16.605.344a.978.978 0 01-.18 1.3l-8.24 6.082a1.892 1.892 0 00-1.147-1.508l8.28-6.08a.991.991 0 01.682-.138z"></path><path clip-rule="evenodd" d="M10.122 21.842l-8.217-6.066a.952.952 0 01-.206-1.287.978.978 0 011.287-.206l8.28 6.08a1.893 1.893 0 00-1.144 1.479z"></path><path d="M4.273 4.475a.978.978 0 01-.965-.965V1.09a.978.978 0 111.943 0v2.42a.978.978 0 01-.978.965zM4.247 13.034a.978.978 0 100-1.956.978.978 0 000 1.956zM4.247 10.19a.978.978 0 100-1.956.978.978 0 000 1.956zM4.247 7.332a.978.978 0 100-1.956.978.978 0 000 1.956z"></path><path d="M19.718 7.307a.978.978 0 01-.965-.979v-2.42a.965.965 0 011.93 0v2.42a.964.964 0 01-.965.979zM19.743 13.047a.978.978 0 100-1.956.978.978 0 000 1.956zM19.743 10.151a.978.978 0 100-1.956.978.978 0 000 1.956zM19.743 2.068a.978.978 0 100-1.956.978.978 0 000 1.956z"></path><path d="M11.995 15.917a.978.978 0 01-.965-.965v-2.459a.978.978 0 011.943 0v2.433a.976.976 0 01-.978.991zM11.995 18.762a.978.978 0 100-1.956.978.978 0 000 1.956zM11.995 10.64a.978.978 0 100-1.956.978.978 0 000 1.956zM11.995 7.783a.978.978 0 100-1.956.978.978 0 000 1.956z"></path><path d="M15.856 10.177a.978.978 0 01-.965-.965v-2.42a.977.977 0 011.702-.763.979.979 0 01.241.763v2.42a.978.978 0 01-.978.965zM15.869 4.913a.978.978 0 100-1.956.978.978 0 000 1.956zM15.869 15.853a.978.978 0 100-1.956.978.978 0 000 1.956zM15.869 12.996a.978.978 0 100-1.956.978.978 0 000 1.956z"></path><path d="M8.121 15.853a.978.978 0 100-1.956.978.978 0 000 1.956zM8.121 7.783a.978.978 0 100-1.956.978.978 0 000 1.956zM8.121 4.913a.978.978 0 100-1.957.978.978 0 000 1.957zM8.134 12.996a.978.978 0 01-.978-.94V9.611a.965.965 0 011.93 0v2.445a.966.966 0 01-.952.94z"></path></svg> \ No newline at end of file diff --git a/surfsense_web/components/layout/index.ts b/surfsense_web/components/layout/index.ts index eb475e414..67f161d1a 100644 --- a/surfsense_web/components/layout/index.ts +++ b/surfsense_web/components/layout/index.ts @@ -12,7 +12,6 @@ export type { export { ChatListItem, CreateSearchSpaceDialog, - CreditBalanceDisplay, Header, IconRail, LayoutShell, @@ -20,6 +19,7 @@ export { MobileSidebarTrigger, NavIcon, NavSection, + PageUsageDisplay, SearchSpaceAvatar, Sidebar, SidebarCollapseButton, diff --git a/surfsense_web/components/layout/providers/LayoutDataProvider.tsx b/surfsense_web/components/layout/providers/LayoutDataProvider.tsx index 429a1fde8..46f6ec8ae 100644 --- a/surfsense_web/components/layout/providers/LayoutDataProvider.tsx +++ b/surfsense_web/components/layout/providers/LayoutDataProvider.tsx @@ -2,7 +2,7 @@ import { useQuery } from "@tanstack/react-query"; import { useAtom, useAtomValue, useSetAtom } from "jotai"; -import { AlarmClock, AlertTriangle, Inbox, LibraryBig } from "lucide-react"; +import { AlertTriangle, Inbox, LibraryBig, Workflow } from "lucide-react"; import { useParams, usePathname, useRouter } from "next/navigation"; import { useTranslations } from "next-intl"; import { useTheme } from "next-themes"; @@ -186,40 +186,40 @@ export function LayoutDataProvider({ searchSpaceId, children }: LayoutDataProvid setStatusInboxItems(statusInbox.inboxItems); }, [statusInbox.inboxItems, setStatusInboxItems]); - // Track seen notification IDs to detect new insufficient_credits notifications - const seenCreditNotifications = useRef<Set<number>>(new Set()); + // Track seen notification IDs to detect new page_limit_exceeded notifications + const seenPageLimitNotifications = useRef<Set<number>>(new Set()); const isInitialLoad = useRef(true); - // Effect to show toast for new insufficient_credits notifications + // Effect to show toast for new page_limit_exceeded notifications useEffect(() => { if (statusInbox.loading) return; - const creditNotifications = statusInbox.inboxItems.filter( - (item) => item.type === "insufficient_credits" + const pageLimitNotifications = statusInbox.inboxItems.filter( + (item) => item.type === "page_limit_exceeded" ); if (isInitialLoad.current) { - for (const notification of creditNotifications) { - seenCreditNotifications.current.add(notification.id); + for (const notification of pageLimitNotifications) { + seenPageLimitNotifications.current.add(notification.id); } isInitialLoad.current = false; return; } - const newNotifications = creditNotifications.filter( - (notification) => !seenCreditNotifications.current.has(notification.id) + const newNotifications = pageLimitNotifications.filter( + (notification) => !seenPageLimitNotifications.current.has(notification.id) ); for (const notification of newNotifications) { - seenCreditNotifications.current.add(notification.id); + seenPageLimitNotifications.current.add(notification.id); toast.error(notification.title, { description: notification.message, duration: 8000, icon: <AlertTriangle className="h-5 w-5 text-amber-500" />, action: { - label: "Buy credits", - onClick: () => router.push(`/dashboard/${searchSpaceId}/buy-more`), + label: "Get More Pages", + onClick: () => router.push(`/dashboard/${searchSpaceId}/more-pages`), }, }); } @@ -342,7 +342,7 @@ export function LayoutDataProvider({ searchSpaceId, children }: LayoutDataProvid { title: "Automations", url: `/dashboard/${searchSpaceId}/automations`, - icon: AlarmClock, + icon: Workflow, isActive: isAutomationsActive, }, isMobile @@ -696,7 +696,6 @@ export function LayoutDataProvider({ searchSpaceId, children }: LayoutDataProvid const isAutomationsPage = pathname?.includes("/automations") === true; const useWorkspacePanel = pathname?.endsWith("/buy-more") === true || - pathname?.endsWith("/earn-credits") === true || pathname?.endsWith("/more-pages") === true || isUserSettingsPage || isSearchSpaceSettingsPage || diff --git a/surfsense_web/components/layout/types/layout.types.ts b/surfsense_web/components/layout/types/layout.types.ts index 1dfb51ca8..1bb0a089e 100644 --- a/surfsense_web/components/layout/types/layout.types.ts +++ b/surfsense_web/components/layout/types/layout.types.ts @@ -74,6 +74,11 @@ export interface ChatsSectionProps { searchSpaceId?: string; } +export interface PageUsageDisplayProps { + pagesUsed: number; + pagesLimit: number; +} + export interface SidebarUserProfileProps { user: User; searchSpaceId?: string; diff --git a/surfsense_web/components/layout/ui/dialogs/CreateSearchSpaceDialog.tsx b/surfsense_web/components/layout/ui/dialogs/CreateSearchSpaceDialog.tsx index 009b2c120..6f385b465 100644 --- a/surfsense_web/components/layout/ui/dialogs/CreateSearchSpaceDialog.tsx +++ b/surfsense_web/components/layout/ui/dialogs/CreateSearchSpaceDialog.tsx @@ -67,7 +67,7 @@ export function CreateSearchSpaceDialog({ open, onOpenChange }: CreateSearchSpac trackSearchSpaceCreated(result.id, values.name); - router.push(`/dashboard/${result.id}/new-chat`); + router.push(`/dashboard/${result.id}/onboard`); } catch (error) { console.error("Failed to create search space:", error); setIsSubmitting(false); diff --git a/surfsense_web/components/layout/ui/header/Header.tsx b/surfsense_web/components/layout/ui/header/Header.tsx index ea700391a..79839622d 100644 --- a/surfsense_web/components/layout/ui/header/Header.tsx +++ b/surfsense_web/components/layout/ui/header/Header.tsx @@ -6,6 +6,7 @@ import { currentThreadAtom } from "@/atoms/chat/current-thread.atom"; import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms"; import { activeTabAtom } from "@/atoms/tabs/tabs.atom"; import { ActionLogButton } from "@/components/agent-action-log/action-log-button"; +import { ChatHeader } from "@/components/new-chat/chat-header"; import { ChatShareButton } from "@/components/new-chat/chat-share-button"; import type { ThreadRecord } from "@/lib/chat/thread-persistence"; @@ -65,8 +66,13 @@ export function Header({ mobileMenuTrigger }: HeaderProps) { return ( <header className="sticky top-0 z-10 flex h-14 shrink-0 items-center gap-2 bg-main-panel/95 backdrop-blur supports-backdrop-filter:bg-main-panel/60 px-4"> - {/* Left side - Mobile menu trigger */} - <div className="flex flex-1 items-center gap-2 min-w-0">{mobileMenuTrigger}</div> + {/* Left side - Mobile menu trigger + Model selector */} + <div className="flex flex-1 items-center gap-2 min-w-0"> + {mobileMenuTrigger} + {isChatPage && !isDocumentTab && searchSpaceId && ( + <ChatHeader searchSpaceId={Number(searchSpaceId)} className="md:h-9 md:px-4 md:text-sm" /> + )} + </div> {/* Right side - Actions */} <div className="ml-auto flex items-center gap-2"> diff --git a/surfsense_web/components/layout/ui/index.ts b/surfsense_web/components/layout/ui/index.ts index 85a47bea1..00b862082 100644 --- a/surfsense_web/components/layout/ui/index.ts +++ b/surfsense_web/components/layout/ui/index.ts @@ -4,10 +4,10 @@ export { IconRail, NavIcon, SearchSpaceAvatar } from "./icon-rail"; export { LayoutShell } from "./shell"; export { ChatListItem, - CreditBalanceDisplay, MobileSidebar, MobileSidebarTrigger, NavSection, + PageUsageDisplay, Sidebar, SidebarCollapseButton, SidebarHeader, diff --git a/surfsense_web/components/layout/ui/sidebar/AuthenticatedPageUsageDisplay.tsx b/surfsense_web/components/layout/ui/sidebar/AuthenticatedPageUsageDisplay.tsx new file mode 100644 index 000000000..ad31d50bb --- /dev/null +++ b/surfsense_web/components/layout/ui/sidebar/AuthenticatedPageUsageDisplay.tsx @@ -0,0 +1,15 @@ +"use client"; + +import { useQuery } from "@rocicorp/zero/react"; +import { useIsAnonymous } from "@/contexts/anonymous-mode"; +import { queries } from "@/zero/queries"; +import { PageUsageDisplay } from "./PageUsageDisplay"; + +export function AuthenticatedPageUsageDisplay() { + const isAnonymous = useIsAnonymous(); + const [me] = useQuery(queries.user.me({})); + + if (isAnonymous || !me) return null; + + return <PageUsageDisplay pagesUsed={me.pagesUsed} pagesLimit={me.pagesLimit} />; +} diff --git a/surfsense_web/components/layout/ui/sidebar/CreditBalanceDisplay.tsx b/surfsense_web/components/layout/ui/sidebar/CreditBalanceDisplay.tsx deleted file mode 100644 index 1d45137fb..000000000 --- a/surfsense_web/components/layout/ui/sidebar/CreditBalanceDisplay.tsx +++ /dev/null @@ -1,55 +0,0 @@ -"use client"; - -import { useQuery } from "@rocicorp/zero/react"; -import { useIsAnonymous } from "@/contexts/anonymous-mode"; -import { cn } from "@/lib/utils"; -import { queries } from "@/zero/queries"; - -// Show the low-balance warning state once the wallet drops below $0.50. -const LOW_BALANCE_WARNING_MICROS = 500_000; - -function formatUsd(micros: number): string { - // Clamp at $0.00 — the balance can dip slightly negative when the actual - // cost of a job exceeds the pre-charge estimate. - const dollars = Math.max(0, micros) / 1_000_000; - if (dollars >= 100) return `$${dollars.toFixed(0)}`; - if (dollars >= 1) return `$${dollars.toFixed(2)}`; - // Sub-dollar balances need extra precision so the user can still tell what - // is left ("$0.042 of credit") instead of rounding to "$0.00". - if (dollars > 0) return `$${dollars.toFixed(3)}`; - return "$0.00"; -} - -/** - * Unified credit-wallet balance shown in the sidebar. - * - * The single ``creditMicrosBalance`` replaces the former page-limit and - * premium-credit meters. Values come from Zero (live-replicated from Postgres) - * as integer micro-USD (1_000_000 == $1.00). A low-balance warning highlights - * the amount when it falls below $0.50 so the user knows to top up or enable - * auto-reload. - */ -export function CreditBalanceDisplay() { - const isAnonymous = useIsAnonymous(); - const [me] = useQuery(queries.user.me({})); - - if (isAnonymous || !me) return null; - - const balanceMicros = me.creditMicrosBalance ?? 0; - const isLow = balanceMicros < LOW_BALANCE_WARNING_MICROS; - - return ( - <div className="flex items-center justify-between text-xs"> - <span className="text-muted-foreground">Credits</span> - <span - className={cn( - "font-medium tabular-nums", - isLow ? "text-amber-600 dark:text-amber-500" : "text-foreground" - )} - title={isLow ? "Low balance — buy credits or enable auto-reload" : undefined} - > - {formatUsd(balanceMicros)} - </span> - </div> - ); -} diff --git a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx index 44cc56ab0..6c6668319 100644 --- a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx +++ b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx @@ -43,7 +43,6 @@ import type { FolderDisplay } from "@/components/documents/FolderNode"; import { FolderPickerDialog } from "@/components/documents/FolderPickerDialog"; import { FolderTreeView } from "@/components/documents/FolderTreeView"; import { VersionHistoryDialog } from "@/components/documents/version-history"; -import { useRuntimeConfig } from "@/components/providers/runtime-config"; import { EXPORT_FILE_EXTENSIONS } from "@/components/shared/ExportMenuItems"; import { DEFAULT_EXCLUDE_PATTERNS, @@ -79,7 +78,7 @@ import { foldersApiService } from "@/lib/apis/folders-api.service"; import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service"; import { authenticatedFetch } from "@/lib/auth-utils"; import { getMentionDocKey } from "@/lib/chat/mention-doc-key"; -import { buildBackendUrl } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; import { uploadFolderScan } from "@/lib/folder-sync-upload"; import { getSupportedExtensionsSet } from "@/lib/supported-extensions"; import { queries } from "@/zero/queries/index"; @@ -227,7 +226,6 @@ function AuthenticatedDocumentsSidebarBase({ const isMobile = !useMediaQuery("(min-width: 640px)"); const platformElectronAPI = useElectronAPI(); const electronAPI = desktopFeaturesEnabled ? platformElectronAPI : null; - const { etlService } = useRuntimeConfig(); const searchSpaceId = Number(params.search_space_id); const setConnectorDialogOpen = useSetAtom(connectorDialogOpenAtom); const openEditorPanel = useSetAtom(openEditorPanelAtom); @@ -620,8 +618,7 @@ function AuthenticatedDocumentsSidebarBase({ folderName: matched.name, searchSpaceId, excludePatterns: matched.excludePatterns ?? DEFAULT_EXCLUDE_PATTERNS, - fileExtensions: - matched.fileExtensions ?? Array.from(getSupportedExtensionsSet(undefined, etlService)), + fileExtensions: matched.fileExtensions ?? Array.from(getSupportedExtensionsSet()), rootFolderId: folder.id, }); toast.success(`Re-scan complete: ${matched.name}`); @@ -629,7 +626,7 @@ function AuthenticatedDocumentsSidebarBase({ toast.error((err as Error)?.message || "Failed to re-scan folder"); } }, - [searchSpaceId, electronAPI, etlService] + [searchSpaceId, electronAPI] ); const handleStopWatching = useCallback( @@ -751,9 +748,7 @@ function AuthenticatedDocumentsSidebarBase({ .trim() .slice(0, 80) || "folder"; await doExport( - buildBackendUrl(`/api/v1/search-spaces/${searchSpaceId}/export`, { - folder_id: ctx.folder.id, - }), + `${BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/export?folder_id=${ctx.folder.id}`, `${safeName}.zip` ); toast.success(`Folder "${ctx.folder.name}" exported`); @@ -805,9 +800,7 @@ function AuthenticatedDocumentsSidebarBase({ .trim() .slice(0, 80) || "folder"; await doExport( - buildBackendUrl(`/api/v1/search-spaces/${searchSpaceId}/export`, { - folder_id: folder.id, - }), + `${BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/export?folder_id=${folder.id}`, `${safeName}.zip` ); toast.success(`Folder "${folder.name}" exported`); @@ -827,8 +820,8 @@ function AuthenticatedDocumentsSidebarBase({ try { const endpoint = doc.document_type === "USER_MEMORY" - ? buildBackendUrl("/api/v1/users/me/memory") - : buildBackendUrl(`/api/v1/searchspaces/${searchSpaceId}/memory`); + ? `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/users/me/memory` + : `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/searchspaces/${searchSpaceId}/memory`; const response = await authenticatedFetch(endpoint, { method: "GET" }); if (!response.ok) { const errorData = await response.json().catch(() => ({ detail: "Export failed" })); @@ -856,9 +849,7 @@ function AuthenticatedDocumentsSidebarBase({ try { const response = await authenticatedFetch( - buildBackendUrl(`/api/v1/search-spaces/${searchSpaceId}/documents/${doc.id}/export`, { - format, - }), + `${BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${doc.id}/export?format=${format}`, { method: "GET" } ); @@ -1037,8 +1028,8 @@ function AuthenticatedDocumentsSidebarBase({ } const endpoint = doc.document_type === "USER_MEMORY" - ? buildBackendUrl("/api/v1/users/me/memory/reset") - : buildBackendUrl(`/api/v1/searchspaces/${searchSpaceId}/memory/reset`); + ? `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/users/me/memory/reset` + : `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/searchspaces/${searchSpaceId}/memory/reset`; try { const response = await authenticatedFetch(endpoint, { method: "POST" }); if (!response.ok) { @@ -1158,7 +1149,6 @@ function AuthenticatedDocumentsSidebarBase({ const showCloudSkeleton = currentFilesystemTab === "cloud" && (zeroFoldersResult.type !== "complete" || zeroAllDocsResult.type !== "complete"); - const connectorButtonLabel = connectorCount > 0 ? "Manage connectors" : "Connect your connectors"; const cloudContent = ( <> @@ -1171,7 +1161,9 @@ function AuthenticatedDocumentsSidebarBase({ className="shrink-0 mx-4 mt-6 mb-2.5 h-auto select-none justify-start gap-2 bg-muted px-3 py-1.5 text-xs text-muted-foreground" > <Unplug className="size-4 shrink-0" /> - <span className="truncate">{connectorButtonLabel}</span> + <span className="truncate"> + {connectorCount > 0 ? "Manage connectors" : "Connect your connectors"} + </span> {connectorCount > 0 && ( <span className="shrink-0 rounded-full bg-muted-foreground/15 px-1.5 py-0.5 text-[10px] font-medium text-muted-foreground"> {connectorCount} diff --git a/surfsense_web/components/layout/ui/sidebar/InboxSidebar.tsx b/surfsense_web/components/layout/ui/sidebar/InboxSidebar.tsx index 3785dc649..f757db70e 100644 --- a/surfsense_web/components/layout/ui/sidebar/InboxSidebar.tsx +++ b/surfsense_web/components/layout/ui/sidebar/InboxSidebar.tsx @@ -48,8 +48,8 @@ import { isCommentReplyMetadata, isConnectorIndexingMetadata, isDocumentProcessingMetadata, - isInsufficientCreditsMetadata, isNewMentionMetadata, + isPageLimitExceededMetadata, } from "@/contracts/types/inbox.types"; import { useDebouncedValue } from "@/hooks/use-debounced-value"; import type { InboxItem } from "@/hooks/use-inbox"; @@ -291,7 +291,7 @@ export function InboxSidebarContent({ (item: InboxItem): boolean => { if (activeFilter === "unread") return !item.read; if (activeFilter === "errors") { - if (item.type === "insufficient_credits") return true; + if (item.type === "page_limit_exceeded") return true; const meta = item.metadata as Record<string, unknown> | undefined; return typeof meta?.status === "string" && meta.status === "failed"; } @@ -397,8 +397,8 @@ export function InboxSidebarContent({ router.push(url); } } - } else if (item.type === "insufficient_credits") { - if (isInsufficientCreditsMetadata(item.metadata)) { + } else if (item.type === "page_limit_exceeded") { + if (isPageLimitExceededMetadata(item.metadata)) { const actionUrl = item.metadata.action_url; if (actionUrl) { onOpenChange(false); @@ -470,7 +470,7 @@ export function InboxSidebarContent({ ); } - if (item.type === "insufficient_credits") { + if (item.type === "page_limit_exceeded") { return ( <div className="h-8 w-8 flex items-center justify-center rounded-full bg-amber-500/10"> <AlertTriangle className="h-4 w-4 text-amber-500" /> diff --git a/surfsense_web/components/layout/ui/sidebar/PageUsageDisplay.tsx b/surfsense_web/components/layout/ui/sidebar/PageUsageDisplay.tsx new file mode 100644 index 000000000..3d011b762 --- /dev/null +++ b/surfsense_web/components/layout/ui/sidebar/PageUsageDisplay.tsx @@ -0,0 +1,24 @@ +"use client"; + +import { Progress } from "@/components/ui/progress"; + +interface PageUsageDisplayProps { + pagesUsed: number; + pagesLimit: number; +} + +export function PageUsageDisplay({ pagesUsed, pagesLimit }: PageUsageDisplayProps) { + const usagePercentage = (pagesUsed / pagesLimit) * 100; + + return ( + <div className="space-y-1.5"> + <div className="flex justify-between items-center text-xs"> + <span className="text-muted-foreground"> + {pagesUsed.toLocaleString()} / {pagesLimit.toLocaleString()} pages + </span> + <span className="font-medium">{usagePercentage.toFixed(0)}%</span> + </div> + <Progress value={usagePercentage} className="h-1.5" /> + </div> + ); +} diff --git a/surfsense_web/components/layout/ui/sidebar/PremiumTokenUsageDisplay.tsx b/surfsense_web/components/layout/ui/sidebar/PremiumTokenUsageDisplay.tsx new file mode 100644 index 000000000..983672d0b --- /dev/null +++ b/surfsense_web/components/layout/ui/sidebar/PremiumTokenUsageDisplay.tsx @@ -0,0 +1,49 @@ +"use client"; + +import { useQuery } from "@rocicorp/zero/react"; +import { Progress } from "@/components/ui/progress"; +import { useIsAnonymous } from "@/contexts/anonymous-mode"; +import { queries } from "@/zero/queries"; + +/** + * Premium credit balance shown in the sidebar. + * + * Values come from Zero (live-replicated from Postgres) and are stored as + * integer micro-USD (1_000_000 == $1.00). We render in dollars because + * users top up at $1/pack and the credit gets debited at actual provider + * cost. + */ +export function PremiumTokenUsageDisplay() { + const isAnonymous = useIsAnonymous(); + const [me] = useQuery(queries.user.me({})); + + if (isAnonymous || !me) return null; + + const usagePercentage = Math.min( + (me.premiumCreditMicrosUsed / Math.max(me.premiumCreditMicrosLimit, 1)) * 100, + 100 + ); + + const formatUsd = (micros: number) => { + const dollars = micros / 1_000_000; + if (dollars >= 100) return `$${dollars.toFixed(0)}`; + if (dollars >= 1) return `$${dollars.toFixed(2)}`; + // Sub-dollar balances need extra precision so the bar still tells the + // user what's left ("$0.04 of credit") instead of rounding to "$0". + if (dollars > 0) return `$${dollars.toFixed(3)}`; + return "$0"; + }; + + return ( + <div className="space-y-1.5"> + <div className="flex justify-between items-center text-xs"> + <span className="text-muted-foreground"> + {formatUsd(me.premiumCreditMicrosUsed)} / {formatUsd(me.premiumCreditMicrosLimit)} of + credit + </span> + <span className="font-medium">{usagePercentage.toFixed(0)}%</span> + </div> + <Progress value={usagePercentage} className="h-1.5 [&>div]:bg-purple-500" /> + </div> + ); +} diff --git a/surfsense_web/components/layout/ui/sidebar/Sidebar.tsx b/surfsense_web/components/layout/ui/sidebar/Sidebar.tsx index ee891d78b..6a4785d98 100644 --- a/surfsense_web/components/layout/ui/sidebar/Sidebar.tsx +++ b/surfsense_web/components/layout/ui/sidebar/Sidebar.tsx @@ -1,6 +1,6 @@ "use client"; -import { CreditCard, SquarePen, Zap } from "lucide-react"; +import { CreditCard, Dot, SquarePen, Zap } from "lucide-react"; import Link from "next/link"; import { useParams } from "next/navigation"; import { useTranslations } from "next-intl"; @@ -13,9 +13,10 @@ import { useIsAnonymous } from "@/contexts/anonymous-mode"; import { cn } from "@/lib/utils"; import { SIDEBAR_MIN_WIDTH } from "../../hooks/useSidebarResize"; import type { ChatItem, NavItem, PageUsage, SearchSpace, User } from "../../types/layout.types"; +import { AuthenticatedPageUsageDisplay } from "./AuthenticatedPageUsageDisplay"; import { ChatListItem } from "./ChatListItem"; -import { CreditBalanceDisplay } from "./CreditBalanceDisplay"; import { NavSection } from "./NavSection"; +import { PremiumTokenUsageDisplay } from "./PremiumTokenUsageDisplay"; import { SidebarButton } from "./SidebarButton"; import { SidebarCollapseButton } from "./SidebarCollapseButton"; import { SidebarHeader } from "./SidebarHeader"; @@ -403,16 +404,17 @@ function SidebarUsageFooter({ return ( <div className={containerClass}> - <CreditBalanceDisplay /> + <PremiumTokenUsageDisplay /> + <AuthenticatedPageUsageDisplay /> <div className="space-y-0.5"> <Link - href={`/dashboard/${searchSpaceId}/earn-credits`} + href={`/dashboard/${searchSpaceId}/more-pages`} onClick={onNavigate} className="group flex w-full items-center justify-between rounded-md px-1.5 py-1 transition-colors hover:bg-accent" > <span className="flex items-center gap-1.5 text-xs text-muted-foreground group-hover:text-accent-foreground"> <Zap className="h-3 w-3 shrink-0" /> - Earn credits + Get Free Pages </span> <Badge className="h-4 rounded px-1 text-[10px] font-semibold leading-none bg-emerald-600 text-white border-transparent hover:bg-emerald-600"> FREE @@ -425,7 +427,12 @@ function SidebarUsageFooter({ > <span className="flex items-center gap-1.5 text-xs text-muted-foreground group-hover:text-accent-foreground"> <CreditCard className="h-3 w-3 shrink-0" /> - Buy credits + Buy More + </span> + <span className="flex items-center text-[10px] font-medium text-muted-foreground"> + $1/1k + <Dot className="h-3 w-3" /> + $1/1M </span> </Link> </div> diff --git a/surfsense_web/components/layout/ui/sidebar/index.ts b/surfsense_web/components/layout/ui/sidebar/index.ts index fcfe2252d..e25149b06 100644 --- a/surfsense_web/components/layout/ui/sidebar/index.ts +++ b/surfsense_web/components/layout/ui/sidebar/index.ts @@ -1,10 +1,10 @@ export { AllChatsSidebar, AllChatsSidebarContent } from "./AllChatsSidebar"; export { ChatListItem } from "./ChatListItem"; -export { CreditBalanceDisplay } from "./CreditBalanceDisplay"; export { DocumentsSidebar } from "./DocumentsSidebar"; export { InboxSidebar, InboxSidebarContent } from "./InboxSidebar"; export { MobileSidebar, MobileSidebarTrigger } from "./MobileSidebar"; export { NavSection } from "./NavSection"; +export { PageUsageDisplay } from "./PageUsageDisplay"; export { Sidebar } from "./Sidebar"; export { SidebarCollapseButton } from "./SidebarCollapseButton"; export { SidebarHeader } from "./SidebarHeader"; diff --git a/surfsense_web/components/layout/ui/tabs/DocumentTabContent.tsx b/surfsense_web/components/layout/ui/tabs/DocumentTabContent.tsx index d50d28a3c..61b8c3e25 100644 --- a/surfsense_web/components/layout/ui/tabs/DocumentTabContent.tsx +++ b/surfsense_web/components/layout/ui/tabs/DocumentTabContent.tsx @@ -11,7 +11,7 @@ import { Alert, AlertDescription } from "@/components/ui/alert"; import { Button } from "@/components/ui/button"; import { Spinner } from "@/components/ui/spinner"; import { authenticatedFetch, getBearerToken, redirectToLogin } from "@/lib/auth-utils"; -import { buildBackendUrl } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; const LARGE_DOCUMENT_THRESHOLD = 2 * 1024 * 1024; // 2MB @@ -108,12 +108,10 @@ export function DocumentTabContent({ documentId, searchSpaceId, title }: Documen } try { - const response = await authenticatedFetch( - buildBackendUrl( - `/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/editor-content` - ), - { method: "GET" } + const url = new URL( + `${BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/editor-content` ); + const response = await authenticatedFetch(url.toString(), { method: "GET" }); if (controller.signal.aborted) return; @@ -167,7 +165,7 @@ export function DocumentTabContent({ documentId, searchSpaceId, title }: Documen setSaving(true); try { const response = await authenticatedFetch( - buildBackendUrl(`/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/save`), + `${BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/save`, { method: "POST", headers: { "Content-Type": "application/json" }, @@ -325,9 +323,7 @@ export function DocumentTabContent({ documentId, searchSpaceId, title }: Documen setDownloading(true); try { const response = await authenticatedFetch( - buildBackendUrl( - `/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/download-markdown` - ), + `${BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/download-markdown`, { method: "GET" } ); if (!response.ok) throw new Error("Download failed"); diff --git a/surfsense_web/components/new-chat/chat-example-prompts.tsx b/surfsense_web/components/new-chat/chat-example-prompts.tsx index 344176629..98d95b98b 100644 --- a/surfsense_web/components/new-chat/chat-example-prompts.tsx +++ b/surfsense_web/components/new-chat/chat-example-prompts.tsx @@ -1,12 +1,12 @@ "use client"; import { - AlarmClock, FilePlus2, - type LucideIcon, Search, Settings2, + type LucideIcon, WandSparkles, + Workflow, X, } from "lucide-react"; import { memo, useCallback, useState } from "react"; @@ -22,7 +22,7 @@ interface ChatExamplePromptsProps { const CATEGORY_ICONS: Record<string, LucideIcon> = { search: Search, create: FilePlus2, - automate: AlarmClock, + automate: Workflow, tools: Settings2, }; diff --git a/surfsense_web/components/new-chat/chat-header.tsx b/surfsense_web/components/new-chat/chat-header.tsx index 99d56eb02..4716418ee 100644 --- a/surfsense_web/components/new-chat/chat-header.tsx +++ b/surfsense_web/components/new-chat/chat-header.tsx @@ -1,23 +1,167 @@ "use client"; -import { ImageModelSelector } from "./image-model-selector"; +import { useCallback, useState } from "react"; +import { ImageConfigDialog } from "@/components/shared/image-config-dialog"; +import { ModelConfigDialog } from "@/components/shared/model-config-dialog"; +import { VisionConfigDialog } from "@/components/shared/vision-config-dialog"; +import type { + GlobalImageGenConfig, + GlobalNewLLMConfig, + GlobalVisionLLMConfig, + ImageGenerationConfig, + NewLLMConfigPublic, + VisionLLMConfig, +} from "@/contracts/types/new-llm-config.types"; import { ModelSelector } from "./model-selector"; interface ChatHeaderProps { searchSpaceId: number; className?: string; - onChatModelSelected?: () => void; } -export function ChatHeader({ searchSpaceId, className, onChatModelSelected }: ChatHeaderProps) { +export function ChatHeader({ searchSpaceId, className }: ChatHeaderProps) { + // LLM config dialog state + const [dialogOpen, setDialogOpen] = useState(false); + const [selectedConfig, setSelectedConfig] = useState< + NewLLMConfigPublic | GlobalNewLLMConfig | null + >(null); + const [isGlobal, setIsGlobal] = useState(false); + const [dialogMode, setDialogMode] = useState<"create" | "edit" | "view">("view"); + + // Image config dialog state + const [imageDialogOpen, setImageDialogOpen] = useState(false); + const [selectedImageConfig, setSelectedImageConfig] = useState< + ImageGenerationConfig | GlobalImageGenConfig | null + >(null); + const [isImageGlobal, setIsImageGlobal] = useState(false); + const [imageDialogMode, setImageDialogMode] = useState<"create" | "edit" | "view">("view"); + + // Vision config dialog state + const [visionDialogOpen, setVisionDialogOpen] = useState(false); + const [selectedVisionConfig, setSelectedVisionConfig] = useState< + VisionLLMConfig | GlobalVisionLLMConfig | null + >(null); + const [isVisionGlobal, setIsVisionGlobal] = useState(false); + const [visionDialogMode, setVisionDialogMode] = useState<"create" | "edit" | "view">("view"); + + // Default provider for create dialogs + const [defaultLLMProvider, setDefaultLLMProvider] = useState<string | undefined>(); + const [defaultImageProvider, setDefaultImageProvider] = useState<string | undefined>(); + const [defaultVisionProvider, setDefaultVisionProvider] = useState<string | undefined>(); + + // LLM handlers + const handleEditLLMConfig = useCallback( + (config: NewLLMConfigPublic | GlobalNewLLMConfig, global: boolean) => { + setSelectedConfig(config); + setIsGlobal(global); + setDialogMode(global ? "view" : "edit"); + setDefaultLLMProvider(undefined); + setDialogOpen(true); + }, + [] + ); + + const handleAddNewLLM = useCallback((provider?: string) => { + setSelectedConfig(null); + setIsGlobal(false); + setDialogMode("create"); + setDefaultLLMProvider(provider); + setDialogOpen(true); + }, []); + + const handleDialogClose = useCallback((open: boolean) => { + setDialogOpen(open); + if (!open) setSelectedConfig(null); + }, []); + + // Image model handlers + const handleAddImageModel = useCallback((provider?: string) => { + setSelectedImageConfig(null); + setIsImageGlobal(false); + setImageDialogMode("create"); + setDefaultImageProvider(provider); + setImageDialogOpen(true); + }, []); + + const handleEditImageConfig = useCallback( + (config: ImageGenerationConfig | GlobalImageGenConfig, global: boolean) => { + setSelectedImageConfig(config); + setIsImageGlobal(global); + setImageDialogMode(global ? "view" : "edit"); + setDefaultImageProvider(undefined); + setImageDialogOpen(true); + }, + [] + ); + + const handleImageDialogClose = useCallback((open: boolean) => { + setImageDialogOpen(open); + if (!open) setSelectedImageConfig(null); + }, []); + + // Vision model handlers + const handleAddVisionModel = useCallback((provider?: string) => { + setSelectedVisionConfig(null); + setIsVisionGlobal(false); + setVisionDialogMode("create"); + setDefaultVisionProvider(provider); + setVisionDialogOpen(true); + }, []); + + const handleEditVisionConfig = useCallback( + (config: VisionLLMConfig | GlobalVisionLLMConfig, global: boolean) => { + setSelectedVisionConfig(config); + setIsVisionGlobal(global); + setVisionDialogMode(global ? "view" : "edit"); + setDefaultVisionProvider(undefined); + setVisionDialogOpen(true); + }, + [] + ); + + const handleVisionDialogClose = useCallback((open: boolean) => { + setVisionDialogOpen(open); + if (!open) setSelectedVisionConfig(null); + }, []); + return ( - <div className="flex min-w-0 shrink-0 items-center gap-2"> + <div className="flex items-center gap-2"> <ModelSelector - searchSpaceId={searchSpaceId} + onEditLLM={handleEditLLMConfig} + onAddNewLLM={handleAddNewLLM} + onEditImage={handleEditImageConfig} + onAddNewImage={handleAddImageModel} + onEditVision={handleEditVisionConfig} + onAddNewVision={handleAddVisionModel} className={className} - onChatModelSelected={onChatModelSelected} /> - <ImageModelSelector searchSpaceId={searchSpaceId} className={className} mobileIconOnly /> + <ModelConfigDialog + open={dialogOpen} + onOpenChange={handleDialogClose} + config={selectedConfig} + isGlobal={isGlobal} + searchSpaceId={searchSpaceId} + mode={dialogMode} + defaultProvider={defaultLLMProvider} + /> + <ImageConfigDialog + open={imageDialogOpen} + onOpenChange={handleImageDialogClose} + config={selectedImageConfig} + isGlobal={isImageGlobal} + searchSpaceId={searchSpaceId} + mode={imageDialogMode} + defaultProvider={defaultImageProvider} + /> + <VisionConfigDialog + open={visionDialogOpen} + onOpenChange={handleVisionDialogClose} + config={selectedVisionConfig} + isGlobal={isVisionGlobal} + searchSpaceId={searchSpaceId} + mode={visionDialogMode} + defaultProvider={defaultVisionProvider} + /> </div> ); } diff --git a/surfsense_web/components/new-chat/image-model-selector.tsx b/surfsense_web/components/new-chat/image-model-selector.tsx deleted file mode 100644 index 5cd898afc..000000000 --- a/surfsense_web/components/new-chat/image-model-selector.tsx +++ /dev/null @@ -1,311 +0,0 @@ -"use client"; - -import { useAtom, useAtomValue } from "jotai"; -import { Check, ChevronDown, ImagePlus, Search, SlidersHorizontal } from "lucide-react"; -import { useRouter } from "next/navigation"; -import type { UIEvent } from "react"; -import { useCallback, useMemo, useState } from "react"; -import { updateModelRolesMutationAtom } from "@/atoms/model-connections/model-connections-mutation.atoms"; -import { - globalModelConnectionsAtom, - modelConnectionsAtom, - modelRolesAtom, -} from "@/atoms/model-connections/model-connections-query.atoms"; -import { Badge } from "@/components/ui/badge"; -import { Button } from "@/components/ui/button"; -import { - Drawer, - DrawerContent, - DrawerHandle, - DrawerHeader, - DrawerTitle, - DrawerTrigger, -} from "@/components/ui/drawer"; -import { Input } from "@/components/ui/input"; -import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; -import { Spinner } from "@/components/ui/spinner"; -import type { ConnectionRead, ModelRead } from "@/contracts/types/model-connections.types"; -import { useIsMobile } from "@/hooks/use-mobile"; -import { AUTO_PROVIDER_ICON_KEY, getProviderIcon } from "@/lib/provider-icons"; -import { cn } from "@/lib/utils"; -import { providerDisplay } from "../settings/model-connections/provider-metadata"; - -interface ImageModelSelectorProps { - searchSpaceId: number; - className?: string; - mobileIconOnly?: boolean; -} - -type ImageModel = ModelRead & { - connectionId: number; - connectionLabel: string; - connectionScope: string; - provider: string; -}; - -const AUTO_IMAGE_MODEL_ID = 0; - -function connectionLabel(connection: ConnectionRead) { - if (connection.scope === "GLOBAL") return "Global"; - return providerDisplay(connection.provider).name; -} - -function flattenImageModels(connections: ConnectionRead[]) { - return connections.flatMap((connection) => - connection.models - .filter((model) => model.enabled && Boolean(model.supports_image_generation)) - .map((model) => ({ - ...model, - connectionId: connection.id, - connectionLabel: connectionLabel(connection), - connectionScope: connection.scope, - provider: connection.provider, - })) - ); -} - -function isFreeGlobalModel(model: ImageModel) { - return model.connectionScope === "GLOBAL" && model.billing_tier?.toLowerCase() === "free"; -} - -function modelName(model: ImageModel) { - const name = model.display_name || model.model_id; - if (model.connectionScope === "GLOBAL") { - return name.replace(/\s+\(free\)$/i, ""); - } - return name; -} - -function filterImageModels(models: ImageModel[], search: string) { - const normalized = search.trim().toLowerCase(); - if (!normalized) return models; - return models.filter((model) => - [modelName(model), model.model_id, model.connectionLabel] - .join(" ") - .toLowerCase() - .includes(normalized) - ); -} - -function groupedModels(models: ImageModel[]) { - return models.reduce<Record<string, ImageModel[]>>((groups, model) => { - const key = model.connectionLabel; - if (!groups[key]) groups[key] = []; - groups[key].push(model); - return groups; - }, {}); -} - -export function ImageModelSelector({ - searchSpaceId, - className, - mobileIconOnly = false, -}: ImageModelSelectorProps) { - const router = useRouter(); - const isMobile = useIsMobile(); - const [open, setOpen] = useState(false); - const [search, setSearch] = useState(""); - const [scrollPos, setScrollPos] = useState<"top" | "middle" | "bottom">("top"); - const [{ data: globalConnections = [], isLoading: globalLoading }] = useAtom( - globalModelConnectionsAtom - ); - const [{ data: connections = [], isLoading: connectionsLoading }] = useAtom(modelConnectionsAtom); - const [{ data: roles }] = useAtom(modelRolesAtom); - const updateRoles = useAtomValue(updateModelRolesMutationAtom); - - const allImageModels = useMemo( - () => flattenImageModels([...globalConnections, ...connections]), - [globalConnections, connections] - ); - - const visibleImageModels = useMemo( - () => filterImageModels(allImageModels, search), - [allImageModels, search] - ); - const imageModelsById = useMemo( - () => new Map(allImageModels.map((model) => [model.id, model])), - [allImageModels] - ); - const selectedModelId = roles?.image_gen_model_id ?? AUTO_IMAGE_MODEL_ID; - const selected = imageModelsById.get(selectedModelId); - const groups = useMemo(() => groupedModels(visibleImageModels), [visibleImageModels]); - const loading = globalLoading || connectionsLoading; - const hasSearchQuery = search.trim().length > 0; - const showIconOnlyTrigger = isMobile && mobileIconOnly; - - function handleOpenChange(nextOpen: boolean) { - if (!nextOpen) setSearch(""); - setOpen(nextOpen); - } - - function selectModel(modelId: number) { - updateRoles.mutate({ image_gen_model_id: modelId }); - setSearch(""); - setOpen(false); - } - - function manageModelConnections() { - setOpen(false); - router.push(`/dashboard/${searchSpaceId}/search-space-settings/models`); - } - - const handleScroll = useCallback((event: UIEvent<HTMLDivElement>) => { - const el = event.currentTarget; - const atTop = el.scrollTop <= 2; - const atBottom = el.scrollHeight - el.scrollTop - el.clientHeight <= 2; - setScrollPos(atTop ? "top" : atBottom ? "bottom" : "middle"); - }, []); - - // Only surface this control when usable image-generation models exist. - if (!loading && allImageModels.length === 0) { - return null; - } - - const content = ( - <div className="flex h-[320px] select-none flex-col overflow-hidden"> - <div className="p-2"> - <div className="relative"> - <Search className="absolute left-0.5 top-1/2 h-4 w-4 -translate-y-1/2 text-muted-foreground" /> - <Input - value={search} - onChange={(event) => setSearch(event.target.value)} - placeholder="Search image models" - className="h-8 border-0 bg-transparent pl-6 text-sm shadow-none" - /> - </div> - </div> - <div - className="min-h-0 flex-1 overflow-y-auto overflow-x-hidden px-1.5 py-1.5" - onScroll={handleScroll} - style={{ - maskImage: `linear-gradient(to bottom, ${scrollPos === "top" ? "black" : "transparent"}, black 16px, black calc(100% - 16px), ${scrollPos === "bottom" ? "black" : "transparent"})`, - WebkitMaskImage: `linear-gradient(to bottom, ${scrollPos === "top" ? "black" : "transparent"}, black 16px, black calc(100% - 16px), ${scrollPos === "bottom" ? "black" : "transparent"})`, - }} - > - <button - type="button" - className="flex w-full items-center justify-between rounded-md px-3 py-2 text-left transition-colors hover:bg-accent hover:text-accent-foreground" - onClick={() => selectModel(AUTO_IMAGE_MODEL_ID)} - > - <div className="min-w-0 flex-1"> - <div className="flex min-w-0 items-center gap-2 font-medium"> - {getProviderIcon(AUTO_PROVIDER_ICON_KEY, { className: "size-4 shrink-0" })} - <span className="truncate">Auto</span> - </div> - </div> - {selectedModelId === AUTO_IMAGE_MODEL_ID ? <Check className="h-4 w-4" /> : null} - </button> - {loading ? ( - <div className="flex items-center justify-center py-8"> - <Spinner /> - </div> - ) : Object.keys(groups).length === 0 ? ( - <div className="px-3 py-8 text-center text-sm text-muted-foreground"> - {hasSearchQuery - ? "No matching image models." - : "No enabled image models. Add or enable models in Settings."} - </div> - ) : ( - Object.entries(groups).map(([connection, models]) => ( - <div key={connection} className="mt-3"> - <div className="px-2 py-1 text-sm font-semibold text-muted-foreground"> - {connection} - </div> - {models.map((model) => ( - <button - type="button" - key={model.id} - className="flex w-full items-center justify-between rounded-md px-3 py-2 text-left transition-colors hover:bg-accent hover:text-accent-foreground" - onClick={() => selectModel(model.id)} - > - <div className="min-w-0 flex-1"> - <div className="flex min-w-0 items-center gap-2 font-medium"> - {getProviderIcon(model.provider, { className: "size-4 shrink-0" })} - <span className="truncate">{modelName(model)}</span> - </div> - </div> - <div className="ml-3 flex shrink-0 items-center gap-2"> - {isFreeGlobalModel(model) ? ( - <Badge - variant="secondary" - className="h-5 shrink-0 rounded-sm border-0 bg-popover-foreground/10 px-1.5 text-[11px] text-popover-foreground hover:bg-popover-foreground/10" - > - Free - </Badge> - ) : null} - {roles?.image_gen_model_id === model.id ? <Check className="h-4 w-4" /> : null} - </div> - </button> - ))} - </div> - )) - )} - </div> - <div className="p-2"> - <Button - variant="ghost" - className="w-full justify-start rounded-md bg-foreground/5 hover:bg-foreground/10 hover:text-foreground" - onClick={manageModelConnections} - > - <SlidersHorizontal className="h-4 w-4" /> Manage models - </Button> - </div> - </div> - ); - - const trigger = ( - <Button - type="button" - variant="ghost" - size="sm" - aria-label="Select image model" - className={cn( - "h-8 min-w-0 gap-2 rounded-md px-3 text-muted-foreground transition-colors", - "select-none", - "hover:bg-foreground/10 hover:text-foreground", - "data-[state=open]:bg-foreground/10 data-[state=open]:text-foreground", - className, - showIconOnlyTrigger && "h-9 w-auto shrink-0 justify-center gap-1 px-2" - )} - > - {selected ? ( - getProviderIcon(selected.provider, { className: "size-4 shrink-0" }) - ) : ( - <ImagePlus className="size-4 shrink-0" /> - )} - {showIconOnlyTrigger ? null : ( - <span className="min-w-0 flex-1 truncate text-sm"> - {selected ? modelName(selected) : "Auto"} - </span> - )} - <ChevronDown className="h-3.5 w-3.5 shrink-0" /> - </Button> - ); - - if (isMobile) { - return ( - <Drawer open={open} onOpenChange={handleOpenChange}> - <DrawerTrigger asChild>{trigger}</DrawerTrigger> - <DrawerContent className="max-h-[85vh]"> - <DrawerHandle /> - <DrawerHeader> - <DrawerTitle>Select Image Model</DrawerTitle> - </DrawerHeader> - {content} - </DrawerContent> - </Drawer> - ); - } - - return ( - <Popover open={open} onOpenChange={handleOpenChange}> - <PopoverTrigger asChild>{trigger}</PopoverTrigger> - <PopoverContent - align="start" - className="w-[340px] border border-popover-border bg-popover p-0 text-popover-foreground shadow-md" - > - {content} - </PopoverContent> - </Popover> - ); -} diff --git a/surfsense_web/components/new-chat/model-selector.tsx b/surfsense_web/components/new-chat/model-selector.tsx index c10bfd862..0a096f5f8 100644 --- a/surfsense_web/components/new-chat/model-selector.tsx +++ b/surfsense_web/components/new-chat/model-selector.tsx @@ -1,16 +1,40 @@ "use client"; -import { useAtom, useAtomValue } from "jotai"; -import { Check, ChevronDown, Search, SlidersHorizontal } from "lucide-react"; -import { useRouter } from "next/navigation"; -import type { UIEvent } from "react"; -import { useCallback, useMemo, useState } from "react"; -import { updateModelRolesMutationAtom } from "@/atoms/model-connections/model-connections-mutation.atoms"; +import { useAtomValue } from "jotai"; import { - globalModelConnectionsAtom, - modelConnectionsAtom, - modelRolesAtom, -} from "@/atoms/model-connections/model-connections-query.atoms"; + Bot, + Check, + ChevronDown, + ChevronLeft, + ChevronRight, + ChevronUp, + ImageIcon, + Layers, + Pencil, + Plus, + ScanEye, + Search, + Zap, +} from "lucide-react"; +import type React from "react"; +import { Fragment, useCallback, useEffect, useMemo, useRef, useState } from "react"; +import { toast } from "sonner"; +import { pendingUserImageDataUrlsAtom } from "@/atoms/chat/pending-user-images.atom"; +import { + globalImageGenConfigsAtom, + imageGenConfigsAtom, +} from "@/atoms/image-gen-config/image-gen-config-query.atoms"; +import { updateLLMPreferencesMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; +import { + globalNewLLMConfigsAtom, + llmPreferencesAtom, + newLLMConfigsAtom, +} from "@/atoms/new-llm-config/new-llm-config-query.atoms"; +import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms"; +import { + globalVisionLLMConfigsAtom, + visionLLMConfigsAtom, +} from "@/atoms/vision-llm-config/vision-llm-config-query.atoms"; import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; import { @@ -21,288 +45,1389 @@ import { DrawerTitle, DrawerTrigger, } from "@/components/ui/drawer"; -import { Input } from "@/components/ui/input"; import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; import { Spinner } from "@/components/ui/spinner"; -import type { ConnectionRead, ModelRead } from "@/contracts/types/model-connections.types"; +import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; +import type { + GlobalImageGenConfig, + GlobalNewLLMConfig, + GlobalVisionLLMConfig, + ImageGenerationConfig, + NewLLMConfigPublic, + VisionLLMConfig, +} from "@/contracts/types/new-llm-config.types"; import { useIsMobile } from "@/hooks/use-mobile"; -import { AUTO_PROVIDER_ICON_KEY, getProviderIcon } from "@/lib/provider-icons"; +import { getProviderIcon } from "@/lib/provider-icons"; import { cn } from "@/lib/utils"; -import { providerDisplay } from "../settings/model-connections/provider-metadata"; -interface ModelSelectorProps { - searchSpaceId: number; - className?: string; - onChatModelSelected?: () => void; -} +// ─── Helpers ──────────────────────────────────────────────────────── -type ChatModel = ModelRead & { - connectionId: number; - connectionLabel: string; - connectionScope: string; - provider: string; +const PROVIDER_NAMES: Record<string, string> = { + OPENAI: "OpenAI", + ANTHROPIC: "Anthropic", + GOOGLE: "Google", + AZURE: "Azure", + AZURE_OPENAI: "Azure OpenAI", + AWS_BEDROCK: "AWS Bedrock", + BEDROCK: "Bedrock", + DEEPSEEK: "DeepSeek", + MISTRAL: "Mistral", + COHERE: "Cohere", + GITHUB_MODELS: "GitHub Models", + GROQ: "Groq", + OLLAMA: "Ollama", + TOGETHER_AI: "Together AI", + FIREWORKS_AI: "Fireworks AI", + REPLICATE: "Replicate", + HUGGINGFACE: "HuggingFace", + PERPLEXITY: "Perplexity", + XAI: "xAI", + OPENROUTER: "OpenRouter", + CEREBRAS: "Cerebras", + SAMBANOVA: "SambaNova", + VERTEX_AI: "Vertex AI", + MINIMAX: "MiniMax", + MOONSHOT: "Moonshot", + ZHIPU: "Zhipu", + DEEPINFRA: "DeepInfra", + CLOUDFLARE: "Cloudflare", + DATABRICKS: "Databricks", + NSCALE: "NScale", + RECRAFT: "Recraft", + XINFERENCE: "XInference", + CUSTOM: "Custom", + AI21: "AI21", + ALIBABA_QWEN: "Qwen", + ANYSCALE: "Anyscale", + COMETAPI: "CometAPI", }; -const AUTO_CHAT_MODEL_ID = 0; +// Provider keys valid per model type, matching backend enums +// (LiteLLMProvider, ImageGenProvider, VisionProvider in db.py) +const LLM_PROVIDER_KEYS: string[] = [ + "OPENAI", + "ANTHROPIC", + "GOOGLE", + "AZURE_OPENAI", + "BEDROCK", + "VERTEX_AI", + "GROQ", + "DEEPSEEK", + "XAI", + "MISTRAL", + "COHERE", + "OPENROUTER", + "TOGETHER_AI", + "FIREWORKS_AI", + "REPLICATE", + "PERPLEXITY", + "OLLAMA", + "CEREBRAS", + "SAMBANOVA", + "DEEPINFRA", + "AI21", + "ALIBABA_QWEN", + "MOONSHOT", + "ZHIPU", + "MINIMAX", + "HUGGINGFACE", + "CLOUDFLARE", + "DATABRICKS", + "ANYSCALE", + "COMETAPI", + "GITHUB_MODELS", + "CUSTOM", +]; -function connectionLabel(connection: ConnectionRead) { - if (connection.scope === "GLOBAL") return "Global"; - return providerDisplay(connection.provider).name; -} +const IMAGE_PROVIDER_KEYS: string[] = [ + "OPENAI", + "AZURE_OPENAI", + "GOOGLE", + "VERTEX_AI", + "BEDROCK", + "RECRAFT", + "OPENROUTER", + "XINFERENCE", + "NSCALE", +]; -function flattenChatModels(connections: ConnectionRead[]) { - return connections.flatMap((connection) => - connection.models - .filter((model) => model.enabled && Boolean(model.supports_chat)) - .map((model) => ({ - ...model, - connectionId: connection.id, - connectionLabel: connectionLabel(connection), - connectionScope: connection.scope, - provider: connection.provider, - })) +const VISION_PROVIDER_KEYS: string[] = [ + "OPENAI", + "ANTHROPIC", + "GOOGLE", + "AZURE_OPENAI", + "VERTEX_AI", + "BEDROCK", + "XAI", + "OPENROUTER", + "OLLAMA", + "GROQ", + "TOGETHER_AI", + "FIREWORKS_AI", + "DEEPSEEK", + "MISTRAL", + "CUSTOM", +]; + +const PROVIDER_KEYS_BY_TAB: Record<string, string[]> = { + llm: LLM_PROVIDER_KEYS, + image: IMAGE_PROVIDER_KEYS, + vision: VISION_PROVIDER_KEYS, +}; + +function formatProviderName(provider: string): string { + const key = provider.toUpperCase(); + return ( + PROVIDER_NAMES[key] ?? + provider.charAt(0).toUpperCase() + provider.slice(1).toLowerCase().replace(/_/g, " ") ); } -function isFreeGlobalModel(model: ChatModel) { - return model.connectionScope === "GLOBAL" && model.billing_tier?.toLowerCase() === "free"; +function normalizeText(input: string): string { + return input + .normalize("NFD") + .replace(/\p{Diacritic}/gu, "") + .toLowerCase() + .replace(/[^a-z0-9]+/g, " ") + .trim(); } -function modelName(model: ChatModel) { - const name = model.display_name || model.model_id; - if (model.connectionScope === "GLOBAL") { - return name.replace(/\s+\(free\)$/i, ""); +interface ConfigBase { + id: number; + name: string; + model_name: string; + provider: string; +} + +function filterAndScore<T extends ConfigBase>( + configs: T[], + selectedProvider: string, + searchQuery: string +): T[] { + let result = configs; + + if (selectedProvider !== "all") { + result = result.filter((c) => c.provider.toUpperCase() === selectedProvider); } - return name; + + if (!searchQuery.trim()) return result; + + const normalized = normalizeText(searchQuery); + const tokens = normalized.split(/\s+/).filter(Boolean); + + const scored = result.map((c) => { + const aggregate = normalizeText([c.name, c.model_name, c.provider].join(" ")); + let score = 0; + if (aggregate.includes(normalized)) score += 5; + for (const token of tokens) { + if (aggregate.includes(token)) score += 1; + } + return { config: c, score }; + }); + + return scored + .filter((s) => s.score > 0) + .sort((a, b) => b.score - a.score) + .map((s) => s.config); } -function filterChatModels(models: ChatModel[], search: string) { - const normalized = search.trim().toLowerCase(); - if (!normalized) return models; - return models.filter((model) => - [modelName(model), model.model_id, model.connectionLabel] - .join(" ") - .toLowerCase() - .includes(normalized) +interface DisplayItem { + config: ConfigBase & Record<string, unknown>; + isGlobal: boolean; + isAutoMode: boolean; +} + +const TruncatedNameWithTooltip: React.FC<{ + text: string; + className?: string; + enableTooltip: boolean; +}> = ({ text, className, enableTooltip }) => { + const textRef = useRef<HTMLSpanElement>(null); + const openTimerRef = useRef<number | undefined>(undefined); + const [isTruncated, setIsTruncated] = useState(false); + const [open, setOpen] = useState(false); + + const recalcTruncation = useCallback(() => { + const el = textRef.current; + if (!el) return; + setIsTruncated(el.scrollWidth > el.clientWidth + 1); + }, []); + + useEffect(() => { + if (!enableTooltip) return; + const el = textRef.current; + if (!el) return; + + const raf = requestAnimationFrame(recalcTruncation); + recalcTruncation(); + + const observer = new ResizeObserver(recalcTruncation); + observer.observe(el); + if (el.parentElement) observer.observe(el.parentElement); + window.addEventListener("resize", recalcTruncation); + + return () => { + cancelAnimationFrame(raf); + observer.disconnect(); + window.removeEventListener("resize", recalcTruncation); + }; + }, [enableTooltip, recalcTruncation]); + + useEffect(() => { + // Recompute when row text changes. + void text; + requestAnimationFrame(recalcTruncation); + }, [text, recalcTruncation]); + + useEffect( + () => () => { + if (openTimerRef.current) window.clearTimeout(openTimerRef.current); + }, + [] ); -} -function groupedModels(models: ChatModel[]) { - return models.reduce<Record<string, ChatModel[]>>((groups, model) => { - const key = model.connectionLabel; - if (!groups[key]) groups[key] = []; - groups[key].push(model); - return groups; - }, {}); + if (!enableTooltip) { + return ( + <span ref={textRef} className={cn("block max-w-full", className)}> + {text} + </span> + ); + } + + const handleOpenChange = (nextOpen: boolean) => { + if (openTimerRef.current) { + window.clearTimeout(openTimerRef.current); + openTimerRef.current = undefined; + } + if (!nextOpen) { + setOpen(false); + return; + } + if (!isTruncated) return; + openTimerRef.current = window.setTimeout(() => { + setOpen(true); + openTimerRef.current = undefined; + }, 220); + }; + + return ( + <Tooltip open={open} onOpenChange={handleOpenChange}> + <TooltipTrigger asChild> + <span ref={textRef} className={cn("block max-w-full", className)}> + {text} + </span> + </TooltipTrigger> + <TooltipContent side="top" align="start"> + {text} + </TooltipContent> + </Tooltip> + ); +}; + +// ─── Component ────────────────────────────────────────────────────── + +interface ModelSelectorProps { + onEditLLM: (config: NewLLMConfigPublic | GlobalNewLLMConfig, isGlobal: boolean) => void; + onAddNewLLM: (provider?: string) => void; + onEditImage?: (config: ImageGenerationConfig | GlobalImageGenConfig, isGlobal: boolean) => void; + onAddNewImage?: (provider?: string) => void; + onEditVision?: (config: VisionLLMConfig | GlobalVisionLLMConfig, isGlobal: boolean) => void; + onAddNewVision?: (provider?: string) => void; + className?: string; } export function ModelSelector({ - searchSpaceId, + onEditLLM, + onAddNewLLM, + onEditImage, + onAddNewImage, + onEditVision, + onAddNewVision, className, - onChatModelSelected, }: ModelSelectorProps) { - const router = useRouter(); - const isMobile = useIsMobile(); const [open, setOpen] = useState(false); - const [search, setSearch] = useState(""); - const [scrollPos, setScrollPos] = useState<"top" | "middle" | "bottom">("top"); - const [{ data: globalConnections = [], isLoading: globalLoading }] = useAtom( - globalModelConnectionsAtom - ); - const [{ data: connections = [], isLoading: connectionsLoading }] = useAtom(modelConnectionsAtom); - const [{ data: roles }] = useAtom(modelRolesAtom); - const updateRoles = useAtomValue(updateModelRolesMutationAtom); + const [activeTab, setActiveTab] = useState<"llm" | "image" | "vision">("llm"); + const [searchQuery, setSearchQuery] = useState(""); + const [selectedProvider, setSelectedProvider] = useState<string>("all"); + const [focusedIndex, setFocusedIndex] = useState(-1); + const [modelScrollPos, setModelScrollPos] = useState<"top" | "middle" | "bottom">("top"); + const [sidebarScrollPos, setSidebarScrollPos] = useState<"top" | "middle" | "bottom">("top"); + const providerSidebarRef = useRef<HTMLDivElement>(null); + const modelListRef = useRef<HTMLDivElement>(null); + const searchInputRef = useRef<HTMLInputElement>(null); + const isMobile = useIsMobile(); - const allChatModels = useMemo( - () => flattenChatModels([...globalConnections, ...connections]), - [globalConnections, connections] + const handleOpenChange = useCallback( + (next: boolean) => { + if (next) { + setSearchQuery(""); + setSelectedProvider("all"); + if (!isMobile) { + requestAnimationFrame(() => searchInputRef.current?.focus()); + } + } + setOpen(next); + }, + [isMobile] ); - const visibleChatModels = useMemo( - () => filterChatModels(allChatModels, search), - [allChatModels, search] + const handleTabChange = useCallback( + (next: "llm" | "image" | "vision") => { + setActiveTab(next); + setSelectedProvider("all"); + setSearchQuery(""); + setFocusedIndex(-1); + setModelScrollPos("top"); + if (open && !isMobile) { + requestAnimationFrame(() => searchInputRef.current?.focus()); + } + }, + [open, isMobile] ); - const chatModelsById = useMemo( - () => new Map(allChatModels.map((model) => [model.id, model])), - [allChatModels] - ); - const selectedModelId = roles?.chat_model_id ?? AUTO_CHAT_MODEL_ID; - const selected = chatModelsById.get(selectedModelId); - const groups = useMemo(() => groupedModels(visibleChatModels), [visibleChatModels]); - const loading = globalLoading || connectionsLoading; - const hasSearchQuery = search.trim().length > 0; - const showIconOnlyTrigger = isMobile; - function handleOpenChange(nextOpen: boolean) { - if (!nextOpen) setSearch(""); - setOpen(nextOpen); - } - - function selectModel(modelId: number) { - updateRoles.mutate({ chat_model_id: modelId }); - setSearch(""); - setOpen(false); - requestAnimationFrame(() => { - onChatModelSelected?.(); - }); - } - - function manageModelConnections() { - setOpen(false); - router.push(`/dashboard/${searchSpaceId}/search-space-settings/models`); - } - - const handleScroll = useCallback((event: UIEvent<HTMLDivElement>) => { - const el = event.currentTarget; + const handleModelListScroll = useCallback((e: React.UIEvent<HTMLDivElement>) => { + const el = e.currentTarget; const atTop = el.scrollTop <= 2; const atBottom = el.scrollHeight - el.scrollTop - el.clientHeight <= 2; - setScrollPos(atTop ? "top" : atBottom ? "bottom" : "middle"); + setModelScrollPos(atTop ? "top" : atBottom ? "bottom" : "middle"); }, []); - const content = ( - <div className="flex h-[320px] select-none flex-col overflow-hidden"> - <div className="p-2"> - <div className="relative"> - <Search className="absolute left-0.5 top-1/2 h-4 w-4 -translate-y-1/2 text-muted-foreground" /> - <Input - value={search} - onChange={(event) => setSearch(event.target.value)} - placeholder="Search chat models" - className="h-8 border-0 bg-transparent pl-6 text-sm shadow-none" - /> - </div> - </div> - <div - className="min-h-0 flex-1 overflow-y-auto overflow-x-hidden px-1.5 py-1.5" - onScroll={handleScroll} - style={{ - maskImage: `linear-gradient(to bottom, ${scrollPos === "top" ? "black" : "transparent"}, black 16px, black calc(100% - 16px), ${scrollPos === "bottom" ? "black" : "transparent"})`, - WebkitMaskImage: `linear-gradient(to bottom, ${scrollPos === "top" ? "black" : "transparent"}, black 16px, black calc(100% - 16px), ${scrollPos === "bottom" ? "black" : "transparent"})`, - }} - > - <button - type="button" - className="flex w-full items-center justify-between rounded-md px-3 py-2 text-left transition-colors hover:bg-accent hover:text-accent-foreground" - onClick={() => selectModel(AUTO_CHAT_MODEL_ID)} - > - <div className="min-w-0 flex-1"> - <div className="flex min-w-0 items-center gap-2 font-medium"> - {getProviderIcon(AUTO_PROVIDER_ICON_KEY, { className: "size-4 shrink-0" })} - <span className="truncate">Auto</span> - </div> - </div> - {selectedModelId === AUTO_CHAT_MODEL_ID ? <Check className="h-4 w-4" /> : null} - </button> - {loading ? ( - <div className="flex items-center justify-center py-8"> - <Spinner /> - </div> - ) : Object.keys(groups).length === 0 ? ( - <div className="px-3 py-8 text-center text-sm text-muted-foreground"> - {hasSearchQuery - ? "No matching chat models." - : "No enabled chat models. Add or enable models in Settings."} - </div> - ) : ( - Object.entries(groups).map(([connection, models]) => ( - <div key={connection} className="mt-3"> - <div className="px-2 py-1 text-sm font-semibold text-muted-foreground"> - {connection} - </div> - {models.map((model) => ( - <button - type="button" - key={model.id} - className="flex w-full items-center justify-between rounded-md px-3 py-2 text-left transition-colors hover:bg-accent hover:text-accent-foreground" - onClick={() => selectModel(model.id)} - > - <div className="min-w-0 flex-1"> - <div className="flex min-w-0 items-center gap-2 font-medium"> - {getProviderIcon(model.provider, { className: "size-4 shrink-0" })} - <span className="truncate">{modelName(model)}</span> - </div> - {/* {model.max_input_tokens ? ( - <div className="text-xs text-muted-foreground"> - {model.max_input_tokens.toLocaleString()} context - </div> - ) : null} */} - </div> - <div className="ml-3 flex shrink-0 items-center gap-2"> - {isFreeGlobalModel(model) ? ( - <Badge - variant="secondary" - className="h-5 shrink-0 rounded-sm border-0 bg-popover-foreground/10 px-1.5 text-[11px] text-popover-foreground hover:bg-popover-foreground/10" - > - Free - </Badge> - ) : null} - {/* - Re-enable this once the chat composer supports image input. - For now, surfacing `supports_image_input` in the chat model - selector is misleading because users cannot attach images. - - {!model.supports_image_input ? ( - <Badge variant="outline" className="gap-1"> - <ImageOff className="h-3 w-3" /> No image - </Badge> - ) : null} - */} - {roles?.chat_model_id === model.id ? <Check className="h-4 w-4" /> : null} - </div> - </button> - ))} - </div> - )) - )} - </div> - <div className="p-2"> - <Button - variant="ghost" - className="w-full justify-start rounded-md bg-foreground/5 hover:bg-foreground/10 hover:text-foreground" - onClick={manageModelConnections} - > - <SlidersHorizontal className="h-4 w-4" /> Manage models - </Button> - </div> - </div> + const handleSidebarScroll = useCallback( + (e: React.UIEvent<HTMLDivElement>) => { + const el = e.currentTarget; + if (isMobile) { + const atStart = el.scrollLeft <= 2; + const atEnd = el.scrollWidth - el.scrollLeft - el.clientWidth <= 2; + setSidebarScrollPos(atStart ? "top" : atEnd ? "bottom" : "middle"); + } else { + const atTop = el.scrollTop <= 2; + const atBottom = el.scrollHeight - el.scrollTop - el.clientHeight <= 2; + setSidebarScrollPos(atTop ? "top" : atBottom ? "bottom" : "middle"); + } + }, + [isMobile] ); - const trigger = ( + const scrollProviderSidebar = useCallback( + (direction: "backward" | "forward") => { + const el = providerSidebarRef.current; + if (!el) return; + const delta = isMobile + ? Math.max(56, Math.floor(el.clientWidth * 0.5)) + : Math.max(44, Math.floor(el.clientHeight * 0.4)); + + if (isMobile) { + el.scrollBy({ + left: direction === "backward" ? -delta : delta, + behavior: "smooth", + }); + return; + } + + el.scrollBy({ + top: direction === "backward" ? -delta : delta, + behavior: "smooth", + }); + }, + [isMobile] + ); + + // Cmd/Ctrl+M shortcut (desktop only) + useEffect(() => { + if (isMobile) return; + const handler = (e: KeyboardEvent) => { + if ((e.metaKey || e.ctrlKey) && e.key === "m") { + e.preventDefault(); + // setOpen((prev) => !prev); + handleOpenChange(!open); + } + }; + document.addEventListener("keydown", handler); + return () => document.removeEventListener("keydown", handler); + }, [isMobile, open, handleOpenChange]); + + // ─── Data ─── + const { data: llmUserConfigs, isLoading: llmUserLoading } = useAtomValue(newLLMConfigsAtom); + const { data: llmGlobalConfigs, isLoading: llmGlobalLoading } = + useAtomValue(globalNewLLMConfigsAtom); + const { data: preferences, isLoading: prefsLoading } = useAtomValue(llmPreferencesAtom); + const searchSpaceId = useAtomValue(activeSearchSpaceIdAtom); + const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom); + const { data: imageGlobalConfigs, isLoading: imageGlobalLoading } = + useAtomValue(globalImageGenConfigsAtom); + const { data: imageUserConfigs, isLoading: imageUserLoading } = useAtomValue(imageGenConfigsAtom); + const { data: visionGlobalConfigs, isLoading: visionGlobalLoading } = useAtomValue( + globalVisionLLMConfigsAtom + ); + const { data: visionUserConfigs, isLoading: visionUserLoading } = + useAtomValue(visionLLMConfigsAtom); + + // Pending image attachments on the composer. Used to surface an + // amber "No image" hint on chat models the catalog reports as + // non-vision (`supports_image_input=false`) when the next message + // will carry an image. The hint is purely advisory: selection, + // focus, and click handling are unaffected. The backend's safety + // net (`is_known_text_only_chat_model`) is the actual block, and + // it only fires when LiteLLM *explicitly* marks a model as + // text-only — so a model that's secretly capable but hasn't been + // annotated will still flow through to the provider. + const pendingUserImageUrls = useAtomValue(pendingUserImageDataUrlsAtom); + const hasPendingImages = pendingUserImageUrls.length > 0; + + const isLoading = + llmUserLoading || + llmGlobalLoading || + prefsLoading || + imageGlobalLoading || + imageUserLoading || + visionGlobalLoading || + visionUserLoading; + + // ─── Current selected configs ─── + const currentLLMConfig = useMemo(() => { + if (!preferences) return null; + const id = preferences.agent_llm_id; + if (id === null || id === undefined) return null; + if (id <= 0) return llmGlobalConfigs?.find((c) => c.id === id) ?? null; + return llmUserConfigs?.find((c) => c.id === id) ?? null; + }, [preferences, llmGlobalConfigs, llmUserConfigs]); + + const isLLMAutoMode = + currentLLMConfig && "is_auto_mode" in currentLLMConfig && currentLLMConfig.is_auto_mode; + + const currentImageConfig = useMemo(() => { + if (!preferences) return null; + const id = preferences.image_generation_config_id; + if (id === null || id === undefined) return null; + return ( + imageGlobalConfigs?.find((c) => c.id === id) ?? + imageUserConfigs?.find((c) => c.id === id) ?? + null + ); + }, [preferences, imageGlobalConfigs, imageUserConfigs]); + + const isImageAutoMode = + currentImageConfig && "is_auto_mode" in currentImageConfig && currentImageConfig.is_auto_mode; + + const currentVisionConfig = useMemo(() => { + if (!preferences) return null; + const id = preferences.vision_llm_config_id; + if (id === null || id === undefined) return null; + return ( + visionGlobalConfigs?.find((c) => c.id === id) ?? + visionUserConfigs?.find((c) => c.id === id) ?? + null + ); + }, [preferences, visionGlobalConfigs, visionUserConfigs]); + + const isVisionAutoMode = + currentVisionConfig && + "is_auto_mode" in currentVisionConfig && + currentVisionConfig.is_auto_mode; + + // ─── Filtered configs (separate global / user for section headers) ─── + const filteredLLMGlobal = useMemo( + () => filterAndScore(llmGlobalConfigs ?? [], selectedProvider, searchQuery), + [llmGlobalConfigs, selectedProvider, searchQuery] + ); + const filteredLLMUser = useMemo( + () => filterAndScore(llmUserConfigs ?? [], selectedProvider, searchQuery), + [llmUserConfigs, selectedProvider, searchQuery] + ); + const filteredImageGlobal = useMemo( + () => filterAndScore(imageGlobalConfigs ?? [], selectedProvider, searchQuery), + [imageGlobalConfigs, selectedProvider, searchQuery] + ); + const filteredImageUser = useMemo( + () => filterAndScore(imageUserConfigs ?? [], selectedProvider, searchQuery), + [imageUserConfigs, selectedProvider, searchQuery] + ); + const filteredVisionGlobal = useMemo( + () => filterAndScore(visionGlobalConfigs ?? [], selectedProvider, searchQuery), + [visionGlobalConfigs, selectedProvider, searchQuery] + ); + const filteredVisionUser = useMemo( + () => filterAndScore(visionUserConfigs ?? [], selectedProvider, searchQuery), + [visionUserConfigs, selectedProvider, searchQuery] + ); + + // Combined display list for keyboard navigation + const currentDisplayItems: DisplayItem[] = useMemo(() => { + const toItems = (configs: ConfigBase[], isGlobal: boolean): DisplayItem[] => + configs.map((c) => ({ + config: c as ConfigBase & Record<string, unknown>, + isGlobal, + isAutoMode: + isGlobal && "is_auto_mode" in c && !!(c as Record<string, unknown>).is_auto_mode, + })); + + const sortGlobalItems = (items: DisplayItem[]): DisplayItem[] => + [...items].sort((a, b) => { + if (a.isAutoMode !== b.isAutoMode) return a.isAutoMode ? -1 : 1; + const aPremium = !!(a.config as Record<string, unknown>).is_premium; + const bPremium = !!(b.config as Record<string, unknown>).is_premium; + if (aPremium !== bPremium) return aPremium ? 1 : -1; + return 0; + }); + + switch (activeTab) { + case "llm": + return [ + ...sortGlobalItems(toItems(filteredLLMGlobal, true)), + ...toItems(filteredLLMUser, false), + ]; + case "image": + return [ + ...sortGlobalItems(toItems(filteredImageGlobal, true)), + ...toItems(filteredImageUser, false), + ]; + case "vision": + return [ + ...sortGlobalItems(toItems(filteredVisionGlobal, true)), + ...toItems(filteredVisionUser, false), + ]; + } + }, [ + activeTab, + filteredLLMGlobal, + filteredLLMUser, + filteredImageGlobal, + filteredImageUser, + filteredVisionGlobal, + filteredVisionUser, + ]); + + // ─── Provider sidebar data ─── + // Collect which providers actually have configured models for the active tab + const configuredProviderSet = useMemo(() => { + const configs = + activeTab === "llm" + ? [...(llmGlobalConfigs ?? []), ...(llmUserConfigs ?? [])] + : activeTab === "image" + ? [...(imageGlobalConfigs ?? []), ...(imageUserConfigs ?? [])] + : [...(visionGlobalConfigs ?? []), ...(visionUserConfigs ?? [])]; + const set = new Set<string>(); + for (const c of configs) { + if (c.provider) set.add(c.provider.toUpperCase()); + } + return set; + }, [ + activeTab, + llmGlobalConfigs, + llmUserConfigs, + imageGlobalConfigs, + imageUserConfigs, + visionGlobalConfigs, + visionUserConfigs, + ]); + + // Show only providers valid for the active tab; configured ones first + const activeProviders = useMemo(() => { + const tabKeys = PROVIDER_KEYS_BY_TAB[activeTab] ?? LLM_PROVIDER_KEYS; + const configured = tabKeys.filter((p) => configuredProviderSet.has(p)); + const unconfigured = tabKeys.filter((p) => !configuredProviderSet.has(p)); + return ["all", ...configured, ...unconfigured]; + }, [activeTab, configuredProviderSet]); + + const providerModelCounts = useMemo(() => { + const allConfigs = + activeTab === "llm" + ? [...(llmGlobalConfigs ?? []), ...(llmUserConfigs ?? [])] + : activeTab === "image" + ? [...(imageGlobalConfigs ?? []), ...(imageUserConfigs ?? [])] + : [...(visionGlobalConfigs ?? []), ...(visionUserConfigs ?? [])]; + const counts: Record<string, number> = { all: allConfigs.length }; + for (const c of allConfigs) { + const p = c.provider.toUpperCase(); + counts[p] = (counts[p] || 0) + 1; + } + return counts; + }, [ + activeTab, + llmGlobalConfigs, + llmUserConfigs, + imageGlobalConfigs, + imageUserConfigs, + visionGlobalConfigs, + visionUserConfigs, + ]); + + // ─── Selection handlers ─── + const handleSelectLLM = useCallback( + async (config: NewLLMConfigPublic | GlobalNewLLMConfig) => { + if (currentLLMConfig?.id === config.id) { + setOpen(false); + return; + } + if (!searchSpaceId) { + toast.error("No search space selected"); + return; + } + try { + await updatePreferences({ + search_space_id: Number(searchSpaceId), + data: { agent_llm_id: config.id }, + }); + toast.success(`Switched to ${config.name}`); + setOpen(false); + } catch { + toast.error("Failed to switch model"); + } + }, + [currentLLMConfig, searchSpaceId, updatePreferences] + ); + + const handleSelectImage = useCallback( + async (configId: number) => { + if (currentImageConfig?.id === configId) { + setOpen(false); + return; + } + if (!searchSpaceId) { + toast.error("No search space selected"); + return; + } + try { + await updatePreferences({ + search_space_id: Number(searchSpaceId), + data: { image_generation_config_id: configId }, + }); + toast.success("Image model updated"); + setOpen(false); + } catch { + toast.error("Failed to switch image model"); + } + }, + [currentImageConfig, searchSpaceId, updatePreferences] + ); + + const handleSelectVision = useCallback( + async (configId: number) => { + if (currentVisionConfig?.id === configId) { + setOpen(false); + return; + } + if (!searchSpaceId) { + toast.error("No search space selected"); + return; + } + try { + await updatePreferences({ + search_space_id: Number(searchSpaceId), + data: { vision_llm_config_id: configId }, + }); + toast.success("Vision model updated"); + setOpen(false); + } catch { + toast.error("Failed to switch vision model"); + } + }, + [currentVisionConfig, searchSpaceId, updatePreferences] + ); + + const handleSelectItem = useCallback( + (item: DisplayItem) => { + switch (activeTab) { + case "llm": + handleSelectLLM(item.config as NewLLMConfigPublic | GlobalNewLLMConfig); + break; + case "image": + handleSelectImage(item.config.id); + break; + case "vision": + handleSelectVision(item.config.id); + break; + } + }, + [activeTab, handleSelectLLM, handleSelectImage, handleSelectVision] + ); + + const handleEditItem = useCallback( + (e: React.MouseEvent, item: DisplayItem) => { + e.stopPropagation(); + setOpen(false); + switch (activeTab) { + case "llm": + onEditLLM(item.config as NewLLMConfigPublic | GlobalNewLLMConfig, item.isGlobal); + break; + case "image": + onEditImage?.(item.config as ImageGenerationConfig | GlobalImageGenConfig, item.isGlobal); + break; + case "vision": + onEditVision?.(item.config as VisionLLMConfig | GlobalVisionLLMConfig, item.isGlobal); + break; + } + }, + [activeTab, onEditLLM, onEditImage, onEditVision] + ); + + // ─── Keyboard navigation ─── + // biome-ignore lint/correctness/useExhaustiveDependencies: searchQuery and selectedProvider are intentional triggers to reset focus + useEffect(() => { + setFocusedIndex(-1); + }, [searchQuery, selectedProvider]); + + useEffect(() => { + if (focusedIndex < 0 || !modelListRef.current) return; + const items = modelListRef.current.querySelectorAll("[data-model-index]"); + items[focusedIndex]?.scrollIntoView({ + block: "nearest", + behavior: "smooth", + }); + }, [focusedIndex]); + + const handleKeyDown = useCallback( + (e: React.KeyboardEvent<HTMLInputElement>) => { + const count = currentDisplayItems.length; + + // Arrow Left/Right cycle provider filters + if (e.key === "ArrowLeft" || e.key === "ArrowRight") { + e.preventDefault(); + const providers = activeProviders; + const idx = providers.indexOf(selectedProvider); + let next: number; + if (e.key === "ArrowLeft") { + next = idx > 0 ? idx - 1 : providers.length - 1; + } else { + next = idx < providers.length - 1 ? idx + 1 : 0; + } + setSelectedProvider(providers[next]); + if (providerSidebarRef.current) { + const buttons = providerSidebarRef.current.querySelectorAll("button"); + buttons[next]?.scrollIntoView({ + block: "nearest", + inline: "nearest", + behavior: "smooth", + }); + } + return; + } + + if (count === 0) return; + + switch (e.key) { + case "ArrowDown": + e.preventDefault(); + setFocusedIndex((prev) => (prev < count - 1 ? prev + 1 : 0)); + break; + case "ArrowUp": + e.preventDefault(); + setFocusedIndex((prev) => (prev > 0 ? prev - 1 : count - 1)); + break; + case "Enter": + e.preventDefault(); + if (focusedIndex >= 0 && focusedIndex < count) { + handleSelectItem(currentDisplayItems[focusedIndex]); + } + break; + case "Home": + e.preventDefault(); + setFocusedIndex(0); + break; + case "End": + e.preventDefault(); + setFocusedIndex(count - 1); + break; + } + }, + [currentDisplayItems, focusedIndex, activeProviders, selectedProvider, handleSelectItem] + ); + + // ─── Render: Provider sidebar ─── + const renderProviderSidebar = () => { + const configuredCount = configuredProviderSet.size; + + return ( + <div + className={cn( + "shrink-0 border-popover-border flex relative", + isMobile ? "flex-row items-center border-b" : "flex-col w-10 border-r" + )} + > + {!isMobile && ( + <div + className={cn( + "absolute top-0 left-0 right-0 z-10 h-5 flex items-center justify-center transition-all duration-200 ease-out", + sidebarScrollPos === "top" + ? "opacity-0 -translate-y-1 pointer-events-none" + : "opacity-100 translate-y-0 pointer-events-auto" + )} + > + <Button + type="button" + variant="ghost" + aria-label="Scroll providers up" + onClick={() => scrollProviderSidebar("backward")} + className="h-4 w-4 rounded-sm p-0 text-muted-foreground/90 hover:bg-accent hover:text-accent-foreground" + > + <ChevronUp className="size-3" /> + </Button> + </div> + )} + {isMobile && ( + <div + className={cn( + "absolute left-0 top-0 bottom-0 z-10 w-5 flex items-center justify-center transition-all duration-200 ease-out pointer-events-none", + sidebarScrollPos === "top" ? "opacity-0 -translate-x-1" : "opacity-100 translate-x-0" + )} + > + <ChevronLeft className="size-3 text-muted-foreground" /> + </div> + )} + <div + ref={providerSidebarRef} + onScroll={handleSidebarScroll} + className={cn( + isMobile + ? "flex flex-row gap-0.5 px-1 py-1.5 overflow-x-auto [&::-webkit-scrollbar]:h-0 [&::-webkit-scrollbar-track]:bg-transparent" + : "flex flex-col gap-0.5 p-1 overflow-y-auto flex-1 [&::-webkit-scrollbar]:w-0 [&::-webkit-scrollbar-track]:bg-transparent" + )} + style={ + isMobile + ? { + maskImage: `linear-gradient(to right, ${sidebarScrollPos === "top" ? "black" : "transparent"}, black 24px, black calc(100% - 24px), ${sidebarScrollPos === "bottom" ? "black" : "transparent"})`, + WebkitMaskImage: `linear-gradient(to right, ${sidebarScrollPos === "top" ? "black" : "transparent"}, black 24px, black calc(100% - 24px), ${sidebarScrollPos === "bottom" ? "black" : "transparent"})`, + } + : { + maskImage: `linear-gradient(to bottom, ${sidebarScrollPos === "top" ? "black" : "transparent"}, black 32px, black calc(100% - 32px), ${sidebarScrollPos === "bottom" ? "black" : "transparent"})`, + WebkitMaskImage: `linear-gradient(to bottom, ${sidebarScrollPos === "top" ? "black" : "transparent"}, black 32px, black calc(100% - 32px), ${sidebarScrollPos === "bottom" ? "black" : "transparent"})`, + } + } + > + {activeProviders.map((provider, idx) => { + const isAll = provider === "all"; + const isActive = selectedProvider === provider; + const count = providerModelCounts[provider] || 0; + const isConfigured = isAll || configuredProviderSet.has(provider); + + // Separator between configured and unconfigured providers + // idx 0 is "all", configured run from 1..configuredCount, unconfigured start at configuredCount+1 + const showSeparator = !isAll && idx === configuredCount + 1 && configuredCount > 0; + + return ( + <Fragment key={provider}> + {showSeparator && + (isMobile ? ( + <div className="w-px h-5 bg-popover-border shrink-0 self-center mx-0.5" /> + ) : ( + <div className="h-px w-5 bg-popover-border mx-auto my-0.5" /> + ))} + <Tooltip> + <TooltipTrigger asChild> + <Button + type="button" + variant="ghost" + onClick={() => setSelectedProvider(provider)} + tabIndex={-1} + className={cn( + "relative h-auto rounded-md transition-all duration-150", + isMobile ? "p-2 shrink-0" : "p-1.5 w-full", + isActive + ? "bg-primary/10 text-primary" + : isConfigured + ? "hover:bg-accent text-muted-foreground hover:text-accent-foreground" + : "opacity-50 hover:opacity-80 hover:bg-accent/40 text-muted-foreground" + )} + > + {isAll ? ( + <Layers className="size-4" /> + ) : ( + getProviderIcon(provider, { + className: "size-4", + }) + )} + </Button> + </TooltipTrigger> + <TooltipContent side={isMobile ? "bottom" : "right"}> + {isAll ? "All Models" : formatProviderName(provider)} + {isConfigured ? ` (${count})` : " (not configured)"} + </TooltipContent> + </Tooltip> + </Fragment> + ); + })} + </div> + {!isMobile && ( + <div + className={cn( + "absolute bottom-0 left-0 right-0 z-10 h-5 flex items-center justify-center transition-all duration-200 ease-out", + sidebarScrollPos === "bottom" + ? "opacity-0 translate-y-1 pointer-events-none" + : "opacity-100 translate-y-0 pointer-events-auto" + )} + > + <Button + type="button" + variant="ghost" + aria-label="Scroll providers down" + onClick={() => scrollProviderSidebar("forward")} + className="h-4 w-4 rounded-sm p-0 text-muted-foreground/90 hover:bg-accent hover:text-accent-foreground" + > + <ChevronDown className="size-3" /> + </Button> + </div> + )} + {isMobile && ( + <div + className={cn( + "absolute right-0 top-0 bottom-0 z-10 w-5 flex items-center justify-center transition-all duration-200 ease-out pointer-events-none", + sidebarScrollPos === "bottom" + ? "opacity-0 translate-x-1" + : "opacity-100 translate-x-0" + )} + > + <ChevronRight className="size-3 text-muted-foreground" /> + </div> + )} + </div> + ); + }; + + // ─── Render: Model card ─── + const getSelectedId = () => { + switch (activeTab) { + case "llm": + return currentLLMConfig?.id; + case "image": + return currentImageConfig?.id; + case "vision": + return currentVisionConfig?.id; + } + }; + + const renderModelCard = (item: DisplayItem, index: number) => { + const { config, isAutoMode } = item; + const isSelected = getSelectedId() === config.id; + const isFocused = focusedIndex === index; + const hasCitations = "citations_enabled" in config && !!config.citations_enabled; + const hasPremiumStatus = "is_premium" in config && !isAutoMode; + const isPremium = hasPremiumStatus && !!(config as Record<string, unknown>).is_premium; + // Chat-tab only: surface an amber "No image" hint when the + // composer carries images and the catalog reports the model as + // non-vision. This is purely advisory — selection is *not* + // blocked. The backend's narrow safety net + // (`is_known_text_only_chat_model`) is the source of truth for + // rejecting image turns, and it only fires when LiteLLM + // explicitly marks the model as text-only. A model surfaced as + // `supports_image_input=false` here may still be capable in + // practice (unknown / unmapped LiteLLM entry), so we let the + // user pick it and the provider response decide. + const isImageIncompatibleChatModel = + activeTab === "llm" && + hasPendingImages && + "supports_image_input" in config && + (config as Record<string, unknown>).supports_image_input === false; + + return ( + <div + key={`${activeTab}-${item.isGlobal ? "g" : "u"}-${config.id}`} + data-model-index={index} + role="option" + tabIndex={isMobile ? -1 : 0} + aria-selected={isSelected} + title={ + isImageIncompatibleChatModel + ? "This model is reported as text-only. You can still pick it; the provider may reject image turns." + : undefined + } + onClick={() => handleSelectItem(item)} + onKeyDown={ + isMobile + ? undefined + : (e) => { + if (e.key === "Enter" || e.key === " ") { + e.preventDefault(); + handleSelectItem(item); + } + } + } + onMouseEnter={() => setFocusedIndex(index)} + className={cn( + "group flex items-center gap-2.5 px-3 py-2 rounded-xl", + "transition-colors duration-150 mx-2 cursor-pointer", + "hover:bg-accent hover:text-accent-foreground", + isFocused && "bg-accent text-accent-foreground", + isSelected && "bg-accent text-accent-foreground" + )} + > + {/* Provider icon */} + <div className="shrink-0"> + {getProviderIcon(config.provider as string, { + isAutoMode, + className: "size-5", + })} + </div> + + {/* Model info */} + <div className="flex-1 min-w-0"> + <div className="flex items-center gap-1.5"> + <TruncatedNameWithTooltip + text={config.name} + enableTooltip={!isMobile} + className="font-medium text-sm truncate" + /> + {isAutoMode && ( + <Badge + variant="secondary" + className="text-[9px] px-1 py-0 h-3.5 bg-zinc-200 text-zinc-600 dark:bg-zinc-700 dark:text-zinc-300 border-0" + > + Recommended + </Badge> + )} + {isImageIncompatibleChatModel && ( + <Badge + variant="secondary" + className="text-[9px] px-1 py-0 h-3.5 bg-amber-100 text-amber-700 dark:bg-amber-900/50 dark:text-amber-300 border-0" + > + No image + </Badge> + )} + </div> + {isAutoMode ? ( + <div className="flex items-center gap-1.5 mt-0.5"> + <span className="text-xs text-muted-foreground truncate">Auto Mode</span> + </div> + ) : ( + (hasPremiumStatus || hasCitations) && ( + <div className="flex items-center gap-1.5 mt-0.5"> + {hasPremiumStatus && ( + <Badge + variant="secondary" + className={cn( + "text-[10px] px-1.5 py-0.5 border-0", + isPremium + ? "bg-purple-100 text-purple-700 dark:bg-purple-900/50 dark:text-purple-300" + : "bg-emerald-100 text-emerald-700 dark:bg-emerald-900/50 dark:text-emerald-300" + )} + > + {isPremium ? "Premium" : "Free"} + </Badge> + )} + {hasCitations && ( + <Badge + variant="secondary" + className="text-[10px] px-1.5 py-0.5 border-0 bg-neutral-200 text-neutral-700 dark:bg-neutral-700 dark:text-neutral-200" + > + Citations + </Badge> + )} + </div> + ) + )} + </div> + + {/* Actions */} + <div className="flex items-center gap-1 shrink-0"> + {!isAutoMode && ( + <Button + variant="ghost" + size="icon" + className="size-7 rounded-md hover:bg-accent hover:text-accent-foreground opacity-0 group-hover:opacity-100 transition-opacity" + onClick={(e) => handleEditItem(e, item)} + > + <Pencil className="size-3.5 text-muted-foreground" /> + </Button> + )} + {isSelected && ( + <div className="size-7 grid place-items-center shrink-0"> + <Check className="size-4" /> + </div> + )} + </div> + </div> + ); + }; + + // ─── Render: Full content ─── + const renderContent = () => { + const globalItems = currentDisplayItems.filter((i) => i.isGlobal); + const userItems = currentDisplayItems.filter((i) => !i.isGlobal); + const globalStartIdx = 0; + const userStartIdx = globalItems.length; + + const addHandler = + activeTab === "llm" ? onAddNewLLM : activeTab === "image" ? onAddNewImage : onAddNewVision; + const addLabel = + activeTab === "llm" + ? "Add Model" + : activeTab === "image" + ? "Add Image Model" + : "Add Vision Model"; + + return ( + <div className="flex flex-col w-full overflow-hidden"> + {/* Tab header */} + <div className="border-b border-popover-border"> + <div className="w-full grid grid-cols-3 h-11"> + {( + [ + { + value: "llm" as const, + icon: Zap, + label: "LLM", + }, + { + value: "image" as const, + icon: ImageIcon, + label: "Image", + }, + { + value: "vision" as const, + icon: ScanEye, + label: "Vision", + }, + ] as const + ).map(({ value, icon: Icon, label }) => ( + <Button + key={value} + type="button" + variant="ghost" + // onClick={() => setActiveTab(value)} + onClick={() => handleTabChange(value)} + className={cn( + "h-auto rounded-none px-0 py-0 gap-1.5 text-sm font-medium transition-all duration-200 border-b-[1.5px] hover:bg-transparent", + activeTab === value + ? "border-foreground dark:border-white text-foreground" + : "border-transparent text-muted-foreground hover:text-accent-foreground" + )} + > + <Icon className="size-3.5" /> + {label} + </Button> + ))} + </div> + </div> + + {/* Two-pane layout */} + <div className={cn("flex", isMobile ? "flex-col h-[60vh]" : "flex-row h-[380px]")}> + {/* Provider sidebar */} + {renderProviderSidebar()} + + {/* Main content */} + <div className="flex flex-col min-w-0 min-h-0 flex-1 overflow-hidden"> + {/* Search */} + <div className="relative"> + <Search className="absolute left-3 top-1/2 -translate-y-1/2 size-3.5 text-muted-foreground/100 pointer-events-none" /> + <input + ref={searchInputRef} + placeholder="Search models" + value={searchQuery} + onChange={(e) => setSearchQuery(e.target.value)} + onKeyDown={isMobile ? undefined : handleKeyDown} + role="combobox" + aria-expanded={true} + aria-controls="model-selector-list" + className={cn( + "w-full pl-8 pr-3 py-2.5 text-sm bg-transparent", + "focus:outline-none", + "placeholder:text-muted-foreground" + )} + /> + </div> + + {/* Provider header when filtered */} + {selectedProvider !== "all" && ( + <div className="flex items-center gap-2 px-3 py-1.5"> + {getProviderIcon(selectedProvider, { + className: "size-4", + })} + <span className="text-sm font-medium">{formatProviderName(selectedProvider)}</span> + <span className="text-xs text-muted-foreground ml-auto"> + {configuredProviderSet.has(selectedProvider) + ? `${providerModelCounts[selectedProvider] || 0} models` + : "Not configured"} + </span> + </div> + )} + + {/* Model list */} + <div + id="model-selector-list" + ref={modelListRef} + role="listbox" + className="overflow-y-auto flex-1 py-1 space-y-1 flex flex-col" + onScroll={handleModelListScroll} + style={{ + maskImage: `linear-gradient(to bottom, ${modelScrollPos === "top" ? "black" : "transparent"}, black 16px, black calc(100% - 16px), ${modelScrollPos === "bottom" ? "black" : "transparent"})`, + WebkitMaskImage: `linear-gradient(to bottom, ${modelScrollPos === "top" ? "black" : "transparent"}, black 16px, black calc(100% - 16px), ${modelScrollPos === "bottom" ? "black" : "transparent"})`, + }} + > + {currentDisplayItems.length === 0 ? ( + <div className="flex-1 flex flex-col items-center justify-center gap-3 px-4"> + {selectedProvider !== "all" && !configuredProviderSet.has(selectedProvider) ? ( + <> + <div className="opacity-40"> + {getProviderIcon(selectedProvider, { + className: "size-10", + })} + </div> + <p className="text-sm font-medium text-muted-foreground"> + No {formatProviderName(selectedProvider)} models configured + </p> + <p className="text-xs text-muted-foreground/60 text-center"> + Add a model with this provider to get started + </p> + {addHandler && ( + <Button + variant="secondary" + size="sm" + className="mt-1" + onClick={() => { + setOpen(false); + addHandler(selectedProvider !== "all" ? selectedProvider : undefined); + }} + > + {addLabel} + </Button> + )} + </> + ) : searchQuery ? ( + <> + <Search className="size-8 text-muted-foreground" /> + <p className="text-sm text-muted-foreground">No models found</p> + <p className="text-xs text-muted-foreground/60"> + Try a different search term + </p> + </> + ) : ( + <> + <p className="text-sm font-medium text-muted-foreground"> + No models configured + </p> + <p className="text-xs text-muted-foreground/60 text-center"> + Configure models in your search space settings + </p> + </> + )} + </div> + ) : ( + <> + {globalItems.length > 0 && ( + <> + <div className="flex items-center gap-2 px-3 py-1.5 text-[12px] font-semibold text-muted-foreground tracking-wider"> + Global Models + </div> + {globalItems.map((item, i) => renderModelCard(item, globalStartIdx + i))} + </> + )} + {globalItems.length > 0 && userItems.length > 0 && ( + <div className="my-1.5 mx-4 h-px bg-popover-border" /> + )} + {userItems.length > 0 && ( + <> + <div className="flex items-center gap-2 px-3 py-1.5 text-[12px] font-semibold text-muted-foreground tracking-wider"> + Your Configurations + </div> + {userItems.map((item, i) => renderModelCard(item, userStartIdx + i))} + </> + )} + </> + )} + </div> + + {/* Add model button */} + {addHandler && ( + <div className="p-2"> + <Button + variant="ghost" + size="sm" + className="w-full justify-start gap-2 h-9 rounded-lg hover:bg-accent hover:text-accent-foreground " + onClick={() => { + setOpen(false); + addHandler(selectedProvider !== "all" ? selectedProvider : undefined); + }} + > + <Plus className="size-4 text-primary" /> + <span className="text-sm font-medium">{addLabel}</span> + </Button> + </div> + )} + </div> + </div> + </div> + ); + }; + + // ─── Trigger button ─── + const triggerButton = ( <Button - type="button" variant="ghost" size="sm" + role="combobox" + aria-expanded={open} className={cn( - "h-8 min-w-0 gap-2 rounded-md px-3 text-muted-foreground transition-colors", - "select-none", - "hover:bg-foreground/10 hover:text-foreground", - "data-[state=open]:bg-foreground/10 data-[state=open]:text-foreground", - className, - showIconOnlyTrigger && "h-9 w-auto shrink-0 justify-center gap-1 px-2" + "h-8 gap-2 px-3 text-sm bg-muted shadow-xs hover:bg-accent hover:text-accent-foreground border-0 select-none", + className )} > - {selected - ? getProviderIcon(selected.provider, { className: "size-4 shrink-0" }) - : getProviderIcon(AUTO_PROVIDER_ICON_KEY, { className: "size-4 shrink-0" })} - {showIconOnlyTrigger ? null : ( - <span className="min-w-0 flex-1 truncate text-sm"> - {selected ? modelName(selected) : "Auto"} - </span> + {isLoading ? ( + <> + <Spinner size="sm" className="text-muted-foreground" /> + <span className="text-muted-foreground hidden md:inline">Loading</span> + </> + ) : ( + <> + {/* LLM */} + {currentLLMConfig ? ( + <> + {getProviderIcon(currentLLMConfig.provider, { + isAutoMode: isLLMAutoMode ?? false, + })} + <span className="max-w-[100px] md:max-w-[120px] truncate hidden md:inline"> + {currentLLMConfig.name} + </span> + </> + ) : ( + <> + <Bot className="size-4 text-muted-foreground" /> + <span className="text-muted-foreground hidden md:inline">Select Model</span> + </> + )} + <div className="h-4 w-px bg-border/60 dark:bg-white/10 mx-0.5" /> + {/* Image */} + {currentImageConfig ? ( + <> + {getProviderIcon(currentImageConfig.provider, { + isAutoMode: isImageAutoMode ?? false, + })} + <span className="max-w-[80px] md:max-w-[100px] truncate hidden md:inline"> + {currentImageConfig.name} + </span> + </> + ) : ( + <ImageIcon className="size-4 text-muted-foreground" /> + )} + <div className="h-4 w-px bg-border/60 dark:bg-white/10 mx-0.5" /> + {/* Vision */} + {currentVisionConfig ? ( + <> + {getProviderIcon(currentVisionConfig.provider, { + isAutoMode: isVisionAutoMode ?? false, + })} + <span className="max-w-[80px] md:max-w-[100px] truncate hidden md:inline"> + {currentVisionConfig.name} + </span> + </> + ) : ( + <ScanEye className="size-4 text-muted-foreground" /> + )} + </> )} - <ChevronDown className="h-3.5 w-3.5 shrink-0" /> + <ChevronDown className="h-3.5 w-3.5 text-muted-foreground ml-1 shrink-0" /> </Button> ); + // ─── Shell: Drawer on mobile, Popover on desktop ─── if (isMobile) { return ( <Drawer open={open} onOpenChange={handleOpenChange}> - <DrawerTrigger asChild>{trigger}</DrawerTrigger> + <DrawerTrigger asChild>{triggerButton}</DrawerTrigger> <DrawerContent className="max-h-[85vh]"> <DrawerHandle /> - <DrawerHeader> - <DrawerTitle>Select Chat Model</DrawerTitle> + <DrawerHeader className="pb-0"> + <DrawerTitle>Select Model</DrawerTitle> </DrawerHeader> - {content} + <div className="flex-1 overflow-hidden">{renderContent()}</div> </DrawerContent> </Drawer> ); @@ -310,12 +1435,14 @@ export function ModelSelector({ return ( <Popover open={open} onOpenChange={handleOpenChange}> - <PopoverTrigger asChild>{trigger}</PopoverTrigger> + <PopoverTrigger asChild>{triggerButton}</PopoverTrigger> <PopoverContent + className="w-[300px] md:w-[380px] p-0 rounded-lg shadow-lg overflow-hidden select-none" align="start" - className="w-[340px] border border-popover-border bg-popover p-0 text-popover-foreground shadow-md" + sideOffset={8} + onCloseAutoFocus={(e) => e.preventDefault()} > - {content} + {renderContent()} </PopoverContent> </Popover> ); diff --git a/surfsense_web/components/pricing/pricing-section.tsx b/surfsense_web/components/pricing/pricing-section.tsx index 1e11e95d5..46ceee694 100644 --- a/surfsense_web/components/pricing/pricing-section.tsx +++ b/surfsense_web/components/pricing/pricing-section.tsx @@ -14,11 +14,11 @@ const demoPlans = [ price: "0", yearlyPrice: "0", period: "", - billingText: "$5 of credit included to start", + billingText: "500 pages + $5 in premium credits included", features: [ "Self Hostable", - "$5 of credit included to start", - "One credit balance for document processing and premium AI features", + "500 pages included to start", + "$5 in premium credits for paid AI models and premium AI features", "Includes access to OpenAI text, audio and image models", "AI automations and agents: scheduled and event-triggered workflows", "Desktop app: Quick, General and Screenshot Assist plus local folder sync", @@ -38,7 +38,7 @@ const demoPlans = [ billingText: "No subscription, buy only when you need more", features: [ "Everything in Free", - "Buy credit in $1 packs — $1 buys $1 of credit, with optional auto-reload", + "Buy 1,000-page packs or $1 in premium credits at $1 each", "Use premium AI models like GPT-5.4, Claude Sonnet 4.6, Gemini 2.5 Pro & 100+ more via OpenRouter", "Connector write-back to Notion, Slack, Linear & Jira", "Priority support on Discord", @@ -84,32 +84,32 @@ interface FAQSection { const faqData: FAQSection[] = [ { - title: "Credits & Document Billing", + title: "Pages & Document Billing", items: [ { - question: "What are credits in SurfSense?", + question: 'What exactly is a "page" in SurfSense?', answer: - "Credits are a single prepaid balance shown in dollars that powers everything in SurfSense — both document processing and premium AI features. New accounts start with $5 of credit. Your balance goes down as you use the product and back up when you top up or earn more, so there's just one number to keep an eye on.", + "A page is a simple billing unit that measures how much content you add to your knowledge base. For PDFs, one page equals one real PDF page. For other document types like Word, PowerPoint, and Excel files, pages are automatically estimated based on the file. Every file uses at least 1 page.", }, { - question: "How much does document processing cost?", + question: "What are Basic and Premium processing modes?", answer: - "Document processing is billed per page out of your credit balance. For PDFs, one page equals one real PDF page; for other document types like Word, PowerPoint, and Excel files, pages are automatically estimated. Basic mode costs $0.001 per page and Premium mode costs $0.01 per page. Premium processing uses advanced extraction optimized for complex financial, medical, and legal documents with intricate tables and layouts. Every file uses at least 1 page.", + "When uploading documents, you can choose between two processing modes. Basic mode uses standard extraction and costs 1 page credit per page, great for most documents. Premium processing mode uses advanced extraction optimized for complex financial, medical, and legal documents with intricate tables, layouts, and formatting. It costs 10 page credits per page and does not use your premium AI credits.", }, { question: "How does the Pay As You Go plan work?", answer: - "There's no monthly subscription. When you need more credit, simply buy $1 packs — $1 buys exactly $1 of credit. Purchased credit is added to your balance immediately so you can keep working right away. You only pay when you actually need more, and you can enable auto-reload to top up automatically.", + "There's no monthly subscription. When you need more pages, simply purchase 1,000-page packs at $1 each. Purchased pages are added to your account immediately so you can keep indexing right away. You only pay when you actually need more.", }, { - question: "What happens if I run out of credit?", + question: "What happens if I run out of pages?", answer: - "SurfSense checks your remaining credit before processing each file. If you don't have enough, the upload is paused and you'll be notified so you can buy more credit and continue. For cloud connector syncs, a small overage may be allowed so your sync doesn't partially fail.", + "SurfSense checks your remaining pages before processing each file. If you don't have enough, the upload is paused and you'll be notified. You can purchase additional page packs at any time to continue. For cloud connector syncs, a small overage may be allowed so your sync doesn't partially fail.", }, { - question: "If I delete a document, do I get my credit back?", + question: "If I delete a document, do I get my pages back?", answer: - "No. Deleting a document removes it from your knowledge base, but the credit it used is not refunded. Credit tracks your total usage over time, not how much is currently stored, so be mindful of what you index. Once credit is spent, it's spent even if you later remove the document.", + "No. Deleting a document removes it from your knowledge base, but the pages it used are not refunded. Pages track your total usage over time, not how much is currently stored. So be mindful of what you index. Once pages are spent, they're spent even if you later remove the document.", }, ], }, @@ -117,49 +117,49 @@ const faqData: FAQSection[] = [ title: "File Types & Connectors", items: [ { - question: "Which file types use credit?", + question: "Which file types count toward my page limit?", answer: - "Credit is only used for document files that need processing, including PDFs, Word documents (DOC, DOCX, ODT, RTF), presentations (PPT, PPTX, ODP), spreadsheets (XLS, XLSX, ODS), ebooks (EPUB), and images (JPG, PNG, TIFF, WebP, BMP). Plain text files, code files, Markdown, CSV, TSV, HTML, audio, and video files do not consume any credit.", + "Page limits only apply to document files that need processing, including PDFs, Word documents (DOC, DOCX, ODT, RTF), presentations (PPT, PPTX, ODP), spreadsheets (XLS, XLSX, ODS), ebooks (EPUB), and images (JPG, PNG, TIFF, WebP, BMP). Plain text files, code files, Markdown, CSV, TSV, HTML, audio, and video files do not consume any pages.", }, { - question: "How is credit consumed for documents?", + question: "How are pages consumed?", answer: - "Credit is deducted whenever a document file is successfully indexed into your knowledge base, whether through direct uploads or file-based connector syncs (Google Drive, OneDrive, Dropbox, Local Folder). In Basic mode each page costs $0.001; in Premium mode each page costs $0.01. SurfSense checks your remaining credit before processing and only charges you after the file is indexed. Duplicate documents are automatically detected and won't cost you extra.", + "Pages are deducted whenever a document file is successfully indexed into your knowledge base, whether through direct uploads or file-based connector syncs (Google Drive, OneDrive, Dropbox, Local Folder). In Basic mode, each page costs 1 page credit; in Premium mode, each page costs 10 page credits. SurfSense checks your remaining credits before processing and only charges you after the file is indexed. Duplicate documents are automatically detected and won't cost you extra pages.", }, { - question: "Do connectors like Slack, Notion, or Gmail use credit?", + question: "Do connectors like Slack, Notion, or Gmail use pages?", answer: - "No. Connectors that work with structured text data like Slack, Discord, Notion, Confluence, Jira, Linear, ClickUp, GitHub, Gmail, Google Calendar, Microsoft Teams, Airtable, Elasticsearch, Web Crawler, BookStack, Obsidian, and Luma do not use credit at all. Document-processing charges only apply to file-based connectors such as Google Drive, OneDrive, Dropbox, and Local Folder syncs.", + "No. Connectors that work with structured text data like Slack, Discord, Notion, Confluence, Jira, Linear, ClickUp, GitHub, Gmail, Google Calendar, Microsoft Teams, Airtable, Elasticsearch, Web Crawler, BookStack, Obsidian, and Luma do not use pages at all. Page limits only apply to file-based connectors that need document processing, such as Google Drive, OneDrive, Dropbox, and Local Folder syncs.", }, ], }, { - title: "Premium AI & Credit", + title: "Premium Credits", items: [ { - question: "How is credit used for premium AI?", + question: 'What are "premium credits"?', answer: - "The same credit balance pays for paid AI usage in SurfSense, including premium AI models like GPT-5.4, Claude Sonnet 4.6, and Gemini 2.5 Pro, plus premium AI features such as image generation, podcasts, and video presentations when they use paid models. Each request debits the actual USD provider cost, so cheaper and more expensive models bill proportionally.", + "Premium credits are your USD balance for paid AI usage in SurfSense, including premium AI models like GPT-5.4, Claude Sonnet 4.6, and Gemini 2.5 Pro, plus premium AI features such as image generation, podcasts, and video presentations when they use paid models. Each request debits the actual USD provider cost, so cheaper and more expensive models bill proportionally.", }, { - question: "How much credit do I get for free?", + question: "How many premium credits do I get for free?", answer: - "Every registered SurfSense account starts with $5 of credit at no cost. Anonymous users (no login) get 500,000 free tokens across free models before creating an account. Once your included credit runs out, you can top up at any time or earn more by completing tasks.", + "Every registered SurfSense account starts with $5 in premium credits at no cost. Anonymous users (no login) get 500,000 free tokens across free models before creating an account. Once your included premium credits run out, you can top up at any time.", }, { - question: "How does buying credit work?", + question: "How does buying premium credits work?", answer: - "Top-ups are pay as you go, with no subscription. $1 buys $1 of credit, and your balance is spent at provider cost. Purchased credit is added to your account immediately, and you can buy any custom amount. Enable auto-reload to top up automatically when your balance runs low.", + "Premium credit top-ups are pay as you go, with no subscription. $1 buys $1 of credit, and your balance is spent at provider cost. Purchased credit is added to your account immediately. You can buy up to $100 at a time.", }, { - question: "Is there a separate balance for documents and AI?", + question: "Are premium credits the same as page credits?", answer: - "No. SurfSense uses one unified credit balance for everything — document indexing, file-based connector processing, premium model chats, and premium AI generation features all draw from the same wallet. Premium document processing mode simply costs more per page ($0.01 vs $0.001), but it's the same credit.", + "No. Page credits pay for document indexing and file-based connector processing. Premium credits pay for paid AI usage, such as premium model chats and premium AI generation features. Premium document processing mode sounds similar, but it consumes page credits, not premium credits.", }, { - question: "What happens if I run out of credit?", + question: "What happens if I run out of premium credits?", answer: - "When your credit balance runs low, you'll see a warning. Once you run out, paid model requests, premium AI features, and document processing are paused until you top up. You can still use non-premium models and features that do not consume credit.", + "When your premium credit balance runs low, you'll see a warning. Once you run out, paid model requests and premium AI features are paused until you top up. You can still use non-premium models and features that do not consume premium credits.", }, ], }, @@ -174,7 +174,7 @@ const faqData: FAQSection[] = [ { question: "Do automations and agents cost extra?", answer: - "No. There is no separate subscription or add-on fee for automations. Agents draw from the same unified credit balance as the rest of SurfSense. Indexing documents and premium AI model usage during a workflow both consume credit at provider cost. If a workflow only uses free models and indexes no documents, it does not touch your credit.", + "No. There is no separate subscription or add-on fee for automations. Agents use the same page credits and premium credits as the rest of SurfSense. Indexing documents consumes page credits, and premium AI model usage during a workflow consumes premium credits at provider cost. If a workflow only uses free models, it does not touch your premium credits.", }, { question: "How do event-triggered automations work?", @@ -192,9 +192,9 @@ const faqData: FAQSection[] = [ title: "Self-Hosting", items: [ { - question: "Can I self-host SurfSense with unlimited usage?", + question: "Can I self-host SurfSense with unlimited pages and credit?", answer: - "Yes! When self-hosting, you have full control over billing. The default self-hosted setup leaves document-processing credit billing off and gives you effectively unlimited credit, so you can index as much data and use as many AI queries as your infrastructure supports.", + "Yes! When self-hosting, you have full control over your page and premium credit limits. The default self-hosted setup gives you effectively unlimited pages and premium credits, so you can index as much data and use as many AI queries as your infrastructure supports.", }, ], }, @@ -286,8 +286,8 @@ function PricingFAQ() { Frequently Asked Questions </h2> <p className="mx-auto mt-4 max-w-2xl text-lg text-muted-foreground"> - Everything you need to know about SurfSense credits and billing. Can't find what you - need? Reach out at{" "} + Everything you need to know about SurfSense pages, premium credits, and billing. + Can't find what you need? Reach out at{" "} <a href="mailto:rohan@surfsense.com" className="text-blue-500 underline"> rohan@surfsense.com </a> @@ -372,7 +372,7 @@ function PricingBasic() { <Pricing plans={demoPlans} title="SurfSense Pricing" - description="Start free with $5 of credit. Run AI automations and agents, and pay as you go." + description="Start free with 500 pages & $5 in premium credits. Run AI automations and agents, and pay as you go." /> <PricingFAQ /> </> diff --git a/surfsense_web/components/providers/ZeroProvider.tsx b/surfsense_web/components/providers/ZeroProvider.tsx index 35d51311a..5bb43db99 100644 --- a/surfsense_web/components/providers/ZeroProvider.tsx +++ b/surfsense_web/components/providers/ZeroProvider.tsx @@ -12,15 +12,7 @@ import { getBearerToken, handleUnauthorized, refreshAccessToken } from "@/lib/au import { queries } from "@/zero/queries"; import { schema } from "@/zero/schema"; -const configuredCacheURL = process.env.NEXT_PUBLIC_ZERO_CACHE_URL; - -function getCacheURL() { - if (configuredCacheURL) return configuredCacheURL; - if (typeof window !== "undefined") { - return `${window.location.origin}/zero`; - } - return "http://localhost:4848"; -} +const cacheURL = process.env.NEXT_PUBLIC_ZERO_CACHE_URL || "http://localhost:4848"; function ZeroAuthSync() { const zero = useZero(); @@ -50,7 +42,6 @@ function ZeroAuthSync() { export function ZeroProvider({ children }: { children: React.ReactNode }) { const { data: user } = useAtomValue(currentUserAtom); - const cacheURL = useMemo(() => getCacheURL(), []); const userId = user?.id; const hasUser = !!userId; @@ -74,7 +65,7 @@ export function ZeroProvider({ children }: { children: React.ReactNode }) { cacheURL, auth, }), - [userID, context, cacheURL, auth] + [userID, context, auth] ); return ( diff --git a/surfsense_web/components/providers/runtime-config.server.tsx b/surfsense_web/components/providers/runtime-config.server.tsx deleted file mode 100644 index c515820c2..000000000 --- a/surfsense_web/components/providers/runtime-config.server.tsx +++ /dev/null @@ -1,19 +0,0 @@ -import { connection } from "next/server"; -import { RuntimeConfigProvider } from "@/components/providers/runtime-config"; -import { - BUILD_TIME_AUTH_TYPE, - BUILD_TIME_DEPLOYMENT_MODE, - BUILD_TIME_ETL_SERVICE, -} from "@/lib/env-config"; - -export async function RuntimeConfig({ children }: { children: React.ReactNode }) { - await connection(); - - const value = { - authType: process.env.AUTH_TYPE ?? BUILD_TIME_AUTH_TYPE, - etlService: process.env.ETL_SERVICE ?? BUILD_TIME_ETL_SERVICE, - deploymentMode: process.env.DEPLOYMENT_MODE ?? BUILD_TIME_DEPLOYMENT_MODE, - }; - - return <RuntimeConfigProvider value={value}>{children}</RuntimeConfigProvider>; -} diff --git a/surfsense_web/components/providers/runtime-config.tsx b/surfsense_web/components/providers/runtime-config.tsx deleted file mode 100644 index 560acd597..000000000 --- a/surfsense_web/components/providers/runtime-config.tsx +++ /dev/null @@ -1,48 +0,0 @@ -"use client"; - -import { createContext, useContext } from "react"; - -export type AuthType = "LOCAL" | "GOOGLE" | string; -export type DeploymentMode = "self-hosted" | "cloud" | string; - -export interface RuntimeConfigValue { - authType: AuthType; - etlService: string; - deploymentMode: DeploymentMode; -} - -const RuntimeConfigContext = createContext<RuntimeConfigValue | null>(null); - -export function RuntimeConfigProvider({ - value, - children, -}: { - value: RuntimeConfigValue; - children: React.ReactNode; -}) { - return <RuntimeConfigContext.Provider value={value}>{children}</RuntimeConfigContext.Provider>; -} - -export function useRuntimeConfig() { - const context = useContext(RuntimeConfigContext); - if (!context) { - throw new Error("useRuntimeConfig must be used within RuntimeConfigProvider"); - } - return context; -} - -export function useIsLocalAuth() { - return useRuntimeConfig().authType === "LOCAL"; -} - -export function useIsGoogleAuth() { - return useRuntimeConfig().authType === "GOOGLE"; -} - -export function useIsSelfHosted() { - return useRuntimeConfig().deploymentMode === "self-hosted"; -} - -export function useIsCloud() { - return useRuntimeConfig().deploymentMode === "cloud"; -} diff --git a/surfsense_web/components/public-chat/public-thread.tsx b/surfsense_web/components/public-chat/public-thread.tsx index 083cc5e35..d35193cbe 100644 --- a/surfsense_web/components/public-chat/public-thread.tsx +++ b/surfsense_web/components/public-chat/public-thread.tsx @@ -17,9 +17,9 @@ import { MarkdownText } from "@/components/assistant-ui/markdown-text"; import { ReasoningMessagePart } from "@/components/assistant-ui/reasoning-message-part"; import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; import { GenerateImageToolUI } from "@/components/tool-ui/generate-image"; +import { GeneratePodcastToolUI } from "@/components/tool-ui/generate-podcast"; import { GenerateReportToolUI } from "@/components/tool-ui/generate-report"; import { GenerateResumeToolUI } from "@/components/tool-ui/generate-resume"; -import { GeneratePodcastToolUI } from "@/components/tool-ui/podcast"; const GenerateVideoPresentationToolUI = dynamic( () => diff --git a/surfsense_web/components/report-panel/report-panel.tsx b/surfsense_web/components/report-panel/report-panel.tsx index 53b0c9867..682235e0f 100644 --- a/surfsense_web/components/report-panel/report-panel.tsx +++ b/surfsense_web/components/report-panel/report-panel.tsx @@ -22,7 +22,7 @@ import { Spinner } from "@/components/ui/spinner"; import { useMediaQuery } from "@/hooks/use-media-query"; import { baseApiService } from "@/lib/apis/base-api.service"; import { authenticatedFetch } from "@/lib/auth-utils"; -import { buildBackendUrl } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; function ReportPanelSkeleton() { return ( @@ -245,7 +245,7 @@ export function ReportPanelContent({ URL.revokeObjectURL(url); } else { const response = await authenticatedFetch( - buildBackendUrl(`/api/v1/reports/${activeReportId}/export`, { format }), + `${BACKEND_URL}/api/v1/reports/${activeReportId}/export?format=${format}`, { method: "GET" } ); @@ -278,7 +278,7 @@ export function ReportPanelContent({ setSaving(true); try { const response = await authenticatedFetch( - buildBackendUrl(`/api/v1/reports/${activeReportId}/content`), + `${BACKEND_URL}/api/v1/reports/${activeReportId}/content`, { method: "PUT", headers: { "Content-Type": "application/json" }, @@ -506,11 +506,7 @@ export function ReportPanelContent({ </div> ) : reportContent.content_type === "typst" ? ( <PdfViewer - pdfUrl={buildBackendUrl( - shareToken - ? `/api/v1/public/${shareToken}/reports/${activeReportId}/preview` - : `/api/v1/reports/${activeReportId}/preview` - )} + pdfUrl={`${BACKEND_URL}${shareToken ? `/api/v1/public/${shareToken}/reports/${activeReportId}/preview` : `/api/v1/reports/${activeReportId}/preview`}`} isPublic={isPublic} toolbarActions={ <> diff --git a/surfsense_web/components/settings/agent-model-manager.tsx b/surfsense_web/components/settings/agent-model-manager.tsx new file mode 100644 index 000000000..507a263e0 --- /dev/null +++ b/surfsense_web/components/settings/agent-model-manager.tsx @@ -0,0 +1,423 @@ +"use client"; + +import { useAtomValue } from "jotai"; +import { AlertCircle, Dot, FileText, Info, Pencil, RefreshCw, Trash2 } from "lucide-react"; +import { useMemo, useState } from "react"; +import { membersAtom, myAccessAtom } from "@/atoms/members/members-query.atoms"; +import { deleteNewLLMConfigMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; +import { + globalNewLLMConfigsAtom, + newLLMConfigsAtom, +} from "@/atoms/new-llm-config/new-llm-config-query.atoms"; +import { ModelConfigDialog } from "@/components/shared/model-config-dialog"; +import { Alert, AlertDescription } from "@/components/ui/alert"; +import { + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, +} from "@/components/ui/alert-dialog"; +import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar"; +import { Badge } from "@/components/ui/badge"; +import { Button } from "@/components/ui/button"; +import { Card, CardContent } from "@/components/ui/card"; +import { Separator } from "@/components/ui/separator"; +import { Skeleton } from "@/components/ui/skeleton"; +import { Spinner } from "@/components/ui/spinner"; +import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"; +import type { NewLLMConfig } from "@/contracts/types/new-llm-config.types"; +import { useMediaQuery } from "@/hooks/use-media-query"; +import { getProviderIcon } from "@/lib/provider-icons"; +import { cn } from "@/lib/utils"; + +interface AgentModelManagerProps { + searchSpaceId: number; +} + +function getInitials(name: string): string { + const parts = name.trim().split(/\s+/); + if (parts.length >= 2) { + return (parts[0][0] + parts[1][0]).toUpperCase(); + } + return name.slice(0, 2).toUpperCase(); +} + +export function AgentModelManager({ searchSpaceId }: AgentModelManagerProps) { + const isDesktop = useMediaQuery("(min-width: 768px)"); + // Mutations + const { mutateAsync: deleteConfig, isPending: isDeleting } = useAtomValue( + deleteNewLLMConfigMutationAtom + ); + + // Queries + const { + data: configs, + isFetching: isLoading, + error: fetchError, + refetch: refreshConfigs, + } = useAtomValue(newLLMConfigsAtom); + const { data: globalConfigs = [] } = useAtomValue(globalNewLLMConfigsAtom); + + // Members for user resolution + const { data: members } = useAtomValue(membersAtom); + const memberMap = useMemo(() => { + const map = new Map<string, { name: string; email?: string; avatarUrl?: string }>(); + if (members) { + for (const m of members) { + map.set(m.user_id, { + name: m.user_display_name || m.user_email || "Unknown", + email: m.user_email || undefined, + avatarUrl: m.user_avatar_url || undefined, + }); + } + } + return map; + }, [members]); + + // Permissions + const { data: access } = useAtomValue(myAccessAtom); + const canCreate = + !!access && (access.is_owner || (access.permissions?.includes("llm_configs:create") ?? false)); + const canUpdate = + !!access && (access.is_owner || (access.permissions?.includes("llm_configs:update") ?? false)); + const canDelete = + !!access && (access.is_owner || (access.permissions?.includes("llm_configs:delete") ?? false)); + const isReadOnly = !canCreate && !canUpdate && !canDelete; + + // Local state + const [isDialogOpen, setIsDialogOpen] = useState(false); + const [editingConfig, setEditingConfig] = useState<NewLLMConfig | null>(null); + const [configToDelete, setConfigToDelete] = useState<NewLLMConfig | null>(null); + + const handleDelete = async () => { + if (!configToDelete) return; + try { + await deleteConfig({ id: configToDelete.id, name: configToDelete.name }); + setConfigToDelete(null); + } catch { + // Error handled by mutation state + } + }; + + const openEditDialog = (config: NewLLMConfig) => { + setEditingConfig(config); + setIsDialogOpen(true); + }; + + const openNewDialog = () => { + setEditingConfig(null); + setIsDialogOpen(true); + }; + + return ( + <div className="space-y-5 md:space-y-6"> + {/* Header actions */} + <div className="flex items-center justify-between"> + <Button + variant="secondary" + size="sm" + onClick={() => refreshConfigs()} + disabled={isLoading} + className="gap-2" + > + <RefreshCw className={cn("h-3.5 w-3.5", isLoading && "animate-spin")} /> + Refresh + </Button> + {canCreate && ( + <Button + variant="outline" + onClick={openNewDialog} + className="gap-2 border-transparent bg-white text-[#1f1f1f] font-medium hover:bg-zinc-100 hover:text-[#1f1f1f] dark:border-transparent dark:bg-white dark:text-[#1f1f1f]" + > + Add Model + </Button> + )} + </div> + + {/* Fetch Error Alert */} + {fetchError && ( + <div> + <Alert variant="destructive" className="py-3 md:py-4"> + <AlertCircle className="h-3 w-3 md:h-4 md:w-4 shrink-0" /> + <AlertDescription className="text-xs md:text-sm"> + {fetchError?.message ?? "Failed to load configurations"} + </AlertDescription> + </Alert> + </div> + )} + + {/* Read-only / Limited permissions notice */} + {access && !isLoading && isReadOnly && ( + <div> + <Alert> + <Info /> + <AlertDescription> + <p> + You have <span className="font-medium">read-only</span> access to LLM + configurations. Contact a space owner to request additional permissions. + </p> + </AlertDescription> + </Alert> + </div> + )} + {access && !isLoading && !isReadOnly && (!canCreate || !canUpdate || !canDelete) && ( + <div> + <Alert> + <Info /> + <AlertDescription> + <p> + You can{" "} + {[canCreate && "create", canUpdate && "edit", canDelete && "delete"] + .filter(Boolean) + .join(" and ")}{" "} + configurations + {!canDelete && ", but cannot delete them"}. + </p> + </AlertDescription> + </Alert> + </div> + )} + + {/* Global Configs Info */} + {(isLoading || globalConfigs.length > 0) && ( + <Alert> + <Info /> + <AlertDescription> + {isLoading ? ( + <div className="flex min-h-[1.625em] items-center"> + <Skeleton className="h-4 w-60 bg-accent-foreground/15" /> + </div> + ) : ( + <p> + <span className="font-medium"> + {globalConfigs.length} global {globalConfigs.length === 1 ? "model" : "models"} + </span>{" "} + available from your administrator. + </p> + )} + </AlertDescription> + </Alert> + )} + + {/* Loading Skeleton */} + {isLoading && ( + <div className="grid gap-3 grid-cols-1 sm:grid-cols-2 xl:grid-cols-3"> + {["skeleton-a", "skeleton-b", "skeleton-c"].map((key) => ( + <Card key={key} className="border-accent bg-accent/20"> + <CardContent className="p-4 flex flex-col gap-3 min-h-32"> + <Skeleton className="h-4 w-32 md:w-40 bg-accent" /> + <Skeleton className="h-3 w-full bg-accent" /> + <Skeleton className="h-3 w-24 md:w-28 bg-accent mt-auto" /> + </CardContent> + </Card> + ))} + </div> + )} + + {/* Configurations List */} + {!isLoading && ( + <div className="space-y-4"> + {configs?.length === 0 ? ( + <div> + <Card className="border-0 bg-transparent shadow-none"> + <CardContent className="flex flex-col items-center justify-center py-10 md:py-16 text-center"> + <h3 className="text-sm md:text-base font-semibold mb-2">No Models Yet</h3> + <p className="text-[11px] md:text-xs text-muted-foreground max-w-sm mb-4"> + {canCreate + ? "Add your first model to power chat, reports, and other agent capabilities" + : "No models have been added to this space yet. Contact a space owner to add one"} + </p> + </CardContent> + </Card> + </div> + ) : ( + <div className="grid gap-3 grid-cols-1 sm:grid-cols-2 xl:grid-cols-3"> + {configs?.map((config) => { + const member = config.user_id ? memberMap.get(config.user_id) : null; + + return ( + <div key={config.id}> + <Card className="group relative overflow-hidden transition-all duration-200 border-accent bg-accent/20 hover:shadow-md h-full"> + <CardContent className="p-4 flex flex-col gap-3 h-full"> + {/* Header: Icon + Name + Actions */} + <div className="flex items-center justify-between gap-2"> + <div className="flex items-center gap-2.5 min-w-0 flex-1"> + <div className="shrink-0"> + {getProviderIcon(config.provider, { className: "size-4" })} + </div> + <div className="min-w-0 flex-1"> + <h4 className="text-sm font-semibold tracking-tight truncate"> + {config.name} + </h4> + {config.description && ( + <p className="text-[11px] text-muted-foreground/70 truncate mt-0.5"> + {config.description} + </p> + )} + </div> + </div> + {(canUpdate || canDelete) && ( + <div className="flex items-center gap-1 shrink-0 sm:w-0 sm:overflow-hidden sm:group-hover:w-auto sm:opacity-0 sm:group-hover:opacity-100 transition-all duration-150"> + {canUpdate && ( + <TooltipProvider> + <Tooltip open={isDesktop ? undefined : false}> + <TooltipTrigger asChild> + <Button + variant="ghost" + size="icon" + onClick={() => openEditDialog(config)} + className="h-7 w-7 rounded-lg text-muted-foreground hover:text-accent-foreground" + > + <Pencil className="h-3 w-3" /> + </Button> + </TooltipTrigger> + <TooltipContent>Edit</TooltipContent> + </Tooltip> + </TooltipProvider> + )} + {canDelete && ( + <TooltipProvider> + <Tooltip open={isDesktop ? undefined : false}> + <TooltipTrigger asChild> + <Button + variant="ghost" + size="icon" + onClick={() => setConfigToDelete(config)} + className="h-7 w-7 rounded-lg text-muted-foreground hover:text-destructive" + > + <Trash2 className="h-3 w-3" /> + </Button> + </TooltipTrigger> + <TooltipContent>Delete</TooltipContent> + </Tooltip> + </TooltipProvider> + )} + </div> + )} + </div> + + {/* Feature badges */} + <div className="flex items-center gap-1.5 flex-wrap"> + {config.citations_enabled && ( + <Badge + variant="secondary" + className="text-[10px] px-1.5 py-0.5 border-0 text-muted-foreground bg-muted" + > + Citations + </Badge> + )} + {!config.use_default_system_instructions && + config.system_instructions && ( + <Badge + variant="secondary" + className="text-[10px] px-1.5 py-0.5 border-0 text-muted-foreground bg-muted" + > + <FileText className="h-2.5 w-2.5 mr-1" /> + Custom + </Badge> + )} + </div> + + {/* Footer: Date + Creator */} + <div className="mt-auto space-y-2"> + <Separator className="bg-accent" /> + <div className="flex items-center"> + <span className="shrink-0 text-[11px] text-muted-foreground/60 whitespace-nowrap"> + {new Date(config.created_at).toLocaleDateString(undefined, { + year: "numeric", + month: "short", + day: "numeric", + })} + </span> + {member && ( + <> + <Dot className="h-4 w-4 text-muted-foreground/30 shrink-0" /> + <TooltipProvider> + <Tooltip open={isDesktop ? undefined : false}> + <TooltipTrigger asChild> + <div className="min-w-0 flex items-center gap-1.5 cursor-default"> + <Avatar className="size-4.5 shrink-0"> + {member.avatarUrl && ( + <AvatarImage src={member.avatarUrl} alt={member.name} /> + )} + <AvatarFallback className="text-[9px]"> + {getInitials(member.name)} + </AvatarFallback> + </Avatar> + <span className="text-[11px] text-muted-foreground/60 truncate"> + {member.name} + </span> + </div> + </TooltipTrigger> + <TooltipContent side="bottom"> + {member.email || member.name} + </TooltipContent> + </Tooltip> + </TooltipProvider> + </> + )} + </div> + </div> + </CardContent> + </Card> + </div> + ); + })} + </div> + )} + </div> + )} + + {/* Add/Edit Configuration Dialog */} + <ModelConfigDialog + open={isDialogOpen} + onOpenChange={(open) => { + setIsDialogOpen(open); + if (!open) setEditingConfig(null); + }} + config={editingConfig} + isGlobal={false} + searchSpaceId={searchSpaceId} + mode={editingConfig ? "edit" : "create"} + /> + + {/* Delete Confirmation Dialog */} + <AlertDialog + open={!!configToDelete} + onOpenChange={(open) => !open && setConfigToDelete(null)} + > + <AlertDialogContent className="select-none"> + <AlertDialogHeader> + <AlertDialogTitle>Delete Model</AlertDialogTitle> + <AlertDialogDescription> + Are you sure you want to delete{" "} + <span className="font-semibold text-foreground">{configToDelete?.name}</span>? This + action cannot be undone. + </AlertDialogDescription> + </AlertDialogHeader> + <AlertDialogFooter> + <AlertDialogCancel disabled={isDeleting}>Cancel</AlertDialogCancel> + <AlertDialogAction + onClick={handleDelete} + disabled={isDeleting} + className="bg-destructive text-destructive-foreground hover:bg-destructive/90" + > + {isDeleting ? ( + <> + <Spinner size="sm" className="mr-2" /> + Deleting + </> + ) : ( + "Delete" + )} + </AlertDialogAction> + </AlertDialogFooter> + </AlertDialogContent> + </AlertDialog> + </div> + ); +} diff --git a/surfsense_web/components/settings/auto-reload-settings.tsx b/surfsense_web/components/settings/auto-reload-settings.tsx deleted file mode 100644 index fbb7cbfb9..000000000 --- a/surfsense_web/components/settings/auto-reload-settings.tsx +++ /dev/null @@ -1,276 +0,0 @@ -"use client"; - -import { useQuery as useZeroQuery } from "@rocicorp/zero/react"; -import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; -import { AlertTriangle, CreditCard, RefreshCw } from "lucide-react"; -import { useParams, usePathname, useRouter, useSearchParams } from "next/navigation"; -import { useEffect, useRef, useState } from "react"; -import { toast } from "sonner"; -import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; -import { Button } from "@/components/ui/button"; -import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; -import { Input } from "@/components/ui/input"; -import { Label } from "@/components/ui/label"; -import { Spinner } from "@/components/ui/spinner"; -import { Switch } from "@/components/ui/switch"; -import { stripeApiService } from "@/lib/apis/stripe-api.service"; -import { AppError } from "@/lib/error"; -import { queries } from "@/zero/queries"; - -const microsToDollars = (micros: number | null | undefined): string => { - if (micros == null) return ""; - return (micros / 1_000_000).toString(); -}; - -const dollarsToMicros = (value: string): number | null => { - const trimmed = value.trim(); - if (trimmed === "") return null; - const dollars = Number(trimmed); - if (!Number.isFinite(dollars) || dollars < 0) return null; - return Math.round(dollars * 1_000_000); -}; - -const formatUsd = (micros: number) => `$${(Math.max(0, micros) / 1_000_000).toFixed(2)}`; - -export function AutoReloadSettings() { - const params = useParams(); - const router = useRouter(); - const pathname = usePathname(); - const searchParams = useSearchParams(); - const queryClient = useQueryClient(); - const searchSpaceId = Number(params?.search_space_id); - - const [enabled, setEnabled] = useState(false); - const [thresholdInput, setThresholdInput] = useState(""); - const [amountInput, setAmountInput] = useState(""); - const seededRef = useRef(false); - - const [me] = useZeroQuery(queries.user.me({})); - const balanceMicros = me?.creditMicrosBalance ?? 0; - - const { data: settings, isLoading } = useQuery({ - queryKey: ["auto-reload-settings"], - queryFn: () => stripeApiService.getAutoReloadSettings(), - }); - - // Seed the form once from the server, then let the user own the inputs. - useEffect(() => { - if (settings && !seededRef.current) { - seededRef.current = true; - setEnabled(settings.enabled); - setThresholdInput(microsToDollars(settings.threshold_micros)); - setAmountInput(microsToDollars(settings.amount_micros)); - } - }, [settings]); - - // Surface the result of the Stripe card-setup redirect. - useEffect(() => { - const setupResult = searchParams.get("auto_reload_setup"); - if (!setupResult) return; - if (setupResult === "success") { - toast.success("Card saved. You can now enable auto-reload."); - queryClient.invalidateQueries({ queryKey: ["auto-reload-settings"] }); - } else if (setupResult === "cancel") { - toast.info("Card setup canceled."); - } - // Strip the query param so refreshes don't re-toast. - router.replace(pathname); - }, [searchParams, router, pathname, queryClient]); - - const setupMutation = useMutation({ - mutationFn: () => - stripeApiService.createAutoReloadSetupSession({ search_space_id: searchSpaceId }), - onSuccess: (response) => { - window.location.assign(response.checkout_url); - }, - onError: () => { - toast.error("Couldn't start card setup. Please try again."); - }, - }); - - const saveMutation = useMutation({ - mutationFn: stripeApiService.updateAutoReloadSettings, - onSuccess: (updated) => { - queryClient.setQueryData(["auto-reload-settings"], updated); - toast.success(updated.enabled ? "Auto-reload is on." : "Auto-reload settings saved."); - }, - onError: (error) => { - if (error instanceof AppError && error.message) { - toast.error(error.message); - return; - } - toast.error("Couldn't save auto-reload settings. Please try again."); - }, - }); - - // Render nothing while loading (avoids a spinner flash on pages where the - // feature flag turns out to be off) and when auto-reload is disabled - // server-side. - if (isLoading || !settings || !settings.feature_enabled) { - return null; - } - - const minAmountDollars = (settings.min_amount_micros / 1_000_000).toFixed(2); - const hasCard = settings.has_payment_method; - - const handleSave = () => { - if (!enabled) { - saveMutation.mutate({ - enabled: false, - threshold_micros: dollarsToMicros(thresholdInput), - amount_micros: dollarsToMicros(amountInput), - }); - return; - } - - const thresholdMicros = dollarsToMicros(thresholdInput); - const amountMicros = dollarsToMicros(amountInput); - - if (!thresholdMicros || thresholdMicros <= 0) { - toast.error("Enter a low-balance threshold greater than $0."); - return; - } - if (amountMicros == null || amountMicros < settings.min_amount_micros) { - toast.error(`Reload amount must be at least $${minAmountDollars}.`); - return; - } - - saveMutation.mutate({ - enabled: true, - threshold_micros: thresholdMicros, - amount_micros: amountMicros, - }); - }; - - return ( - <Card> - <CardHeader> - <CardTitle className="flex items-center gap-2 text-base"> - <RefreshCw className="h-4 w-4 text-amber-500" /> - Auto-reload - </CardTitle> - <CardDescription> - Automatically top up your credit balance when it drops below a threshold, using a saved - card. Current balance:{" "} - <span className="font-medium text-foreground">{formatUsd(balanceMicros)}</span>. - </CardDescription> - </CardHeader> - <CardContent className="space-y-5"> - {settings.failed_at && ( - <Alert variant="destructive"> - <AlertTriangle className="h-4 w-4" /> - <AlertTitle>Last auto-reload failed</AlertTitle> - <AlertDescription> - Your saved card was declined and auto-reload was turned off. Update your card and - re-enable it below to keep topping up automatically. - </AlertDescription> - </Alert> - )} - - {!hasCard ? ( - <div className="flex flex-col items-start gap-3 rounded-lg border bg-muted/20 p-4"> - <div className="flex items-center gap-2 text-sm"> - <CreditCard className="h-4 w-4 text-muted-foreground" /> - <span>Add a card to enable automatic top-ups.</span> - </div> - <Button onClick={() => setupMutation.mutate()} disabled={setupMutation.isPending}> - {setupMutation.isPending ? ( - <> - <Spinner size="xs" /> - Redirecting - </> - ) : ( - "Add a card" - )} - </Button> - </div> - ) : ( - <> - <div className="flex items-center justify-between gap-4"> - <div className="space-y-0.5"> - <Label htmlFor="auto-reload-toggle" className="text-sm font-medium"> - Enable auto-reload - </Label> - <p className="text-xs text-muted-foreground"> - Charge your saved card when the balance gets low. - </p> - </div> - <Switch id="auto-reload-toggle" checked={enabled} onCheckedChange={setEnabled} /> - </div> - - <div className="grid gap-4 sm:grid-cols-2"> - <div className="space-y-1.5"> - <Label htmlFor="auto-reload-threshold" className="text-xs"> - When balance falls below - </Label> - <div className="relative"> - <span className="pointer-events-none absolute left-3 top-1/2 -translate-y-1/2 text-sm text-muted-foreground"> - $ - </span> - <Input - id="auto-reload-threshold" - type="number" - min="0" - step="1" - inputMode="decimal" - className="pl-6 tabular-nums" - value={thresholdInput} - onChange={(e) => setThresholdInput(e.target.value)} - disabled={!enabled} - placeholder="5" - /> - </div> - </div> - <div className="space-y-1.5"> - <Label htmlFor="auto-reload-amount" className="text-xs"> - Add this much credit - </Label> - <div className="relative"> - <span className="pointer-events-none absolute left-3 top-1/2 -translate-y-1/2 text-sm text-muted-foreground"> - $ - </span> - <Input - id="auto-reload-amount" - type="number" - min={minAmountDollars} - step="1" - inputMode="decimal" - className="pl-6 tabular-nums" - value={amountInput} - onChange={(e) => setAmountInput(e.target.value)} - disabled={!enabled} - placeholder="10" - /> - </div> - <p className="text-[11px] text-muted-foreground">Minimum ${minAmountDollars}.</p> - </div> - </div> - - <div className="flex items-center justify-between gap-3"> - <Button - variant="ghost" - size="sm" - className="text-muted-foreground" - onClick={() => setupMutation.mutate()} - disabled={setupMutation.isPending} - > - <CreditCard className="h-3.5 w-3.5" /> - Update card - </Button> - <Button onClick={handleSave} disabled={saveMutation.isPending}> - {saveMutation.isPending ? ( - <> - <Spinner size="xs" /> - Saving - </> - ) : ( - "Save" - )} - </Button> - </div> - </> - )} - </CardContent> - </Card> - ); -} diff --git a/surfsense_web/components/settings/buy-pages-content.tsx b/surfsense_web/components/settings/buy-pages-content.tsx new file mode 100644 index 000000000..82b8d8e2a --- /dev/null +++ b/surfsense_web/components/settings/buy-pages-content.tsx @@ -0,0 +1,148 @@ +"use client"; + +import { useMutation, useQuery } from "@tanstack/react-query"; +import { Minus, Plus } from "lucide-react"; +import { useParams } from "next/navigation"; +import { useState } from "react"; +import { toast } from "sonner"; +import { Button } from "@/components/ui/button"; +import { Spinner } from "@/components/ui/spinner"; +import { stripeApiService } from "@/lib/apis/stripe-api.service"; +import { AppError } from "@/lib/error"; +import { cn } from "@/lib/utils"; + +const PAGE_PACK_SIZE = 1000; +const PRICE_PER_PACK_USD = 1; +const PRESET_MULTIPLIERS = [1, 2, 5, 10, 25, 50] as const; + +export function BuyPagesContent() { + const params = useParams(); + const [quantity, setQuantity] = useState(1); + const { data: stripeStatus } = useQuery({ + queryKey: ["stripe-status"], + queryFn: () => stripeApiService.getStatus(), + }); + + const purchaseMutation = useMutation({ + mutationFn: stripeApiService.createCheckoutSession, + onSuccess: (response) => { + window.location.assign(response.checkout_url); + }, + onError: (error) => { + if (error instanceof AppError && error.message) { + toast.error(error.message); + return; + } + toast.error("Failed to start checkout. Please try again."); + }, + }); + + const searchSpaceId = Number(params.search_space_id); + const hasValidSearchSpace = Number.isFinite(searchSpaceId) && searchSpaceId > 0; + const totalPages = quantity * PAGE_PACK_SIZE; + const totalPrice = quantity * PRICE_PER_PACK_USD; + + if (stripeStatus && !stripeStatus.page_buying_enabled) { + return ( + <div className="w-full space-y-3 text-center"> + <h2 className="text-xl font-bold tracking-tight">Buy Pages</h2> + <p className="text-sm text-muted-foreground">Page purchases are temporarily unavailable.</p> + </div> + ); + } + + const handleBuyNow = () => { + if (!hasValidSearchSpace) { + toast.error("Unable to determine the current workspace for checkout."); + return; + } + purchaseMutation.mutate({ + quantity, + search_space_id: searchSpaceId, + }); + }; + + return ( + <div className="w-full space-y-5"> + <div className="text-center"> + <h2 className="text-xl font-bold tracking-tight">Buy Pages</h2> + <p className="mt-1 text-sm text-muted-foreground">$1 per 1,000 pages, pay as you go</p> + </div> + + <div className="space-y-3"> + {/* Stepper */} + <div className="flex items-center justify-center gap-3"> + <Button + type="button" + variant="ghost" + size="icon" + onClick={() => setQuantity((q) => Math.max(1, q - 1))} + disabled={quantity <= 1 || purchaseMutation.isPending} + className="size-8 text-muted-foreground shadow-none transition-colors hover:bg-muted hover:text-white disabled:opacity-40" + > + <Minus className="h-3.5 w-3.5" /> + </Button> + <span className="min-w-28 text-center text-lg font-semibold tabular-nums"> + {totalPages.toLocaleString()} + </span> + <Button + type="button" + variant="ghost" + size="icon" + onClick={() => setQuantity((q) => Math.min(100, q + 1))} + disabled={quantity >= 100 || purchaseMutation.isPending} + className="size-8 text-muted-foreground shadow-none transition-colors hover:bg-muted hover:text-white disabled:opacity-40" + > + <Plus className="h-3.5 w-3.5" /> + </Button> + </div> + + {/* Quick-pick presets */} + <div className="flex flex-wrap justify-center gap-1.5"> + {PRESET_MULTIPLIERS.map((m) => ( + <Button + key={m} + type="button" + variant="ghost" + onClick={() => setQuantity(m)} + disabled={purchaseMutation.isPending} + className={cn( + "h-auto rounded-md px-2.5 py-1 text-xs font-medium tabular-nums transition-colors disabled:opacity-60", + quantity === m + ? "bg-accent text-accent-foreground" + : "text-muted-foreground hover:bg-accent hover:text-accent-foreground" + )} + > + {(m * PAGE_PACK_SIZE).toLocaleString()} + </Button> + ))} + </div> + + <div className="flex items-center justify-between rounded-lg border bg-muted/30 px-3 py-2"> + <span className="text-sm font-medium tabular-nums"> + {totalPages.toLocaleString()} pages + </span> + <span className="text-sm font-semibold tabular-nums">${totalPrice}</span> + </div> + + <Button + className="w-full" + disabled={purchaseMutation.isPending || !hasValidSearchSpace} + onClick={handleBuyNow} + > + {purchaseMutation.isPending ? ( + <> + <Spinner size="xs" /> + Redirecting + </> + ) : ( + <> + Buy {totalPages.toLocaleString()} Pages for ${totalPrice} + </> + )} + </Button> + <p className="text-center text-[11px] text-muted-foreground">Secure checkout via Stripe</p> + </div> + </div> + ); +} diff --git a/surfsense_web/components/settings/buy-credits-content.tsx b/surfsense_web/components/settings/buy-tokens-content.tsx similarity index 57% rename from surfsense_web/components/settings/buy-credits-content.tsx rename to surfsense_web/components/settings/buy-tokens-content.tsx index 8cb339420..4b0605f28 100644 --- a/surfsense_web/components/settings/buy-credits-content.tsx +++ b/surfsense_web/components/settings/buy-tokens-content.tsx @@ -7,59 +7,46 @@ import { useParams } from "next/navigation"; import { useState } from "react"; import { toast } from "sonner"; import { Button } from "@/components/ui/button"; +import { Progress } from "@/components/ui/progress"; import { Spinner } from "@/components/ui/spinner"; import { stripeApiService } from "@/lib/apis/stripe-api.service"; import { AppError } from "@/lib/error"; import { cn } from "@/lib/utils"; import { queries } from "@/zero/queries"; -// One pack = $1.00 of credit, stored as 1_000_000 micro-USD on the backend. -// ETL page processing and premium turns are both debited from the same wallet -// at the actual cost, so $1 of credit always buys $1 of usage at cost. +// One pack = $1.00 of credit, stored as 1_000_000 micro-USD on the +// backend. Premium turns are debited at the actual provider cost +// reported by LiteLLM, so $1 of credit always buys $1 of provider +// usage at cost. const CREDIT_PER_PACK_MICROS = 1_000_000; const PRICE_PER_PACK_USD = 1; -const PRESET_MULTIPLIERS = [1, 2, 5, 10, 25, 50, 100] as const; -const MIN_QUANTITY = 1; -const MAX_QUANTITY = 10_000; +const PRESET_MULTIPLIERS = [1, 2, 5, 10, 25, 50] as const; -const clampQuantity = (value: number) => - Math.min(MAX_QUANTITY, Math.max(MIN_QUANTITY, Math.floor(value))); - -const formatUsd = (micros: number) => { - // Clamp at $0.00 — the balance can dip slightly negative when actual cost - // exceeds the pre-charge estimate. - const dollars = Math.max(0, micros) / 1_000_000; +const formatUsd = (micros: number, options?: { compact?: boolean }) => { + const dollars = micros / 1_000_000; + if (options?.compact && dollars >= 1) return `$${dollars.toFixed(2)}`; if (dollars >= 100) return `$${dollars.toFixed(0)}`; if (dollars >= 1) return `$${dollars.toFixed(2)}`; if (dollars > 0) return `$${dollars.toFixed(3)}`; - return "$0.00"; + return "$0"; }; -export function BuyCreditsContent() { +export function BuyTokensContent() { const params = useParams(); const searchSpaceId = Number(params?.search_space_id); const [quantity, setQuantity] = useState(1); - // Raw text of the amount field so the user can clear it while typing; - // committed back to a clamped integer on blur. - const [amountInput, setAmountInput] = useState("1"); - - const commitQuantity = (value: number) => { - const clamped = clampQuantity(Number.isFinite(value) ? value : MIN_QUANTITY); - setQuantity(clamped); - setAmountInput(String(clamped)); - }; // Server config flag: stays on REST, not per-user. - const { data: creditStatus } = useQuery({ - queryKey: ["credit-status"], - queryFn: () => stripeApiService.getCreditStatus(), + const { data: tokenStatus } = useQuery({ + queryKey: ["token-status"], + queryFn: () => stripeApiService.getTokenStatus(), }); // Live per-user balance via Zero. const [me] = useZeroQuery(queries.user.me({})); const purchaseMutation = useMutation({ - mutationFn: stripeApiService.createCreditCheckoutSession, + mutationFn: stripeApiService.createTokenCheckoutSession, onSuccess: (response) => { window.location.assign(response.checkout_url); }, @@ -75,10 +62,10 @@ export function BuyCreditsContent() { const totalCreditMicros = quantity * CREDIT_PER_PACK_MICROS; const totalPrice = quantity * PRICE_PER_PACK_USD; - if (creditStatus && !creditStatus.credit_buying_enabled) { + if (tokenStatus && !tokenStatus.token_buying_enabled) { return ( <div className="w-full space-y-3 text-center"> - <h2 className="text-xl font-bold tracking-tight">Buy Credits</h2> + <h2 className="text-xl font-bold tracking-tight">Buy Premium Credit</h2> <p className="text-sm text-muted-foreground"> Credit purchases are temporarily unavailable. </p> @@ -86,20 +73,35 @@ export function BuyCreditsContent() { ); } - const balanceMicros = me?.creditMicrosBalance ?? creditStatus?.credit_micros_balance ?? 0; + const used = me?.premiumCreditMicrosUsed ?? 0; + const limit = me?.premiumCreditMicrosLimit ?? 0; + // Mirrors the backend formula in stripe_routes.py (max(0, limit - used)). + const remaining = Math.max(0, limit - used); + const usagePercentage = me ? Math.min((used / Math.max(limit, 1)) * 100, 100) : 0; return ( <div className="w-full space-y-5"> <div className="text-center"> - <h2 className="text-xl font-bold tracking-tight">Buy Credits</h2> + <h2 className="text-xl font-bold tracking-tight">Buy Premium Credit</h2> + <p className="mt-1 text-sm text-muted-foreground"> + $1 buys $1 of credit, billed at provider cost + </p> </div> - <div className="rounded-lg border bg-muted/20 p-3"> - <div className="flex items-center justify-between text-sm"> - <span className="text-muted-foreground">Current balance</span> - <span className="font-semibold tabular-nums">{formatUsd(balanceMicros)}</span> + {me && ( + <div className="rounded-lg border bg-muted/20 p-3 space-y-1.5"> + <div className="flex justify-between items-center text-xs"> + <span className="text-muted-foreground"> + {formatUsd(used)} / {formatUsd(limit)} of credit + </span> + <span className="font-medium">{usagePercentage.toFixed(0)}%</span> + </div> + <Progress value={usagePercentage} className="h-1.5 [&>div]:bg-purple-500" /> + <p className="text-[11px] text-muted-foreground"> + {formatUsd(remaining)} of credit remaining + </p> </div> - </div> + )} <div className="space-y-3"> <div className="flex items-center justify-center gap-3"> @@ -107,39 +109,21 @@ export function BuyCreditsContent() { type="button" variant="ghost" size="icon" - onClick={() => commitQuantity(quantity - 1)} - disabled={quantity <= MIN_QUANTITY || purchaseMutation.isPending} + onClick={() => setQuantity((q) => Math.max(1, q - 1))} + disabled={quantity <= 1 || purchaseMutation.isPending} className="size-8 text-muted-foreground shadow-none transition-colors hover:bg-muted hover:text-white disabled:opacity-40" > <Minus className="h-3.5 w-3.5" /> </Button> - <div className="flex items-baseline gap-1.5"> - <span className="text-lg font-semibold">$</span> - <input - type="text" - inputMode="numeric" - value={amountInput} - onChange={(e) => { - const raw = e.target.value.replace(/[^0-9]/g, ""); - setAmountInput(raw); - const parsed = Number.parseInt(raw, 10); - if (Number.isFinite(parsed)) { - setQuantity(clampQuantity(parsed)); - } - }} - onBlur={() => commitQuantity(Number.parseInt(amountInput, 10))} - disabled={purchaseMutation.isPending} - aria-label="Credit amount in US dollars" - className="w-20 rounded-md border bg-transparent px-2 py-1 text-center text-lg font-semibold tabular-nums outline-none focus:ring-2 focus:ring-ring disabled:opacity-60" - /> - <span className="text-sm text-muted-foreground">of credit</span> - </div> + <span className="min-w-32 text-center text-lg font-semibold tabular-nums"> + ${(totalCreditMicros / 1_000_000).toFixed(0)} of credit + </span> <Button type="button" variant="ghost" size="icon" - onClick={() => commitQuantity(quantity + 1)} - disabled={quantity >= MAX_QUANTITY || purchaseMutation.isPending} + onClick={() => setQuantity((q) => Math.min(100, q + 1))} + disabled={quantity >= 100 || purchaseMutation.isPending} className="size-8 text-muted-foreground shadow-none transition-colors hover:bg-muted hover:text-white disabled:opacity-40" > <Plus className="h-3.5 w-3.5" /> @@ -152,7 +136,7 @@ export function BuyCreditsContent() { key={m} type="button" variant="ghost" - onClick={() => commitQuantity(m)} + onClick={() => setQuantity(m)} disabled={purchaseMutation.isPending} className={cn( "h-auto rounded-md px-2.5 py-1 text-xs font-medium tabular-nums transition-colors disabled:opacity-60", diff --git a/surfsense_web/components/settings/general-settings-manager.tsx b/surfsense_web/components/settings/general-settings-manager.tsx index 68ff21f07..a308acfad 100644 --- a/surfsense_web/components/settings/general-settings-manager.tsx +++ b/surfsense_web/components/settings/general-settings-manager.tsx @@ -12,7 +12,7 @@ import { Label } from "@/components/ui/label"; import { Skeleton } from "@/components/ui/skeleton"; import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service"; import { authenticatedFetch } from "@/lib/auth-utils"; -import { buildBackendUrl } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; import { cacheKeys } from "@/lib/query-client/cache-keys"; import { Spinner } from "../ui/spinner"; @@ -49,7 +49,7 @@ export function GeneralSettingsManager({ searchSpaceId }: GeneralSettingsManager setIsExporting(true); try { const response = await authenticatedFetch( - buildBackendUrl(`/api/v1/search-spaces/${searchSpaceId}/export`), + `${BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/export`, { method: "GET" } ); if (!response.ok) { diff --git a/surfsense_web/components/settings/image-model-manager.tsx b/surfsense_web/components/settings/image-model-manager.tsx new file mode 100644 index 000000000..494f7aae9 --- /dev/null +++ b/surfsense_web/components/settings/image-model-manager.tsx @@ -0,0 +1,489 @@ +"use client"; + +import { useAtomValue } from "jotai"; +import { AlertCircle, Dot, Info, Pencil, RefreshCw, Trash2 } from "lucide-react"; +import { useMemo, useState } from "react"; +import { deleteImageGenConfigMutationAtom } from "@/atoms/image-gen-config/image-gen-config-mutation.atoms"; +import { + globalImageGenConfigsAtom, + imageGenConfigsAtom, +} from "@/atoms/image-gen-config/image-gen-config-query.atoms"; +import { membersAtom, myAccessAtom } from "@/atoms/members/members-query.atoms"; +import { ImageConfigDialog } from "@/components/shared/image-config-dialog"; +import { Alert, AlertDescription } from "@/components/ui/alert"; +import { + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, +} from "@/components/ui/alert-dialog"; +import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar"; +import { Badge } from "@/components/ui/badge"; +import { Button } from "@/components/ui/button"; +import { Card, CardContent } from "@/components/ui/card"; +import { Separator } from "@/components/ui/separator"; +import { Skeleton } from "@/components/ui/skeleton"; +import { Spinner } from "@/components/ui/spinner"; +import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"; +import type { ImageGenerationConfig } from "@/contracts/types/new-llm-config.types"; +import { useMediaQuery } from "@/hooks/use-media-query"; +import { getProviderIcon } from "@/lib/provider-icons"; +import { cn } from "@/lib/utils"; + +interface ImageModelManagerProps { + searchSpaceId: number; +} + +function getInitials(name: string): string { + const parts = name.trim().split(/\s+/); + if (parts.length >= 2) { + return (parts[0][0] + parts[1][0]).toUpperCase(); + } + return name.slice(0, 2).toUpperCase(); +} + +export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) { + const isDesktop = useMediaQuery("(min-width: 768px)"); + + const { + mutateAsync: deleteConfig, + isPending: isDeleting, + error: deleteError, + } = useAtomValue(deleteImageGenConfigMutationAtom); + + const { + data: userConfigs, + isFetching: configsLoading, + error: fetchError, + refetch: refreshConfigs, + } = useAtomValue(imageGenConfigsAtom); + const { data: globalConfigs = [], isFetching: globalLoading } = + useAtomValue(globalImageGenConfigsAtom); + + const { data: members } = useAtomValue(membersAtom); + const memberMap = useMemo(() => { + const map = new Map<string, { name: string; email?: string; avatarUrl?: string }>(); + if (members) { + for (const m of members) { + map.set(m.user_id, { + name: m.user_display_name || m.user_email || "Unknown", + email: m.user_email || undefined, + avatarUrl: m.user_avatar_url || undefined, + }); + } + } + return map; + }, [members]); + + const { data: access } = useAtomValue(myAccessAtom); + const canCreate = + !!access && + (access.is_owner || (access.permissions?.includes("image_generations:create") ?? false)); + const canDelete = + !!access && + (access.is_owner || (access.permissions?.includes("image_generations:delete") ?? false)); + const canUpdate = canCreate; + const isReadOnly = !canCreate && !canDelete; + + const [isDialogOpen, setIsDialogOpen] = useState(false); + const [editingConfig, setEditingConfig] = useState<ImageGenerationConfig | null>(null); + const [configToDelete, setConfigToDelete] = useState<ImageGenerationConfig | null>(null); + + const isLoading = configsLoading || globalLoading; + const errors = [deleteError, fetchError].filter(Boolean) as Error[]; + + const openEditDialog = (config: ImageGenerationConfig) => { + setEditingConfig(config); + setIsDialogOpen(true); + }; + + const openNewDialog = () => { + setEditingConfig(null); + setIsDialogOpen(true); + }; + + const handleDelete = async () => { + if (!configToDelete) return; + try { + await deleteConfig({ id: configToDelete.id, name: configToDelete.name }); + setConfigToDelete(null); + } catch { + // Error handled by mutation + } + }; + + return ( + <div className="space-y-4 md:space-y-6"> + {/* Header actions */} + <div className="flex items-center justify-between"> + <Button + variant="secondary" + size="sm" + onClick={() => refreshConfigs()} + disabled={isLoading} + className="gap-2" + > + <RefreshCw className={cn("h-3.5 w-3.5", configsLoading && "animate-spin")} /> + Refresh + </Button> + {canCreate && ( + <Button + variant="outline" + onClick={openNewDialog} + className="gap-2 border-transparent bg-white text-[#1f1f1f] font-medium hover:bg-zinc-100 hover:text-[#1f1f1f] dark:border-transparent dark:bg-white dark:text-[#1f1f1f]" + > + Add Image Model + </Button> + )} + </div> + + {/* Errors */} + {errors.map((err) => ( + <div key={err?.message}> + <Alert variant="destructive" className="py-3"> + <AlertCircle className="h-3 w-3 md:h-4 md:w-4 shrink-0" /> + <AlertDescription className="text-xs md:text-sm">{err?.message}</AlertDescription> + </Alert> + </div> + ))} + + {/* Read-only / Limited permissions notice */} + {access && !isLoading && isReadOnly && ( + <div> + <Alert> + <Info /> + <AlertDescription> + <p> + You have <span className="font-medium">read-only</span> access to image generation + configurations. Contact a space owner to request additional permissions. + </p> + </AlertDescription> + </Alert> + </div> + )} + {access && !isLoading && !isReadOnly && (!canCreate || !canDelete) && ( + <div> + <Alert> + <Info /> + <AlertDescription> + <p> + You can{" "} + {[canCreate && "create and edit", canDelete && "delete"] + .filter(Boolean) + .join(" and ")}{" "} + image model configurations + {!canDelete && ", but cannot delete them"}. + </p> + </AlertDescription> + </Alert> + </div> + )} + + {/* Global info */} + {(isLoading || + globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length > 0) && ( + <Alert> + <Info /> + <AlertDescription> + {isLoading ? ( + <div className="flex min-h-[1.625em] items-center"> + <Skeleton className="h-4 w-60 bg-accent-foreground/15" /> + </div> + ) : ( + <p> + <span className="font-medium"> + {globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length}{" "} + global image{" "} + {globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length === + 1 + ? "model" + : "models"} + </span>{" "} + available from your administrator. {(() => { + const nonAuto = globalConfigs.filter( + (g) => !("is_auto_mode" in g && g.is_auto_mode) + ); + const premium = nonAuto.filter( + (g) => + "billing_tier" in g && + (g as { billing_tier?: string }).billing_tier === "premium" + ).length; + const free = nonAuto.length - premium; + if (premium > 0 && free > 0) { + return `${premium} premium, ${free} free.`; + } + if (premium > 0) { + return `All ${premium} premium — debits your shared credit pool.`; + } + return `All ${free} free.`; + })()} + </p> + )} + </AlertDescription> + </Alert> + )} + + {/* Global Image Models — read-only cards with per-model Free/Premium + badges. Mirrors the badge palette used by the chat role selector + (`llm-role-manager.tsx`) so the meaning is consistent across + every model-configuration surface (chat / image / vision). */} + {!isLoading && + globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length > 0 && ( + <div className="space-y-3"> + <div className="grid gap-3 grid-cols-1 sm:grid-cols-2 xl:grid-cols-3"> + {globalConfigs + .filter((g) => !("is_auto_mode" in g && g.is_auto_mode)) + .map((cfg) => { + const billingTier = + ("billing_tier" in cfg && + typeof (cfg as { billing_tier?: string }).billing_tier === "string" && + (cfg as { billing_tier?: string }).billing_tier) || + "free"; + const isPremium = billingTier === "premium"; + return ( + <Card + key={cfg.id} + className="group relative overflow-hidden transition-all duration-200 border-accent bg-accent/20 hover:shadow-md h-full" + > + <CardContent className="p-4 flex flex-col gap-3 h-full"> + <div className="flex items-center gap-2 min-w-0"> + <div className="shrink-0"> + {getProviderIcon(cfg.provider, { className: "size-4" })} + </div> + <div className="min-w-0 flex-1 flex items-center gap-1.5"> + <h4 className="text-sm font-semibold tracking-tight truncate"> + {cfg.name} + </h4> + {isPremium ? ( + <Badge + variant="secondary" + className="text-[8px] md:text-[9px] shrink-0 bg-purple-100 text-purple-700 dark:bg-purple-900/50 dark:text-purple-300 border-0" + > + Premium + </Badge> + ) : ( + <Badge + variant="secondary" + className="text-[8px] md:text-[9px] shrink-0 bg-emerald-100 text-emerald-700 dark:bg-emerald-900/50 dark:text-emerald-300 border-0" + > + Free + </Badge> + )} + </div> + </div> + {cfg.description && ( + <p className="text-[11px] text-muted-foreground/70 line-clamp-2"> + {cfg.description} + </p> + )} + <div className="mt-auto space-y-2"> + <Separator className="bg-accent" /> + <div className="flex items-center"> + <span className="text-[11px] text-muted-foreground/60 truncate"> + {cfg.model_name} + </span> + </div> + </div> + </CardContent> + </Card> + ); + })} + </div> + </div> + )} + + {/* Loading Skeleton */} + {isLoading && ( + <div className="space-y-4 md:space-y-6"> + <div className="space-y-4"> + <div className="grid gap-3 grid-cols-1 sm:grid-cols-2 xl:grid-cols-3"> + {["skeleton-a", "skeleton-b", "skeleton-c"].map((key) => ( + <Card key={key} className="border-accent bg-accent/20"> + <CardContent className="p-4 flex flex-col gap-3 min-h-32"> + <Skeleton className="h-4 w-32 md:w-40 bg-accent" /> + <Skeleton className="h-3 w-full bg-accent" /> + <Skeleton className="h-3 w-24 md:w-28 bg-accent mt-auto" /> + </CardContent> + </Card> + ))} + </div> + </div> + </div> + )} + + {/* User Configs */} + {!isLoading && ( + <div className="space-y-4 md:space-y-6"> + {(userConfigs?.length ?? 0) === 0 ? ( + <Card className="border-0 bg-transparent shadow-none"> + <CardContent className="flex flex-col items-center justify-center py-10 md:py-16 text-center"> + <h3 className="text-sm md:text-base font-semibold mb-2">No Image Models Yet</h3> + <p className="text-[11px] md:text-xs text-muted-foreground max-w-sm mb-4"> + {canCreate + ? "Add your own image generation model (DALL-E 3, GPT Image 1, etc.)" + : "No image models have been added to this space yet. Contact a space owner to add one."} + </p> + </CardContent> + </Card> + ) : ( + <div className="grid gap-3 grid-cols-1 sm:grid-cols-2 xl:grid-cols-3"> + {userConfigs?.map((config) => { + const member = config.user_id ? memberMap.get(config.user_id) : null; + + return ( + <div key={config.id}> + <Card className="group relative overflow-hidden transition-all duration-200 border-accent bg-accent/20 hover:shadow-md h-full"> + <CardContent className="p-4 flex flex-col gap-3 h-full"> + {/* Header: Icon + Name + Actions */} + <div className="flex items-center justify-between gap-2"> + <div className="flex items-center gap-2.5 min-w-0 flex-1"> + <div className="shrink-0"> + {getProviderIcon(config.provider, { className: "size-4" })} + </div> + <div className="min-w-0 flex-1"> + <h4 className="text-sm font-semibold tracking-tight truncate"> + {config.name} + </h4> + {config.description && ( + <p className="text-[11px] text-muted-foreground/70 truncate mt-0.5"> + {config.description} + </p> + )} + </div> + </div> + {(canUpdate || canDelete) && ( + <div className="flex items-center gap-1 shrink-0 sm:w-0 sm:overflow-hidden sm:group-hover:w-auto sm:opacity-0 sm:group-hover:opacity-100 transition-all duration-150"> + {canUpdate && ( + <TooltipProvider> + <Tooltip open={isDesktop ? undefined : false}> + <TooltipTrigger asChild> + <Button + variant="ghost" + size="icon" + onClick={() => openEditDialog(config)} + className="h-7 w-7 rounded-lg text-muted-foreground hover:text-accent-foreground" + > + <Pencil className="h-3 w-3" /> + </Button> + </TooltipTrigger> + <TooltipContent>Edit</TooltipContent> + </Tooltip> + </TooltipProvider> + )} + {canDelete && ( + <TooltipProvider> + <Tooltip open={isDesktop ? undefined : false}> + <TooltipTrigger asChild> + <Button + variant="ghost" + size="icon" + onClick={() => setConfigToDelete(config)} + className="h-7 w-7 rounded-lg text-muted-foreground hover:text-destructive" + > + <Trash2 className="h-3 w-3" /> + </Button> + </TooltipTrigger> + <TooltipContent>Delete</TooltipContent> + </Tooltip> + </TooltipProvider> + )} + </div> + )} + </div> + + {/* Footer: Date + Creator */} + <div className="mt-auto space-y-2"> + <Separator className="bg-accent" /> + <div className="flex items-center"> + <span className="shrink-0 text-[11px] text-muted-foreground/60 whitespace-nowrap"> + {new Date(config.created_at).toLocaleDateString(undefined, { + year: "numeric", + month: "short", + day: "numeric", + })} + </span> + {member && ( + <> + <Dot className="h-4 w-4 text-muted-foreground/30 shrink-0" /> + <TooltipProvider> + <Tooltip open={isDesktop ? undefined : false}> + <TooltipTrigger asChild> + <div className="min-w-0 flex items-center gap-1.5 cursor-default"> + <Avatar className="size-4.5 shrink-0"> + {member.avatarUrl && ( + <AvatarImage src={member.avatarUrl} alt={member.name} /> + )} + <AvatarFallback className="text-[9px]"> + {getInitials(member.name)} + </AvatarFallback> + </Avatar> + <span className="text-[11px] text-muted-foreground/60 truncate"> + {member.name} + </span> + </div> + </TooltipTrigger> + <TooltipContent side="bottom"> + {member.email || member.name} + </TooltipContent> + </Tooltip> + </TooltipProvider> + </> + )} + </div> + </div> + </CardContent> + </Card> + </div> + ); + })} + </div> + )} + </div> + )} + + {/* Create/Edit Dialog — shared component */} + <ImageConfigDialog + open={isDialogOpen} + onOpenChange={(open) => { + setIsDialogOpen(open); + if (!open) setEditingConfig(null); + }} + config={editingConfig} + isGlobal={false} + searchSpaceId={searchSpaceId} + mode={editingConfig ? "edit" : "create"} + /> + + {/* Delete Confirmation */} + <AlertDialog + open={!!configToDelete} + onOpenChange={(open) => !open && setConfigToDelete(null)} + > + <AlertDialogContent className="select-none"> + <AlertDialogHeader> + <AlertDialogTitle>Delete Image Model</AlertDialogTitle> + <AlertDialogDescription> + Are you sure you want to delete{" "} + <span className="font-semibold text-foreground">{configToDelete?.name}</span>? + </AlertDialogDescription> + </AlertDialogHeader> + <AlertDialogFooter> + <AlertDialogCancel disabled={isDeleting}>Cancel</AlertDialogCancel> + <AlertDialogAction + onClick={handleDelete} + disabled={isDeleting} + className="relative bg-destructive text-destructive-foreground hover:bg-destructive/90" + > + <span className={isDeleting ? "opacity-0" : ""}>Delete</span> + {isDeleting && <Spinner size="sm" className="absolute" />} + </AlertDialogAction> + </AlertDialogFooter> + </AlertDialogContent> + </AlertDialog> + </div> + ); +} diff --git a/surfsense_web/components/settings/llm-role-manager.tsx b/surfsense_web/components/settings/llm-role-manager.tsx new file mode 100644 index 000000000..c32e79a8e --- /dev/null +++ b/surfsense_web/components/settings/llm-role-manager.tsx @@ -0,0 +1,443 @@ +"use client"; + +import { useAtomValue } from "jotai"; +import { + AlertCircle, + Bot, + CircleCheck, + CircleDashed, + FileText, + ImageIcon, + RefreshCw, + ScanEye, +} from "lucide-react"; +import { useCallback, useEffect, useState } from "react"; +import { toast } from "sonner"; +import { + globalImageGenConfigsAtom, + imageGenConfigsAtom, +} from "@/atoms/image-gen-config/image-gen-config-query.atoms"; +import { updateLLMPreferencesMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; +import { + globalNewLLMConfigsAtom, + llmPreferencesAtom, + newLLMConfigsAtom, +} from "@/atoms/new-llm-config/new-llm-config-query.atoms"; +import { + globalVisionLLMConfigsAtom, + visionLLMConfigsAtom, +} from "@/atoms/vision-llm-config/vision-llm-config-query.atoms"; +import { Alert, AlertDescription } from "@/components/ui/alert"; +import { Badge } from "@/components/ui/badge"; +import { Button } from "@/components/ui/button"; +import { Card, CardContent } from "@/components/ui/card"; +import { Label } from "@/components/ui/label"; +import { + Select, + SelectContent, + SelectGroup, + SelectItem, + SelectLabel, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import { Skeleton } from "@/components/ui/skeleton"; +import { Spinner } from "@/components/ui/spinner"; +import { cn } from "@/lib/utils"; + +const ROLE_DESCRIPTIONS = { + agent: { + icon: Bot, + title: "Agent LLM", + description: "Primary LLM for chat interactions and agent operations", + color: "text-muted-foreground", + bgColor: "bg-muted", + prefKey: "agent_llm_id" as const, + configType: "llm" as const, + }, + image_generation: { + icon: ImageIcon, + title: "Image Generation Model", + description: "Model used for AI image generation (DALL-E, GPT Image, etc.)", + color: "text-muted-foreground", + bgColor: "bg-muted", + prefKey: "image_generation_config_id" as const, + configType: "image" as const, + }, + vision: { + icon: ScanEye, + title: "Vision LLM", + description: "Vision-capable model for screenshot analysis and context extraction", + color: "text-muted-foreground", + bgColor: "bg-muted", + prefKey: "vision_llm_config_id" as const, + configType: "vision" as const, + }, +}; + +interface LLMRoleManagerProps { + searchSpaceId: number; +} + +export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) { + // LLM configs + const { + data: newLLMConfigs = [], + isFetching: configsLoading, + error: configsError, + refetch: refreshConfigs, + } = useAtomValue(newLLMConfigsAtom); + const { + data: globalConfigs = [], + isFetching: globalConfigsLoading, + error: globalConfigsError, + } = useAtomValue(globalNewLLMConfigsAtom); + + // Image gen configs + const { + data: userImageConfigs = [], + isFetching: imageConfigsLoading, + error: imageConfigsError, + } = useAtomValue(imageGenConfigsAtom); + const { + data: globalImageConfigs = [], + isFetching: globalImageConfigsLoading, + error: globalImageConfigsError, + } = useAtomValue(globalImageGenConfigsAtom); + + // Vision LLM configs + const { + data: userVisionConfigs = [], + isFetching: visionConfigsLoading, + error: visionConfigsError, + } = useAtomValue(visionLLMConfigsAtom); + const { + data: globalVisionConfigs = [], + isFetching: globalVisionConfigsLoading, + error: globalVisionConfigsError, + } = useAtomValue(globalVisionLLMConfigsAtom); + + // Preferences + const { + data: preferences = {}, + isFetching: preferencesLoading, + error: preferencesError, + } = useAtomValue(llmPreferencesAtom); + + const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom); + + const [assignments, setAssignments] = useState<Record<string, number | null>>(() => ({ + agent_llm_id: preferences.agent_llm_id ?? null, + image_generation_config_id: preferences.image_generation_config_id ?? null, + vision_llm_config_id: preferences.vision_llm_config_id ?? null, + })); + + // Sync local state when preferences load/change. Without this, the selects + // stay on their initial (often empty) value while the query is in flight, + // so a saved assignment — including Auto mode (id 0) — never appears. + useEffect(() => { + setAssignments({ + agent_llm_id: preferences.agent_llm_id ?? null, + image_generation_config_id: preferences.image_generation_config_id ?? null, + vision_llm_config_id: preferences.vision_llm_config_id ?? null, + }); + }, [ + preferences.agent_llm_id, + preferences.image_generation_config_id, + preferences.vision_llm_config_id, + ]); + + const [savingRole, setSavingRole] = useState<string | null>(null); + + const handleRoleAssignment = useCallback( + async (prefKey: string, configId: string) => { + // "unassigned" clears the role (null). Every other option — including + // Auto mode, whose config id is 0 — must be sent as-is. Using a falsy + // check here (e.g. `value || undefined`) would drop id 0 and silently + // fail to persist Auto mode. + const value = configId === "unassigned" ? null : Number(configId); + + setAssignments((prev) => ({ ...prev, [prefKey]: value })); + setSavingRole(prefKey); + + try { + await updatePreferences({ + search_space_id: searchSpaceId, + data: { [prefKey]: value }, + }); + toast.success("Role assignment updated"); + } finally { + setSavingRole(null); + } + }, + [updatePreferences, searchSpaceId] + ); + + // Combine global and custom LLM configs + const allLLMConfigs = [ + ...globalConfigs.map((config) => ({ ...config, is_global: true })), + ...newLLMConfigs.filter((config) => config.id && config.id.toString().trim() !== ""), + ]; + + // Combine global and custom image gen configs + const allImageConfigs = [ + ...globalImageConfigs.map((config) => ({ ...config, is_global: true })), + ...(userImageConfigs ?? []).filter((config) => config.id && config.id.toString().trim() !== ""), + ]; + + // Combine global and custom vision LLM configs + const allVisionConfigs = [ + ...globalVisionConfigs.map((config) => ({ ...config, is_global: true })), + ...(userVisionConfigs ?? []).filter( + (config) => config.id && config.id.toString().trim() !== "" + ), + ]; + + const isLoading = + configsLoading || + preferencesLoading || + globalConfigsLoading || + imageConfigsLoading || + globalImageConfigsLoading || + visionConfigsLoading || + globalVisionConfigsLoading; + const hasError = + configsError || + preferencesError || + globalConfigsError || + imageConfigsError || + globalImageConfigsError || + visionConfigsError || + globalVisionConfigsError; + const hasAnyConfigs = allLLMConfigs.length > 0 || allImageConfigs.length > 0; + + return ( + <div className="space-y-5 md:space-y-6"> + {/* Header actions */} + <div className="flex items-center justify-start"> + <Button + variant="secondary" + size="sm" + onClick={() => refreshConfigs()} + disabled={isLoading} + className="gap-2" + > + <RefreshCw className={cn("h-3.5 w-3.5", isLoading && "animate-spin")} /> + Refresh + </Button> + </div> + + {/* Error Alert */} + {hasError && ( + <div> + <Alert variant="destructive" className="py-3 md:py-4"> + <AlertCircle className="h-3 w-3 md:h-4 md:w-4 shrink-0" /> + <AlertDescription className="text-xs md:text-sm"> + {(configsError?.message ?? "Failed to load LLM configurations") || + (preferencesError?.message ?? "Failed to load preferences") || + (globalConfigsError?.message ?? "Failed to load global configurations")} + </AlertDescription> + </Alert> + </div> + )} + + {/* Loading Skeleton */} + {isLoading && ( + <div className="grid gap-4 grid-cols-1 lg:grid-cols-2"> + {["skeleton-a", "skeleton-b", "skeleton-c"].map((key) => ( + <Card key={key} className="border-accent bg-accent/20"> + <CardContent className="p-4 flex flex-col gap-3 min-h-32"> + <Skeleton className="h-4 w-32 md:w-40 bg-accent" /> + <Skeleton className="h-3 w-full bg-accent" /> + <Skeleton className="h-3 w-24 md:w-28 bg-accent mt-auto" /> + </CardContent> + </Card> + ))} + </div> + )} + + {/* No configs warning */} + {!isLoading && !hasError && !hasAnyConfigs && ( + <Alert variant="destructive" className="py-3 md:py-4"> + <AlertCircle className="h-3 w-3 md:h-4 md:w-4 shrink-0" /> + <AlertDescription className="text-xs md:text-sm"> + No configurations found. Please add at least one LLM provider or image model in the + respective settings tabs before assigning roles. + </AlertDescription> + </Alert> + )} + + {/* Role Assignment Cards */} + {!isLoading && !hasError && hasAnyConfigs && ( + <div className="grid gap-4 grid-cols-1 lg:grid-cols-2"> + {Object.entries(ROLE_DESCRIPTIONS).map(([key, role]) => { + const IconComponent = role.icon; + const currentAssignment = assignments[role.prefKey as keyof typeof assignments]; + + // Pick the right config lists based on role type + const roleGlobalConfigs = + role.configType === "image" + ? globalImageConfigs + : role.configType === "vision" + ? globalVisionConfigs + : globalConfigs; + const roleUserConfigs = + role.configType === "image" + ? (userImageConfigs ?? []).filter((c) => c.id && c.id.toString().trim() !== "") + : role.configType === "vision" + ? (userVisionConfigs ?? []).filter((c) => c.id && c.id.toString().trim() !== "") + : newLLMConfigs.filter((c) => c.id && c.id.toString().trim() !== ""); + const roleAllConfigs = + role.configType === "image" + ? allImageConfigs + : role.configType === "vision" + ? allVisionConfigs + : allLLMConfigs; + + const assignedConfig = roleAllConfigs.find((config) => config.id === currentAssignment); + const isAssigned = !!assignedConfig; + + return ( + <div key={key}> + <Card className="group relative overflow-hidden transition-all duration-200 border-accent bg-accent/20 hover:shadow-md h-full"> + <CardContent className="p-4 md:p-5 space-y-4"> + {/* Role Header */} + <div className="flex items-start justify-between gap-3"> + <div className="flex items-center gap-3 min-w-0"> + <div + className={cn( + "flex items-center justify-center w-9 h-9 rounded-lg shrink-0", + role.bgColor + )} + > + <IconComponent className={cn("w-4 h-4", role.color)} /> + </div> + <div className="min-w-0"> + <h4 className="text-sm font-semibold tracking-tight">{role.title}</h4> + <p className="text-[11px] text-muted-foreground/70 mt-0.5"> + {role.description} + </p> + </div> + </div> + {savingRole === role.prefKey ? ( + <Spinner size="sm" className="shrink-0 mt-0.5 text-muted-foreground" /> + ) : isAssigned ? ( + <CircleCheck className="w-4 h-4 text-muted-foreground/40 shrink-0 mt-0.5" /> + ) : ( + <CircleDashed className="w-4 h-4 text-muted-foreground/40 shrink-0 mt-0.5" /> + )} + </div> + + {/* Selector */} + <div className="space-y-1.5"> + <Label className="text-xs font-medium text-muted-foreground"> + Configuration + </Label> + <Select + value={assignedConfig ? assignedConfig.id.toString() : "unassigned"} + onValueChange={(value) => handleRoleAssignment(role.prefKey, value)} + > + <SelectTrigger className="w-full h-9 md:h-10 text-xs md:text-sm"> + <SelectValue placeholder="Select a configuration" /> + </SelectTrigger> + <SelectContent className="max-w-[calc(100vw-2rem)] select-none"> + <SelectItem + value="unassigned" + className="text-xs md:text-sm py-1.5 md:py-2" + > + <span className="text-muted-foreground">Unassigned</span> + </SelectItem> + + {/* Global Configurations */} + {roleGlobalConfigs.length > 0 && ( + <SelectGroup> + <SelectLabel className="text-[11px] md:text-xs font-semibold text-muted-foreground px-2 py-1 md:py-1.5"> + Global Configurations + </SelectLabel> + {roleGlobalConfigs.map((config) => { + const isAuto = "is_auto_mode" in config && config.is_auto_mode; + // Read billing_tier from the global config; default to "free" + // for legacy YAMLs / Auto stub. Premium gets a purple badge, + // free gets an emerald one — same palette as the chat + // model selector so the meaning is consistent across + // surfaces (issues E, H). + const billingTier = + ("billing_tier" in config && + typeof config.billing_tier === "string" && + config.billing_tier) || + "free"; + const isPremium = billingTier === "premium"; + return ( + <SelectItem + key={config.id} + value={config.id.toString()} + className="text-xs md:text-sm py-1.5 md:py-2" + textValue={config.name} + > + <div className="flex items-center gap-1 md:gap-1.5 flex-wrap min-w-0"> + <span className="truncate text-xs md:text-sm"> + {config.name} + </span> + {isAuto ? ( + <Badge + variant="secondary" + className="text-[8px] md:text-[9px] shrink-0 bg-zinc-200 text-zinc-600 dark:bg-zinc-700 dark:text-zinc-300 [[data-slot=select-trigger]_&]:hidden" + > + Recommended + </Badge> + ) : isPremium ? ( + <Badge + variant="secondary" + className="text-[8px] md:text-[9px] shrink-0 bg-purple-100 text-purple-700 dark:bg-purple-900/50 dark:text-purple-300 border-0 [[data-slot=select-trigger]_&]:hidden" + > + Premium + </Badge> + ) : ( + <Badge + variant="secondary" + className="text-[8px] md:text-[9px] shrink-0 bg-emerald-100 text-emerald-700 dark:bg-emerald-900/50 dark:text-emerald-300 border-0 [[data-slot=select-trigger]_&]:hidden" + > + Free + </Badge> + )} + </div> + </SelectItem> + ); + })} + </SelectGroup> + )} + + {/* Custom Configurations */} + {roleUserConfigs.length > 0 && ( + <SelectGroup> + <SelectLabel className="text-[11px] md:text-xs font-semibold text-muted-foreground px-2 py-1 md:py-1.5"> + Your Configurations + </SelectLabel> + {roleUserConfigs.map((config) => ( + <SelectItem + key={config.id} + value={config.id.toString()} + className="text-xs md:text-sm py-1.5 md:py-2" + > + <div className="flex items-center gap-1 md:gap-1.5 flex-wrap min-w-0"> + <span className="truncate text-xs md:text-sm"> + {config.name} + </span> + </div> + </SelectItem> + ))} + </SelectGroup> + )} + </SelectContent> + </Select> + </div> + </CardContent> + </Card> + </div> + ); + })} + </div> + )} + </div> + ); +} diff --git a/surfsense_web/components/settings/model-connections-settings.tsx b/surfsense_web/components/settings/model-connections-settings.tsx deleted file mode 100644 index 3b30b1558..000000000 --- a/surfsense_web/components/settings/model-connections-settings.tsx +++ /dev/null @@ -1,153 +0,0 @@ -"use client"; - -import { useAtom, useAtomValue } from "jotai"; -import { Dot } from "lucide-react"; -import { updateModelRolesMutationAtom } from "@/atoms/model-connections/model-connections-mutation.atoms"; -import { - globalModelConnectionsAtom, - modelConnectionsAtom, - modelRolesAtom, -} from "@/atoms/model-connections/model-connections-query.atoms"; -import { Label } from "@/components/ui/label"; -import { - Select, - SelectContent, - SelectItem, - SelectTrigger, - SelectValue, -} from "@/components/ui/select"; -import { Separator } from "@/components/ui/separator"; -import type { ConnectionRead, ModelRead } from "@/contracts/types/model-connections.types"; -import { AUTO_PROVIDER_ICON_KEY, getProviderIcon } from "@/lib/provider-icons"; -import { ModelProviderConnectionsPanel } from "./model-connections/model-provider-connections-panel"; -import { capability, modelLabel } from "./model-connections/model-utils"; -import { providerDisplay, providerIcon } from "./model-connections/provider-metadata"; - -function flattenModels(connections: ConnectionRead[]) { - return connections.flatMap((connection) => - connection.models.map((model) => ({ - ...model, - connectionName: providerDisplay(connection.provider).name, - connectionId: connection.id, - provider: connection.provider, - })) - ); -} - -function roleSelectValue(modelId: number | null | undefined, models: Array<{ id: number }>) { - if (!modelId) return "0"; - return models.some((model) => model.id === modelId) ? String(modelId) : "0"; -} - -function renderAutoModeOption() { - return ( - <SelectItem value="0"> - <span className="inline-flex items-center gap-2"> - {getProviderIcon(AUTO_PROVIDER_ICON_KEY)} - <span>Auto mode</span> - </span> - </SelectItem> - ); -} - -export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: number }) { - const [{ data: globalConnections = [] }] = useAtom(globalModelConnectionsAtom); - const [{ data: connections = [] }] = useAtom(modelConnectionsAtom); - const [{ data: roles }] = useAtom(modelRolesAtom); - const updateRoles = useAtomValue(updateModelRolesMutationAtom); - - const allConnections = [...globalConnections, ...connections]; - const enabledModels = flattenModels(allConnections).filter((model) => model.enabled); - const chatModels = enabledModels.filter((model) => capability(model, "chat")); - const visionModels = enabledModels.filter((model) => capability(model, "vision")); - const imageModels = enabledModels.filter((model) => capability(model, "image_gen")); - - function renderModelOption(model: ModelRead & { connectionName: string; provider: string }) { - return ( - <SelectItem key={model.id} value={String(model.id)}> - <span className="inline-flex items-center gap-2"> - {providerIcon(model.provider)} - <span className="inline-flex items-center gap-1"> - <span>{modelLabel(model)}</span> - <Dot className="size-4 text-muted-foreground" aria-hidden="true" /> - <span>{model.connectionName}</span> - </span> - </span> - </SelectItem> - ); - } - - return ( - <div className="flex flex-col gap-6"> - <div className="flex flex-col gap-4"> - <div> - <h3 className="text-base font-semibold">Model Roles</h3> - <p className="text-sm text-muted-foreground"> - Pick which enabled model powers chat, vision, and image generation for this search - space. - </p> - </div> - <div className="flex w-full max-w-2xl flex-col gap-4"> - <div className="flex flex-col gap-2"> - <Label>Chat model</Label> - <p className="text-xs text-muted-foreground"> - Primary model for chat responses and agent tasks. You can also change it from the - chat. - </p> - <Select - value={roleSelectValue(roles?.chat_model_id, chatModels)} - onValueChange={(value) => updateRoles.mutate({ chat_model_id: Number(value) })} - > - <SelectTrigger className="w-full"> - <SelectValue /> - </SelectTrigger> - <SelectContent> - {renderAutoModeOption()} - {chatModels.map(renderModelOption)} - </SelectContent> - </Select> - </div> - <div className="flex flex-col gap-2"> - <Label>Vision model</Label> - <p className="text-xs text-muted-foreground"> - Used to understand images in uploads, documents, connectors, and automations. Falls - back to chat model when possible. - </p> - <Select - value={roleSelectValue(roles?.vision_model_id, visionModels)} - onValueChange={(value) => updateRoles.mutate({ vision_model_id: Number(value) })} - > - <SelectTrigger className="w-full"> - <SelectValue /> - </SelectTrigger> - <SelectContent> - {renderAutoModeOption()} - {visionModels.map(renderModelOption)} - </SelectContent> - </Select> - </div> - <div className="flex flex-col gap-2"> - <Label>Image generation model</Label> - <p className="text-xs text-muted-foreground">Used when generating images in chat.</p> - <Select - value={roleSelectValue(roles?.image_gen_model_id, imageModels)} - onValueChange={(value) => updateRoles.mutate({ image_gen_model_id: Number(value) })} - > - <SelectTrigger className="w-full"> - <SelectValue /> - </SelectTrigger> - <SelectContent> - {renderAutoModeOption()} - {imageModels.map(renderModelOption)} - </SelectContent> - </Select> - </div> - </div> - </div> - - <Separator /> - - <ModelProviderConnectionsPanel searchSpaceId={searchSpaceId} connections={connections} /> - </div> - ); -} diff --git a/surfsense_web/components/settings/model-connections/azure-connect-form.tsx b/surfsense_web/components/settings/model-connections/azure-connect-form.tsx deleted file mode 100644 index 451f053db..000000000 --- a/surfsense_web/components/settings/model-connections/azure-connect-form.tsx +++ /dev/null @@ -1,54 +0,0 @@ -import { useEffect, useState } from "react"; -import { Input } from "@/components/ui/input"; -import { Label } from "@/components/ui/label"; -import { ApiKeyField } from "./connect-fields"; -import { - isValidAzureTargetUri, - type ProviderConnectFormProps, - parseAzureTargetUri, -} from "./provider-metadata"; - -/** - * Azure OpenAI connect form. The user pastes a single Target URI, which we parse - * into api base, api version, and the deployment name (seeded as the model). - */ -export function AzureConnectForm({ onDraftChange }: ProviderConnectFormProps) { - const [targetUri, setTargetUri] = useState(""); - const [apiKey, setApiKey] = useState(""); - const canSubmit = isValidAzureTargetUri(targetUri) && Boolean(apiKey.trim()); - - useEffect(() => { - const parsed = parseAzureTargetUri(targetUri); - onDraftChange( - { - base_url: parsed?.origin ?? null, - api_key: apiKey || null, - extra: parsed?.apiVersion ? { api_version: parsed.apiVersion } : {}, - seedModelId: parsed?.deploymentName || undefined, - }, - canSubmit - ); - }, [apiKey, canSubmit, onDraftChange, targetUri]); - - return ( - <div className="flex flex-col gap-4"> - <div className="flex flex-col gap-2"> - <Label>Target URI</Label> - <Input - value={targetUri} - onChange={(event) => setTargetUri(event.target.value)} - placeholder="https://your-resource.cognitiveservices.azure.com/openai/deployments/deployment-name/chat/completions?api-version=2025-01-01-preview" - /> - <p className="text-xs text-muted-foreground"> - Paste your endpoint target URI from Azure OpenAI (including API base, deployment name, and - API version). - </p> - </div> - <ApiKeyField - value={apiKey} - onChange={setApiKey} - placeholder="Paste your API key from Azure" - /> - </div> - ); -} diff --git a/surfsense_web/components/settings/model-connections/bedrock-connect-form.tsx b/surfsense_web/components/settings/model-connections/bedrock-connect-form.tsx deleted file mode 100644 index f76308421..000000000 --- a/surfsense_web/components/settings/model-connections/bedrock-connect-form.tsx +++ /dev/null @@ -1,120 +0,0 @@ -import { useEffect, useState } from "react"; -import { Input } from "@/components/ui/input"; -import { Label } from "@/components/ui/label"; -import { - Select, - SelectContent, - SelectItem, - SelectTrigger, - SelectValue, -} from "@/components/ui/select"; -import { ApiKeyField } from "./connect-fields"; -import { - AWS_REGION_OPTIONS, - BEDROCK_AUTH_ACCESS_KEY, - BEDROCK_AUTH_IAM, - BEDROCK_AUTH_LONG_TERM_API_KEY, - type ProviderConnectFormProps, -} from "./provider-metadata"; - -/** - * Amazon Bedrock connect form. Region + auth method drive which AWS credentials - * are collected; everything rides along in `extra.litellm_params`. - */ -export function BedrockConnectForm({ onDraftChange }: ProviderConnectFormProps) { - const [region, setRegion] = useState(""); - const [authMethod, setAuthMethod] = useState(BEDROCK_AUTH_ACCESS_KEY); - const [accessKeyId, setAccessKeyId] = useState(""); - const [secretAccessKey, setSecretAccessKey] = useState(""); - const [bearerToken, setBearerToken] = useState(""); - - const canSubmit = (() => { - if (!region) return false; - if (authMethod === BEDROCK_AUTH_ACCESS_KEY) { - return Boolean(accessKeyId && secretAccessKey); - } - if (authMethod === BEDROCK_AUTH_LONG_TERM_API_KEY) { - return Boolean(bearerToken); - } - return true; - })(); - - useEffect(() => { - const params: Record<string, string> = { aws_region_name: region }; - if (authMethod === BEDROCK_AUTH_ACCESS_KEY) { - params.aws_access_key_id = accessKeyId; - params.aws_secret_access_key = secretAccessKey; - } else if (authMethod === BEDROCK_AUTH_LONG_TERM_API_KEY) { - params.aws_bearer_token_bedrock = bearerToken; - } - onDraftChange({ base_url: null, api_key: null, extra: { litellm_params: params } }, canSubmit); - }, [accessKeyId, authMethod, bearerToken, canSubmit, onDraftChange, region, secretAccessKey]); - - return ( - <div className="flex flex-col gap-4"> - <div className="flex flex-col gap-2"> - <Label>AWS Region</Label> - <Select value={region || undefined} onValueChange={setRegion}> - <SelectTrigger> - <SelectValue placeholder="Select a region" /> - </SelectTrigger> - <SelectContent> - {AWS_REGION_OPTIONS.map((option) => ( - <SelectItem key={option} value={option}> - {option} - </SelectItem> - ))} - </SelectContent> - </Select> - </div> - <div className="flex flex-col gap-2"> - <Label>Authentication Method</Label> - <Select value={authMethod} onValueChange={setAuthMethod}> - <SelectTrigger> - <SelectValue /> - </SelectTrigger> - <SelectContent> - <SelectItem value={BEDROCK_AUTH_IAM}>Environment IAM Role</SelectItem> - <SelectItem value={BEDROCK_AUTH_ACCESS_KEY}>Access Key</SelectItem> - <SelectItem value={BEDROCK_AUTH_LONG_TERM_API_KEY}>Long-term API Key</SelectItem> - </SelectContent> - </Select> - </div> - {authMethod === BEDROCK_AUTH_ACCESS_KEY ? ( - <> - <div className="flex flex-col gap-2"> - <Label>AWS Access Key ID</Label> - <Input - value={accessKeyId} - onChange={(event) => setAccessKeyId(event.target.value)} - placeholder="Enter your AWS access key ID" - /> - </div> - <ApiKeyField - value={secretAccessKey} - onChange={setSecretAccessKey} - label="AWS Secret Access Key" - placeholder="Enter your AWS secret access key" - /> - </> - ) : null} - {authMethod === BEDROCK_AUTH_LONG_TERM_API_KEY ? ( - <ApiKeyField - value={bearerToken} - onChange={setBearerToken} - label="Long-term API Key" - placeholder="Your long-term API key" - /> - ) : null} - {authMethod === BEDROCK_AUTH_IAM ? ( - <p className="text-xs text-muted-foreground"> - SurfSense will use the IAM role attached to the environment it's running in to - authenticate. - </p> - ) : null} - <p className="text-xs text-muted-foreground"> - Add Bedrock model IDs from the provider's settings after connecting. - </p> - </div> - ); -} diff --git a/surfsense_web/components/settings/model-connections/connect-fields.tsx b/surfsense_web/components/settings/model-connections/connect-fields.tsx deleted file mode 100644 index 584fb98b0..000000000 --- a/surfsense_web/components/settings/model-connections/connect-fields.tsx +++ /dev/null @@ -1,105 +0,0 @@ -import { Eye, EyeOff } from "lucide-react"; -import type { ReactNode } from "react"; -import { useState } from "react"; -import { Button } from "@/components/ui/button"; -import { DialogFooter } from "@/components/ui/dialog"; -import { Input } from "@/components/ui/input"; -import { Label } from "@/components/ui/label"; -import { Spinner } from "@/components/ui/spinner"; - -interface ApiBaseUrlFieldProps { - value: string; - onChange: (value: string) => void; - /** Placeholder, typically the provider's prefilled default base URL. */ - placeholder?: string; - hint?: ReactNode; -} - -/** Shared API Base URL input. The prefilled default is passed in via `value`. */ -export function ApiBaseUrlField({ value, onChange, placeholder, hint }: ApiBaseUrlFieldProps) { - return ( - <div className="flex flex-col gap-2"> - <Label>API Base URL</Label> - <Input - value={value} - onChange={(event) => onChange(event.target.value)} - placeholder={placeholder || "https://api.example.com/v1"} - /> - {hint ? <p className="text-xs text-muted-foreground">{hint}</p> : null} - </div> - ); -} - -interface ApiKeyFieldProps { - value: string; - onChange: (value: string) => void; - label?: string; - placeholder?: string; -} - -/** Shared masked API Key input. */ -export function ApiKeyField({ - value, - onChange, - label = "API Key", - placeholder = "API key", -}: ApiKeyFieldProps) { - const [showApiKey, setShowApiKey] = useState(false); - - return ( - <div className="flex flex-col gap-2"> - <Label>{label}</Label> - <div className="relative"> - <Input - value={value} - onChange={(event) => onChange(event.target.value)} - placeholder={placeholder} - type={showApiKey ? "text" : "password"} - className="pr-11" - /> - <Button - type="button" - variant="ghost" - size="icon" - className="absolute top-1/2 right-1 size-8 -translate-y-1/2 text-muted-foreground" - onClick={() => setShowApiKey((current) => !current)} - disabled={!value} - aria-label={showApiKey ? "Hide API key" : "Show API key"} - > - {showApiKey ? <EyeOff className="h-4 w-4" /> : <Eye className="h-4 w-4" />} - </Button> - </div> - </div> - ); -} - -interface ConnectFormFooterProps { - onCancel: () => void; - onSubmit: () => void; - canSubmit: boolean; - isPending: boolean; -} - -/** Shared Cancel / Connect footer for every provider connect form. */ -export function ConnectFormFooter({ - onCancel, - onSubmit, - canSubmit, - isPending, -}: ConnectFormFooterProps) { - return ( - <DialogFooter className="shrink-0 border-t bg-popover px-6 py-4"> - <Button variant="secondary" onClick={onCancel}> - Cancel - </Button> - <Button - onClick={onSubmit} - disabled={isPending || !canSubmit} - className="relative min-w-[96px]" - > - <span className={isPending ? "opacity-0" : ""}>Connect</span> - {isPending ? <Spinner size="sm" className="absolute" /> : null} - </Button> - </DialogFooter> - ); -} diff --git a/surfsense_web/components/settings/model-connections/connection-card.tsx b/surfsense_web/components/settings/model-connections/connection-card.tsx deleted file mode 100644 index b482cac9f..000000000 --- a/surfsense_web/components/settings/model-connections/connection-card.tsx +++ /dev/null @@ -1,88 +0,0 @@ -"use client"; - -import { useAtomValue } from "jotai"; -import { Trash2 } from "lucide-react"; -import { deleteModelConnectionMutationAtom } from "@/atoms/model-connections/model-connections-mutation.atoms"; -import { - AlertDialog, - AlertDialogAction, - AlertDialogCancel, - AlertDialogContent, - AlertDialogDescription, - AlertDialogFooter, - AlertDialogHeader, - AlertDialogTitle, - AlertDialogTrigger, -} from "@/components/ui/alert-dialog"; -import { Badge } from "@/components/ui/badge"; -import { Button } from "@/components/ui/button"; -import type { ConnectionRead } from "@/contracts/types/model-connections.types"; -import { ConnectionSettingsDialog } from "./connection-settings-dialog"; -import { providerDisplay, providerIcon } from "./provider-metadata"; - -export function ConnectionCard({ connection }: { connection: ConnectionRead }) { - const deleteConnection = useAtomValue(deleteModelConnectionMutationAtom); - - const providerMeta = providerDisplay(connection.provider); - const providerLabel = providerMeta.name; - - function deleteCurrentConnection() { - deleteConnection.mutate(connection.id); - } - - return ( - <div className="overflow-hidden rounded-lg border border-border/60"> - <div className="flex items-center justify-between gap-3 p-4 transition-colors hover:bg-accent"> - <div className="min-w-0"> - <div className="flex items-center gap-2 font-semibold"> - {providerIcon(connection.provider)} - <span className="truncate">{providerLabel}</span> - {connection.scope === "GLOBAL" ? ( - <Badge variant="outline" className="text-[10px]"> - Default - </Badge> - ) : null} - </div> - <div className="truncate text-sm text-muted-foreground"> - {connection.base_url || "Provider default endpoint"} - </div> - </div> - <div className="flex shrink-0 items-center gap-2"> - <ConnectionSettingsDialog connection={connection} providerLabel={providerLabel} /> - <AlertDialog> - <AlertDialogTrigger asChild> - <Button - variant="ghost" - size="icon" - className="text-muted-foreground hover:text-accent-foreground" - disabled={deleteConnection.isPending} - aria-label={`Delete ${providerLabel}`} - > - <Trash2 className="h-4 w-4" /> - </Button> - </AlertDialogTrigger> - <AlertDialogContent> - <AlertDialogHeader> - <AlertDialogTitle>Delete this provider?</AlertDialogTitle> - <AlertDialogDescription> - <span className="font-medium text-foreground">{providerLabel}</span> and all of - its models will be removed from this search space. This cannot be undone. - </AlertDialogDescription> - </AlertDialogHeader> - <AlertDialogFooter> - <AlertDialogCancel disabled={deleteConnection.isPending}>Cancel</AlertDialogCancel> - <AlertDialogAction - onClick={deleteCurrentConnection} - disabled={deleteConnection.isPending} - className="bg-destructive text-white hover:bg-destructive/90" - > - Delete - </AlertDialogAction> - </AlertDialogFooter> - </AlertDialogContent> - </AlertDialog> - </div> - </div> - </div> - ); -} diff --git a/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx b/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx deleted file mode 100644 index 1f16c3bd0..000000000 --- a/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx +++ /dev/null @@ -1,333 +0,0 @@ -import { useAtomValue } from "jotai"; -import { Eye, EyeOff, Settings } from "lucide-react"; -import { useMemo, useState } from "react"; -import { - addManualModelMutationAtom, - bulkUpdateModelsMutationAtom, - discoverConnectionModelsMutationAtom, - testPreviewModelMutationAtom, - updateModelConnectionMutationAtom, -} from "@/atoms/model-connections/model-connections-mutation.atoms"; -import { Button } from "@/components/ui/button"; -import { - Dialog, - DialogContent, - DialogDescription, - DialogFooter, - DialogHeader, - DialogTitle, - DialogTrigger, -} from "@/components/ui/dialog"; -import { Input } from "@/components/ui/input"; -import { Label } from "@/components/ui/label"; -import { Separator } from "@/components/ui/separator"; -import { Spinner } from "@/components/ui/spinner"; -import type { - ConnectionRead, - ConnectionUpdateRequest, -} from "@/contracts/types/model-connections.types"; -import { capability, type SelectableModel } from "./model-utils"; -import { ModelsSelectionPanel } from "./models-selection-panel"; -import { providerIcon } from "./provider-metadata"; - -interface ConnectionSettingsDialogProps { - connection: ConnectionRead; - providerLabel: string; -} - -function enabledModelIds(models: SelectableModel[]) { - return new Set( - models - .filter((model) => typeof model.id === "number" && model.enabled) - .map((model) => Number(model.id)) - ); -} - -export function ConnectionSettingsDialog({ - connection, - providerLabel, -}: ConnectionSettingsDialogProps) { - const discoverModels = useAtomValue(discoverConnectionModelsMutationAtom); - const testPreviewModel = useAtomValue(testPreviewModelMutationAtom); - const updateConnection = useAtomValue(updateModelConnectionMutationAtom); - const addManualModel = useAtomValue(addManualModelMutationAtom); - const bulkUpdateModels = useAtomValue(bulkUpdateModelsMutationAtom); - - const allowlist = Array.isArray(connection.extra?.model_ids) - ? (connection.extra.model_ids as string[]) - : []; - const [isOpen, setIsOpen] = useState(false); - const [baseUrlDraft, setBaseUrlDraft] = useState(connection.base_url ?? ""); - const [apiKeyDraft, setApiKeyDraft] = useState(""); - const [showApiKey, setShowApiKey] = useState(false); - const [allowlistText, setAllowlistText] = useState(allowlist.join(", ")); - const [isSavingConnectionSettings, setIsSavingConnectionSettings] = useState(false); - const [draftEnabledModelIds, setDraftEnabledModelIds] = useState(() => - enabledModelIds(connection.models) - ); - - const isLocal = - connection.provider === "ollama_chat" || - connection.provider === "lm_studio" || - !connection.base_url?.startsWith("https"); - const hasConnectionChanges = - baseUrlDraft.trim() !== (connection.base_url ?? "") || - apiKeyDraft.trim() !== (connection.api_key ?? ""); - const draftModels = useMemo( - () => - connection.models.map((model) => - typeof model.id === "number" - ? { ...model, enabled: draftEnabledModelIds.has(model.id) } - : model - ), - [connection.models, draftEnabledModelIds] - ); - const hasModelChanges = connection.models.some( - (model) => typeof model.id === "number" && draftEnabledModelIds.has(model.id) !== model.enabled - ); - const canUpdate = hasConnectionChanges || hasModelChanges; - - function handleOpenChange(open: boolean) { - setIsOpen(open); - if (open) { - setBaseUrlDraft(connection.base_url ?? ""); - setApiKeyDraft(connection.api_key ?? ""); - setShowApiKey(false); - setAllowlistText(allowlist.join(", ")); - setIsSavingConnectionSettings(false); - setDraftEnabledModelIds(enabledModelIds(connection.models)); - } - } - - async function saveModelChanges() { - const toEnable = connection.models - .filter((model) => typeof model.id === "number" && draftEnabledModelIds.has(model.id)) - .filter((model) => !model.enabled) - .map((model) => Number(model.id)); - const toDisable = connection.models - .filter((model) => typeof model.id === "number" && !draftEnabledModelIds.has(model.id)) - .filter((model) => model.enabled) - .map((model) => Number(model.id)); - - if (toEnable.length > 0) { - await bulkUpdateModels.mutateAsync({ - connectionId: connection.id, - data: { model_ids: toEnable, enabled: true }, - }); - } - if (toDisable.length > 0) { - await bulkUpdateModels.mutateAsync({ - connectionId: connection.id, - data: { model_ids: toDisable, enabled: false }, - }); - } - } - - async function saveConnectionSettings() { - if (isSavingConnectionSettings) return; - - const data: ConnectionUpdateRequest = { - base_url: baseUrlDraft.trim() || null, - }; - - if (apiKeyDraft.trim() !== (connection.api_key ?? "")) { - data.api_key = apiKeyDraft.trim() || null; - } - const apiKeyForTest = Object.hasOwn(data, "api_key") - ? (data.api_key ?? null) - : (connection.api_key ?? null); - - const enabledModels = draftModels.filter((model) => model.enabled); - const testModel = enabledModels.find((model) => capability(model, "chat")) ?? enabledModels[0]; - setIsSavingConnectionSettings(true); - try { - if (hasConnectionChanges) { - if (testModel) { - const result = await testPreviewModel.mutateAsync({ - provider: connection.provider, - base_url: data.base_url, - api_key: apiKeyForTest, - scope: "SEARCH_SPACE", - search_space_id: connection.search_space_id, - extra: connection.extra ?? {}, - enabled: connection.enabled, - models: [], - model_id: testModel.model_id, - }); - if (!result.ok) return; - } - await updateConnection.mutateAsync({ id: connection.id, data }); - setApiKeyDraft(""); - } - - if (hasModelChanges) { - await saveModelChanges(); - } - } finally { - setIsSavingConnectionSettings(false); - } - } - - function saveAllowlist() { - const ids = allowlistText - .split(",") - .map((value) => value.trim()) - .filter(Boolean); - updateConnection.mutate({ - id: connection.id, - data: { extra: { ...(connection.extra ?? {}), model_ids: ids } }, - }); - } - - function handleToggleModel(model: SelectableModel, enabled: boolean) { - if (typeof model.id !== "number") return; - const modelId = model.id; - setDraftEnabledModelIds((current) => { - const next = new Set(current); - if (enabled) { - next.add(modelId); - } else { - next.delete(modelId); - } - return next; - }); - } - - function handleBulkToggle(models: SelectableModel[], enabled: boolean) { - const modelIds = models - .map((model) => model.id) - .filter((id): id is number => typeof id === "number"); - if (modelIds.length === 0) return; - setDraftEnabledModelIds((current) => { - const next = new Set(current); - for (const id of modelIds) { - if (enabled) { - next.add(id); - } else { - next.delete(id); - } - } - return next; - }); - } - - return ( - <Dialog open={isOpen} onOpenChange={handleOpenChange}> - <DialogTrigger asChild> - <Button - variant="ghost" - size="icon" - className="text-muted-foreground hover:text-accent-foreground" - aria-label={`Configure ${providerLabel}`} - > - <Settings className="h-4 w-4" /> - </Button> - </DialogTrigger> - <DialogContent className="flex max-h-[90vh] max-w-3xl flex-col overflow-hidden bg-popover p-0 text-popover-foreground"> - <DialogHeader className="shrink-0 border-b px-6 py-5"> - <div className="flex items-center gap-3"> - {providerIcon(connection.provider, "size-5")} - <div> - <DialogTitle> - Configure <span className="italic">{providerLabel}</span> - </DialogTitle> - <DialogDescription> - Manage credentials and choose which models are available from this provider. - </DialogDescription> - </div> - </div> - </DialogHeader> - - <div className="min-h-0 flex-1 overflow-y-auto px-6 py-5"> - <div className="space-y-6"> - <div className="space-y-2"> - <Label>API Base URL</Label> - <Input - value={baseUrlDraft} - onChange={(event) => setBaseUrlDraft(event.target.value)} - placeholder="https://api.example.com/v1" - /> - <p className="text-xs text-muted-foreground"> - Leave empty to use the provider default endpoint. - </p> - </div> - - <div className="space-y-2"> - <Label>API Key</Label> - <div className="relative"> - <Input - value={apiKeyDraft} - onChange={(event) => setApiKeyDraft(event.target.value)} - placeholder={connection.has_api_key ? "Saved API key" : "Paste an API key"} - type={showApiKey ? "text" : "password"} - className="pr-11" - /> - <Button - type="button" - variant="ghost" - size="icon" - className="absolute top-1/2 right-1 size-8 -translate-y-1/2 text-muted-foreground" - onClick={() => setShowApiKey((current) => !current)} - disabled={!apiKeyDraft} - aria-label={showApiKey ? "Hide API key" : "Show API key"} - > - {showApiKey ? <EyeOff className="h-4 w-4" /> : <Eye className="h-4 w-4" />} - </Button> - </div> - </div> - - {!isLocal ? ( - <div className="space-y-2"> - <Label className="text-xs">Model IDs filter (optional)</Label> - <div className="flex gap-2"> - <Input - value={allowlistText} - onChange={(event) => setAllowlistText(event.target.value)} - placeholder="Comma-separated, e.g. anthropic/claude-sonnet-4-5, google/gemini-2.5-pro" - /> - <Button size="sm" onClick={saveAllowlist} disabled={updateConnection.isPending}> - Save filter - </Button> - </div> - <p className="text-xs text-muted-foreground"> - Leave empty to discover all models. Recommended for providers with large catalogs. - </p> - </div> - ) : null} - - <Separator className="bg-muted-foreground/20" /> - - <ModelsSelectionPanel - models={draftModels} - isRefreshing={discoverModels.isPending} - isAddingManual={addManualModel.isPending} - isUpdatingModel={isSavingConnectionSettings} - isBulkUpdating={isSavingConnectionSettings || bulkUpdateModels.isPending} - refreshLabel={`Refresh ${providerLabel} models`} - onRefresh={() => discoverModels.mutate(connection.id)} - onAddManual={(modelId) => - addManualModel.mutate({ - connectionId: connection.id, - data: { model_id: modelId }, - }) - } - onToggleModel={handleToggleModel} - onBulkToggle={handleBulkToggle} - /> - </div> - </div> - - <DialogFooter className="shrink-0 border-t bg-popover px-6 py-4"> - <Button - onClick={saveConnectionSettings} - disabled={isSavingConnectionSettings || !canUpdate} - className="relative min-w-[96px]" - > - <span className={isSavingConnectionSettings ? "opacity-0" : ""}>Update</span> - {isSavingConnectionSettings ? <Spinner size="sm" className="absolute" /> : null} - </Button> - </DialogFooter> - </DialogContent> - </Dialog> - ); -} diff --git a/surfsense_web/components/settings/model-connections/default-connect-form.tsx b/surfsense_web/components/settings/model-connections/default-connect-form.tsx deleted file mode 100644 index e3111202d..000000000 --- a/surfsense_web/components/settings/model-connections/default-connect-form.tsx +++ /dev/null @@ -1,62 +0,0 @@ -import { useEffect, useState } from "react"; -import { ApiBaseUrlField, ApiKeyField } from "./connect-fields"; -import type { ProviderConnectFormProps } from "./provider-metadata"; - -const OPTIONAL_API_KEY_PROVIDERS = new Set(["ollama_chat", "lm_studio", "openai_compatible"]); - -function baseUrlHint(provider: string) { - if (provider === "ollama_chat" || provider === "lm_studio") { - return "For local servers, use host.docker.internal instead of localhost."; - } - if (provider === "openai_compatible") { - return "Enter the full endpoint URL."; - } - if (provider === "openai" || provider === "anthropic" || provider === "openrouter") { - return "Override only if you route through a proxy or gateway."; - } - return undefined; -} - -/** - * Connect form for OpenAI-compatible / native key providers (OpenAI, Anthropic, - * OpenRouter, OpenAI-Compatible, LM Studio, Ollama, …). The base URL is - * prefilled from the provider default. - */ -export function DefaultConnectForm({ - provider, - defaultBaseUrl, - baseUrlRequired, - onDraftChange, -}: ProviderConnectFormProps) { - const [baseUrl, setBaseUrl] = useState(defaultBaseUrl); - const [apiKey, setApiKey] = useState(""); - const isApiKeyOptional = OPTIONAL_API_KEY_PROVIDERS.has(provider); - const hint = baseUrlHint(provider); - const apiKeyValue = apiKey.trim(); - const canSubmit = - !(baseUrlRequired && !baseUrl.trim()) && (isApiKeyOptional || Boolean(apiKeyValue)); - - useEffect(() => { - onDraftChange( - { base_url: baseUrl || null, api_key: apiKeyValue || null, extra: {} }, - canSubmit - ); - }, [apiKeyValue, baseUrl, canSubmit, onDraftChange]); - - return ( - <div className="flex flex-col gap-4"> - <ApiBaseUrlField - value={baseUrl} - onChange={setBaseUrl} - placeholder={defaultBaseUrl} - hint={hint} - /> - <ApiKeyField - value={apiKey} - onChange={setApiKey} - label={isApiKeyOptional ? "API Key (optional)" : "API Key"} - placeholder="Enter your API key" - /> - </div> - ); -} diff --git a/surfsense_web/components/settings/model-connections/model-provider-connections-panel.tsx b/surfsense_web/components/settings/model-connections/model-provider-connections-panel.tsx deleted file mode 100644 index a703ab1c8..000000000 --- a/surfsense_web/components/settings/model-connections/model-provider-connections-panel.tsx +++ /dev/null @@ -1,299 +0,0 @@ -"use client"; - -import { useAtomValue } from "jotai"; -import { type ReactNode, useState } from "react"; -import { toast } from "sonner"; -import { - createModelConnectionMutationAtom, - previewConnectionModelsMutationAtom, - testPreviewModelMutationAtom, -} from "@/atoms/model-connections/model-connections-mutation.atoms"; -import { modelProvidersAtom } from "@/atoms/model-connections/model-connections-query.atoms"; -import { Button } from "@/components/ui/button"; -import { Separator } from "@/components/ui/separator"; -import type { ConnectionRead, ModelSelection } from "@/contracts/types/model-connections.types"; -import { ConnectionCard } from "./connection-card"; -import { capability, type SelectableModel } from "./model-utils"; -import { ProviderConnectDialog } from "./provider-connect-dialog"; -import { - type ConnectionDraft, - PROVIDER_ORDER, - providerDisplay, - providerIcon, -} from "./provider-metadata"; - -interface ModelProviderConnectionsPanelProps { - searchSpaceId: number; - connections: ConnectionRead[]; - className?: string; - addProviderTitle?: string; - addProviderDescription?: string; - availableProvidersTitle?: string; - footerAction?: ReactNode; - showAddProviderHeader?: boolean; -} - -function toModelSelection(model: SelectableModel): ModelSelection { - return { - model_id: model.model_id, - display_name: model.display_name, - source: model.source || "DISCOVERED", - supports_chat: model.supports_chat, - max_input_tokens: model.max_input_tokens, - supports_image_input: model.supports_image_input, - supports_tools: model.supports_tools, - supports_image_generation: model.supports_image_generation, - enabled: model.enabled, - metadata: "metadata" in model ? (model.metadata ?? {}) : (model.catalog ?? {}), - }; -} - -export function ModelProviderConnectionsPanel({ - searchSpaceId, - connections, - className, - addProviderTitle = "Add Provider", - addProviderDescription = "SurfSense supports popular providers and self-hosted model endpoints.", - availableProvidersTitle = "Available Providers", - footerAction, - showAddProviderHeader = true, -}: ModelProviderConnectionsPanelProps) { - const { data: providers = [] } = useAtomValue(modelProvidersAtom); - const createConnection = useAtomValue(createModelConnectionMutationAtom); - const previewModels = useAtomValue(previewConnectionModelsMutationAtom); - const testPreviewModel = useAtomValue(testPreviewModelMutationAtom); - - const [isAddProviderOpen, setIsAddProviderOpen] = useState(false); - const [provider, setProvider] = useState("openai_compatible"); - const [connectModels, setConnectModels] = useState<ModelSelection[]>([]); - const selectedProvider = providers.find((item) => item.provider === provider); - - const sortedProviders = [...providers].sort((left, right) => { - const leftIndex = PROVIDER_ORDER.indexOf(left.provider); - const rightIndex = PROVIDER_ORDER.indexOf(right.provider); - if (leftIndex !== -1 || rightIndex !== -1) { - return ( - (leftIndex === -1 ? Number.MAX_SAFE_INTEGER : leftIndex) - - (rightIndex === -1 ? Number.MAX_SAFE_INTEGER : rightIndex) - ); - } - return providerDisplay(left.provider).name.localeCompare(providerDisplay(right.provider).name); - }); - - function resetConnectState() { - setConnectModels([]); - } - - function handleConnectOpenChange(open: boolean) { - setIsAddProviderOpen(open); - if (!open) { - resetConnectState(); - } - } - - function mergePreviewModels(fetchedModels: SelectableModel[]) { - setConnectModels((current) => { - const currentById = new Map(current.map((model) => [model.model_id, model])); - return fetchedModels.map((model) => { - const prior = currentById.get(model.model_id); - return { - ...toModelSelection(model), - enabled: prior ? prior.enabled : model.enabled, - }; - }); - }); - } - - function connectionModelsForDraft(draft: ConnectionDraft) { - const models = [...connectModels]; - if (draft.seedModelId && !models.some((model) => model.model_id === draft.seedModelId)) { - models.push({ - model_id: draft.seedModelId, - display_name: draft.seedModelId, - source: "MANUAL", - enabled: true, - metadata: {}, - }); - } - return models; - } - - function representativeTestModel(models: ModelSelection[]) { - const enabledModels = models.filter((model) => model.enabled); - return enabledModels.find((model) => capability(model, "chat")) ?? enabledModels[0]; - } - - // Each provider connect form builds its own credential payload; the backend - // resolver (`to_litellm`) forwards `extra.litellm_params` straight to LiteLLM. - function handleCreate(draft: ConnectionDraft) { - const models = connectionModelsForDraft(draft); - const testModel = representativeTestModel(models); - if (!testModel) { - toast.error("Select at least one model before connecting"); - return; - } - - const request = { - provider, - base_url: draft.base_url, - api_key: draft.api_key, - scope: "SEARCH_SPACE" as const, - search_space_id: searchSpaceId, - extra: draft.extra, - enabled: true, - models, - }; - - testPreviewModel.mutate( - { ...request, model_id: testModel.model_id }, - { - onSuccess: (result) => { - if (!result.ok) return; - createConnection.mutate(request, { - onSuccess: () => { - setIsAddProviderOpen(false); - resetConnectState(); - }, - }); - }, - } - ); - } - - function openProviderDialog(providerId: string) { - resetConnectState(); - setProvider(providerId); - setIsAddProviderOpen(true); - if (providerId === "vertex_ai") { - previewModels.mutate( - { - provider: providerId, - base_url: null, - api_key: null, - scope: "SEARCH_SPACE", - search_space_id: searchSpaceId, - extra: {}, - enabled: true, - models: [], - }, - { - onSuccess: mergePreviewModels, - } - ); - } - } - - function refreshConnectModels(draft: ConnectionDraft) { - previewModels.mutate( - { - provider, - base_url: draft.base_url, - api_key: draft.api_key, - scope: "SEARCH_SPACE", - search_space_id: searchSpaceId, - extra: draft.extra, - enabled: true, - models: [], - }, - { - onSuccess: mergePreviewModels, - } - ); - } - - function addConnectModel(modelId: string) { - setConnectModels((current) => { - if (current.some((model) => model.model_id === modelId)) return current; - return [ - ...current, - { - model_id: modelId, - display_name: modelId, - source: "MANUAL", - enabled: true, - metadata: {}, - }, - ]; - }); - } - - function toggleConnectModel(model: SelectableModel, enabled: boolean) { - setConnectModels((current) => - current.map((item) => (item.model_id === model.model_id ? { ...item, enabled } : item)) - ); - } - - function bulkToggleConnectModels(models: SelectableModel[], enabled: boolean) { - const modelIds = new Set(models.map((model) => model.model_id)); - setConnectModels((current) => - current.map((item) => (modelIds.has(item.model_id) ? { ...item, enabled } : item)) - ); - } - - return ( - <div className={className ?? "flex flex-col gap-6"}> - <div className="flex flex-col gap-3"> - {showAddProviderHeader ? ( - <div> - <h3 className="text-base font-semibold">{addProviderTitle}</h3> - <p className="text-sm text-muted-foreground">{addProviderDescription}</p> - </div> - ) : null} - <div className="grid gap-3 md:grid-cols-2"> - {sortedProviders.map((item) => { - const meta = providerDisplay(item.provider); - - return ( - <Button - key={item.provider} - variant="ghost" - type="button" - className="h-auto justify-between gap-3 whitespace-normal rounded-lg border border-border/60 p-4 text-left transition-colors hover:bg-accent hover:text-accent-foreground" - onClick={() => openProviderDialog(item.provider)} - > - <span className="flex min-w-0 items-center gap-3"> - {providerIcon(item.provider, "size-5")} - <span className="min-w-0"> - <span className="block truncate text-sm font-semibold">{meta.name}</span> - <span className="block truncate text-xs text-muted-foreground"> - {meta.subtitle} - </span> - </span> - </span> - <span className="shrink-0 text-sm font-medium text-muted-foreground">Connect</span> - </Button> - ); - })} - </div> - </div> - - <ProviderConnectDialog - open={isAddProviderOpen} - onOpenChange={handleConnectOpenChange} - provider={provider} - selectedProvider={selectedProvider} - isPending={createConnection.isPending || testPreviewModel.isPending} - onSubmit={handleCreate} - previewModels={connectModels} - isPreviewingModels={previewModels.isPending} - onPreviewModels={refreshConnectModels} - onAddPreviewModel={addConnectModel} - onTogglePreviewModel={toggleConnectModel} - onBulkTogglePreviewModels={bulkToggleConnectModels} - /> - - {connections.length > 0 ? ( - <div className="flex flex-col gap-3"> - <Separator /> - <h3 className="text-base font-semibold">{availableProvidersTitle}</h3> - <div className="flex flex-col gap-3"> - {connections.map((connection) => ( - <ConnectionCard key={connection.id} connection={connection} /> - ))} - </div> - </div> - ) : null} - {footerAction ? <div className="flex justify-center pt-2">{footerAction}</div> : null} - </div> - ); -} diff --git a/surfsense_web/components/settings/model-connections/model-utils.ts b/surfsense_web/components/settings/model-connections/model-utils.ts deleted file mode 100644 index 2887f2179..000000000 --- a/surfsense_web/components/settings/model-connections/model-utils.ts +++ /dev/null @@ -1,30 +0,0 @@ -import type { ModelPreviewRead, ModelRead } from "@/contracts/types/model-connections.types"; - -export type ModelCapabilityFilter = "chat" | "vision" | "image_gen"; - -export const MODEL_CAPABILITY_FILTERS: { key: ModelCapabilityFilter; label: string }[] = [ - { key: "chat", label: "Chat" }, - { key: "vision", label: "Vision" }, - { key: "image_gen", label: "Image" }, -]; - -export type SelectableModel = (ModelRead | ModelPreviewRead) & { - id?: number | string; - connection_id?: number; -}; - -export function modelLabel(model: SelectableModel) { - return model.display_name || model.model_id; -} - -export function capability(model: SelectableModel, key: ModelCapabilityFilter) { - if (key === "chat") return Boolean(model.supports_chat); - if (key === "vision") return Boolean(model.supports_image_input); - return Boolean(model.supports_image_generation); -} - -export function capabilityLabels(model: SelectableModel) { - return MODEL_CAPABILITY_FILTERS.filter((filter) => capability(model, filter.key)) - .map((filter) => filter.label.toLowerCase()) - .join(", "); -} diff --git a/surfsense_web/components/settings/model-connections/models-selection-panel.tsx b/surfsense_web/components/settings/model-connections/models-selection-panel.tsx deleted file mode 100644 index 3c6990afb..000000000 --- a/surfsense_web/components/settings/model-connections/models-selection-panel.tsx +++ /dev/null @@ -1,198 +0,0 @@ -import { RefreshCw } from "lucide-react"; -import { useState } from "react"; -import { Badge } from "@/components/ui/badge"; -import { Button } from "@/components/ui/button"; -import { Checkbox } from "@/components/ui/checkbox"; -import { Input } from "@/components/ui/input"; -import { Spinner } from "@/components/ui/spinner"; -import { - capability, - capabilityLabels, - MODEL_CAPABILITY_FILTERS, - type ModelCapabilityFilter, - modelLabel, - type SelectableModel, -} from "./model-utils"; - -interface ModelsSelectionPanelProps { - models: SelectableModel[]; - description?: string; - emptyMessage?: string; - manualInputPlaceholder?: string; - refreshLabel?: string; - isRefreshing?: boolean; - isAddingManual?: boolean; - isUpdatingModel?: boolean; - isBulkUpdating?: boolean; - onRefresh?: () => void; - onAddManual?: (modelId: string) => void; - onToggleModel?: (model: SelectableModel, enabled: boolean) => void; - onBulkToggle?: (models: SelectableModel[], enabled: boolean) => void; -} - -export function ModelsSelectionPanel({ - models, - description = "Select models to make available for this provider.", - emptyMessage = "No models available.", - manualInputPlaceholder = "Add a model ID manually", - refreshLabel = "Refresh models", - isRefreshing = false, - isAddingManual = false, - isUpdatingModel = false, - isBulkUpdating = false, - onRefresh, - onAddManual, - onToggleModel, - onBulkToggle, -}: ModelsSelectionPanelProps) { - const [manualModelId, setManualModelId] = useState(""); - const [modelFilter, setModelFilter] = useState<ModelCapabilityFilter | null>(null); - - const filteredModels = modelFilter - ? models.filter((model) => capability(model, modelFilter)) - : models; - const allFilteredModelsEnabled = - filteredModels.length > 0 && filteredModels.every((model) => model.enabled); - - function addModel() { - const modelId = manualModelId.trim(); - if (!modelId || !onAddManual) return; - onAddManual(modelId); - setManualModelId(""); - } - - function toggleFilteredModels() { - const nextEnabled = !allFilteredModelsEnabled; - const changedModels = filteredModels.filter((model) => model.enabled !== nextEnabled); - if (changedModels.length === 0) return; - onBulkToggle?.(changedModels, nextEnabled); - } - - return ( - <div className="space-y-3"> - <div className="flex flex-wrap items-start justify-between gap-3"> - <div> - <div className="font-semibold">Models</div> - <p className="text-sm text-muted-foreground">{description}</p> - </div> - <div className="flex flex-wrap items-center gap-2"> - <Button - variant="ghost" - size="sm" - type="button" - onClick={toggleFilteredModels} - disabled={!onBulkToggle || isBulkUpdating || filteredModels.length === 0} - > - {allFilteredModelsEnabled ? "Deselect All" : "Select All"} - </Button> - {onRefresh ? ( - <Button - variant="ghost" - size="icon" - type="button" - onClick={onRefresh} - disabled={isRefreshing} - aria-label={refreshLabel} - > - <RefreshCw className={`h-4 w-4 ${isRefreshing ? "animate-spin" : ""}`} /> - </Button> - ) : null} - </div> - </div> - - {onAddManual ? ( - <div className="flex gap-2"> - <Input - value={manualModelId} - onChange={(event) => setManualModelId(event.target.value)} - onKeyDown={(event) => { - if (event.key === "Enter") { - event.preventDefault(); - addModel(); - } - }} - placeholder={manualInputPlaceholder} - /> - <Button - size="sm" - type="button" - onClick={addModel} - disabled={isAddingManual || !manualModelId.trim()} - className="relative min-w-[88px]" - > - <span className={isAddingManual ? "opacity-0" : ""}>Add model</span> - {isAddingManual ? <Spinner size="xs" className="absolute" /> : null} - </Button> - </div> - ) : null} - - {models.length > 0 ? ( - <div className="flex flex-wrap items-center gap-2"> - <span className="text-xs font-medium text-muted-foreground">Filter models</span> - {MODEL_CAPABILITY_FILTERS.map((filter) => { - const count = models.filter((model) => capability(model, filter.key)).length; - const isActive = modelFilter === filter.key; - - return ( - <Button - key={filter.key} - type="button" - variant="secondary" - size="sm" - className={`h-7 rounded-full px-3 text-xs ${isActive ? "" : "opacity-80"}`} - onClick={() => setModelFilter(isActive ? null : filter.key)} - > - {filter.label} - <span className="ml-1 text-muted-foreground">{count}</span> - </Button> - ); - })} - </div> - ) : null} - - <div className="h-80 overflow-y-auto rounded-xl border bg-muted/20 p-2"> - {models.length === 0 ? ( - <div className="rounded-lg px-3 py-6 text-center text-sm text-muted-foreground"> - {emptyMessage} - </div> - ) : null} - {filteredModels.length === 0 && modelFilter ? ( - <div className="rounded-lg px-3 py-6 text-center text-sm text-muted-foreground"> - No{" "} - {MODEL_CAPABILITY_FILTERS.find( - (filter) => filter.key === modelFilter - )?.label.toLowerCase()}{" "} - models found on this connection. - </div> - ) : null} - <div className="space-y-2"> - {filteredModels.map((model) => ( - <div - key={model.id ?? model.model_id} - className="flex items-center gap-3 rounded-lg px-3 py-2 transition-colors hover:bg-background" - > - <Checkbox - checked={model.enabled} - onCheckedChange={(checked) => onToggleModel?.(model, checked === true)} - disabled={!onToggleModel || isUpdatingModel} - /> - <div className="min-w-0 flex-1"> - <div className="flex items-center gap-2 text-sm font-medium"> - <span className="truncate">{modelLabel(model)}</span> - {model.source === "MANUAL" ? ( - <Badge variant="outline" className="text-[10px]"> - manual - </Badge> - ) : null} - </div> - <div className="text-xs text-muted-foreground"> - {capabilityLabels(model) || "No discovered capabilities"} - </div> - </div> - </div> - ))} - </div> - </div> - </div> - ); -} diff --git a/surfsense_web/components/settings/model-connections/provider-connect-dialog.tsx b/surfsense_web/components/settings/model-connections/provider-connect-dialog.tsx deleted file mode 100644 index 2eee2cf8c..000000000 --- a/surfsense_web/components/settings/model-connections/provider-connect-dialog.tsx +++ /dev/null @@ -1,155 +0,0 @@ -import { useCallback, useRef, useState } from "react"; -import { - Dialog, - DialogContent, - DialogDescription, - DialogHeader, - DialogTitle, -} from "@/components/ui/dialog"; -import { Separator } from "@/components/ui/separator"; -import type { ModelProviderRead } from "@/contracts/types/model-connections.types"; -import { AzureConnectForm } from "./azure-connect-form"; -import { BedrockConnectForm } from "./bedrock-connect-form"; -import { ConnectFormFooter } from "./connect-fields"; -import { DefaultConnectForm } from "./default-connect-form"; -import type { SelectableModel } from "./model-utils"; -import { ModelsSelectionPanel } from "./models-selection-panel"; -import { - type ConnectionDraft, - type ProviderConnectFormProps, - providerDefaultBaseUrl, - providerDisplay, - providerIcon, -} from "./provider-metadata"; -import { VertexConnectForm } from "./vertex-connect-form"; - -interface ProviderConnectDialogProps { - open: boolean; - onOpenChange: (open: boolean) => void; - provider: string; - selectedProvider?: ModelProviderRead; - isPending: boolean; - onSubmit: (draft: ConnectionDraft) => void; - previewModels?: SelectableModel[]; - isPreviewingModels?: boolean; - onPreviewModels?: (draft: ConnectionDraft) => void; - onAddPreviewModel?: (modelId: string) => void; - onTogglePreviewModel?: (model: SelectableModel, enabled: boolean) => void; - onBulkTogglePreviewModels?: (models: SelectableModel[], enabled: boolean) => void; -} - -/** - * Shared dialog shell for the "Add Provider" flow. It owns the header and routes - * to the provider-specific connect form. Forms remount on open (Radix unmounts - * closed content), so each gets fresh, prefilled state. - */ -export function ProviderConnectDialog({ - open, - onOpenChange, - provider, - selectedProvider, - isPending, - onSubmit, - previewModels = [], - isPreviewingModels = false, - onPreviewModels, - onAddPreviewModel, - onTogglePreviewModel, - onBulkTogglePreviewModels, -}: ProviderConnectDialogProps) { - const meta = providerDisplay(provider); - const isAzure = provider === "azure"; - const isBedrock = provider === "bedrock"; - const isVertex = provider === "vertex_ai"; - const titleRef = useRef<HTMLHeadingElement>(null); - const [currentDraft, setCurrentDraft] = useState<ConnectionDraft>({ - base_url: null, - api_key: null, - extra: {}, - }); - const [canSubmit, setCanSubmit] = useState(false); - - const handleDraftChange = useCallback((draft: ConnectionDraft, nextCanSubmit: boolean) => { - setCurrentDraft(draft); - setCanSubmit(nextCanSubmit); - }, []); - - const formProps: ProviderConnectFormProps = { - provider, - defaultBaseUrl: providerDefaultBaseUrl(provider, selectedProvider?.default_base_url), - baseUrlRequired: Boolean(selectedProvider?.base_url_required), - onDraftChange: handleDraftChange, - }; - - const modelDescription = (() => { - if (isAzure) { - return "Select the models to enable for Azure OpenAI"; - } - if (isBedrock) { - return "Select the models to enable for Amazon Bedrock"; - } - if (isVertex) { - return "Select the models to enable for Gemini"; - } - return "Select the models to enable for this provider"; - })(); - - const canRefreshModels = !isAzure && !isVertex && (!isBedrock || canSubmit); - const hasEnabledModel = - previewModels.some((model) => model.enabled) || Boolean(currentDraft.seedModelId); - const canConnect = canSubmit && hasEnabledModel; - - return ( - <Dialog open={open} onOpenChange={onOpenChange}> - <DialogContent - className="flex h-[85vh] max-h-[760px] min-h-[640px] max-w-2xl flex-col overflow-hidden bg-popover p-0 text-popover-foreground" - onOpenAutoFocus={(event) => { - event.preventDefault(); - titleRef.current?.focus(); - }} - > - <DialogHeader className="shrink-0 border-b px-6 py-5"> - <div className="flex items-center gap-3"> - {providerIcon(provider, "size-5")} - <div> - <DialogTitle ref={titleRef} tabIndex={-1}> - Connect {meta.name} - </DialogTitle> - <DialogDescription>{meta.subtitle}</DialogDescription> - </div> - </div> - </DialogHeader> - <div className="min-h-0 flex-1 space-y-5 overflow-y-auto px-6 py-5"> - {provider === "azure" ? ( - <AzureConnectForm {...formProps} /> - ) : provider === "bedrock" ? ( - <BedrockConnectForm {...formProps} /> - ) : provider === "vertex_ai" ? ( - <VertexConnectForm {...formProps} /> - ) : ( - <DefaultConnectForm {...formProps} /> - )} - - <Separator className="bg-muted-foreground/20" /> - - <ModelsSelectionPanel - models={previewModels} - description={modelDescription} - isRefreshing={isPreviewingModels} - refreshLabel={`Refresh ${meta.name} models`} - onRefresh={canRefreshModels ? () => onPreviewModels?.(currentDraft) : undefined} - onAddManual={onAddPreviewModel} - onToggleModel={onTogglePreviewModel} - onBulkToggle={onBulkTogglePreviewModels} - /> - </div> - <ConnectFormFooter - onCancel={() => onOpenChange(false)} - onSubmit={() => onSubmit(currentDraft)} - canSubmit={canConnect} - isPending={isPending} - /> - </DialogContent> - </Dialog> - ); -} diff --git a/surfsense_web/components/settings/model-connections/provider-metadata.tsx b/surfsense_web/components/settings/model-connections/provider-metadata.tsx deleted file mode 100644 index 8b8a877b9..000000000 --- a/surfsense_web/components/settings/model-connections/provider-metadata.tsx +++ /dev/null @@ -1,137 +0,0 @@ -import { getProviderIcon } from "@/lib/provider-icons"; - -export const PROVIDER_ORDER = [ - "openai", - "anthropic", - "vertex_ai", - "bedrock", - "azure", - "openrouter", - "ollama_chat", - "lm_studio", - "openai_compatible", -]; - -export const PROVIDER_DISPLAY: Record< - string, - { name: string; subtitle: string; iconKey?: string; defaultBaseUrl?: string } -> = { - anthropic: { - name: "Claude", - subtitle: "Anthropic", - iconKey: "claude", - defaultBaseUrl: "https://api.anthropic.com/v1", - }, - azure: { name: "Azure OpenAI", subtitle: "Microsoft Azure", iconKey: "azure" }, - bedrock: { name: "Amazon Bedrock", subtitle: "AWS", iconKey: "bedrock" }, - lm_studio: { name: "LM Studio", subtitle: "LM Studio", iconKey: "lm_studio" }, - ollama_chat: { name: "Ollama", subtitle: "Ollama", iconKey: "ollama" }, - openai: { - name: "GPT", - subtitle: "OpenAI", - iconKey: "openai", - defaultBaseUrl: "https://api.openai.com/v1", - }, - openai_compatible: { - name: "OpenAI-Compatible", - subtitle: "OpenAI-compatible endpoint", - iconKey: "custom", - }, - openrouter: { - name: "OpenRouter", - subtitle: "OpenRouter", - iconKey: "openrouter", - defaultBaseUrl: "https://openrouter.ai/api/v1", - }, - vertex_ai: { name: "Gemini", subtitle: "Google Cloud Vertex AI", iconKey: "vertex_ai" }, -}; - -export function providerDisplay(provider: string) { - const fallback = provider - .split("_") - .filter(Boolean) - .map((part) => part.charAt(0).toUpperCase() + part.slice(1)) - .join(" "); - - return ( - PROVIDER_DISPLAY[provider] ?? { - name: fallback || provider, - subtitle: provider, - iconKey: provider, - } - ); -} - -export function providerIcon(provider: string, className = "size-4") { - return getProviderIcon(providerDisplay(provider).iconKey ?? provider, { className }); -} - -export function providerDefaultBaseUrl(provider: string, registryDefault?: string | null) { - return registryDefault ?? PROVIDER_DISPLAY[provider]?.defaultBaseUrl ?? ""; -} - -export const AWS_REGION_OPTIONS = [ - "us-east-1", - "us-east-2", - "us-west-2", - "us-gov-east-1", - "us-gov-west-1", - "ap-northeast-1", - "ap-south-1", - "ap-southeast-1", - "ap-southeast-2", - "ap-east-1", - "ca-central-1", - "eu-central-1", - "eu-west-2", -]; - -export const VERTEX_DEFAULT_LOCATION = "global"; - -export const BEDROCK_AUTH_IAM = "iam"; -export const BEDROCK_AUTH_ACCESS_KEY = "access_key"; -export const BEDROCK_AUTH_LONG_TERM_API_KEY = "long_term_api_key"; - -export const VERTEX_AUTH_SERVICE_ACCOUNT = "service_account_json"; -export const VERTEX_AUTH_WORKLOAD_IDENTITY = "workload_identity"; - -// Mirrors Onyx's Azure "Target URI" parser: the user pastes the full endpoint -// (e.g. https://res.cognitiveservices.azure.com/openai/deployments/<dep>/chat/completions?api-version=<ver>) -// which we split into api base (origin), api version, and deployment name. -export function parseAzureTargetUri(rawUri: string) { - try { - const url = new URL(rawUri); - const deploymentMatch = url.pathname.match(/\/openai\/deployments\/([^/]+)/i); - return { - origin: url.origin, - apiVersion: url.searchParams.get("api-version")?.trim() ?? "", - deploymentName: deploymentMatch?.[1] ? deploymentMatch[1].toLowerCase() : "", - isResponsesPath: /\/openai\/responses/i.test(url.pathname), - }; - } catch { - return null; - } -} - -export function isValidAzureTargetUri(rawUri: string) { - const parsed = parseAzureTargetUri(rawUri); - if (!parsed) return false; - return Boolean(parsed.apiVersion) && (Boolean(parsed.deploymentName) || parsed.isResponsesPath); -} - -/** Connection payload produced by a provider connect form. */ -export interface ConnectionDraft { - base_url: string | null; - api_key: string | null; - extra: Record<string, unknown>; - /** Model id to seed after creation (providers without discovery, e.g. Azure). */ - seedModelId?: string; -} - -/** Props shared by every provider-specific connect form. */ -export interface ProviderConnectFormProps { - provider: string; - defaultBaseUrl: string; - baseUrlRequired: boolean; - onDraftChange: (draft: ConnectionDraft, canSubmit: boolean) => void; -} diff --git a/surfsense_web/components/settings/model-connections/vertex-connect-form.tsx b/surfsense_web/components/settings/model-connections/vertex-connect-form.tsx deleted file mode 100644 index 1027742bc..000000000 --- a/surfsense_web/components/settings/model-connections/vertex-connect-form.tsx +++ /dev/null @@ -1,118 +0,0 @@ -import { useEffect, useState } from "react"; -import { Input } from "@/components/ui/input"; -import { Label } from "@/components/ui/label"; -import { - Select, - SelectContent, - SelectItem, - SelectTrigger, - SelectValue, -} from "@/components/ui/select"; -import { - type ProviderConnectFormProps, - VERTEX_AUTH_SERVICE_ACCOUNT, - VERTEX_AUTH_WORKLOAD_IDENTITY, - VERTEX_DEFAULT_LOCATION, -} from "./provider-metadata"; - -/** - * Google Vertex AI (Gemini) connect form. Service-account auth uploads a - * credentials JSON file (read into a string); workload identity collects a - * project id. Credentials ride along in `extra.litellm_params`. - */ -export function VertexConnectForm({ onDraftChange }: ProviderConnectFormProps) { - const [authMethod, setAuthMethod] = useState(VERTEX_AUTH_SERVICE_ACCOUNT); - const [location, setLocation] = useState(VERTEX_DEFAULT_LOCATION); - const [credentials, setCredentials] = useState(""); - const [project, setProject] = useState(""); - - const canSubmit = - authMethod === VERTEX_AUTH_SERVICE_ACCOUNT ? Boolean(credentials) : Boolean(project); - - async function handleCredentialsFile(file: File | undefined) { - if (!file) return; - setCredentials(await file.text()); - } - - useEffect(() => { - const params: Record<string, string> = {}; - if (location) params.vertex_location = location; - if (authMethod === VERTEX_AUTH_SERVICE_ACCOUNT) { - if (credentials) params.vertex_credentials = credentials; - } else if (project) { - params.vertex_project = project; - } - onDraftChange({ base_url: null, api_key: null, extra: { litellm_params: params } }, canSubmit); - }, [authMethod, canSubmit, credentials, location, onDraftChange, project]); - - return ( - <div className="flex flex-col gap-4"> - <div className="flex flex-col gap-2"> - <Label>Authentication Method</Label> - <Select value={authMethod} onValueChange={setAuthMethod}> - <SelectTrigger> - <SelectValue /> - </SelectTrigger> - <SelectContent> - <SelectItem value={VERTEX_AUTH_SERVICE_ACCOUNT}>Service Account JSON</SelectItem> - <SelectItem value={VERTEX_AUTH_WORKLOAD_IDENTITY}>Workload Identity (GKE)</SelectItem> - </SelectContent> - </Select> - </div> - <div className="flex flex-col gap-2"> - <Label>Google Cloud Region Name</Label> - <Input - value={location} - onChange={(event) => setLocation(event.target.value)} - placeholder={VERTEX_DEFAULT_LOCATION} - /> - <p className="text-xs text-muted-foreground"> - Region where your Google Vertex AI models are hosted. - </p> - </div> - {authMethod === VERTEX_AUTH_SERVICE_ACCOUNT ? ( - <div className="flex flex-col gap-2"> - <Label>Service Account JSON</Label> - <Input - id="vertex-service-account-json" - type="file" - accept="application/json,.json" - className="sr-only" - onChange={(event) => handleCredentialsFile(event.target.files?.[0])} - /> - <Label - htmlFor="vertex-service-account-json" - className="flex min-h-28 cursor-pointer flex-col items-center justify-center gap-2 rounded-lg border-2 border-dashed border-muted-foreground/40 bg-muted/20 px-4 py-6 text-center transition-colors hover:border-muted-foreground/70 hover:bg-muted/40" - > - <span className="text-sm font-medium"> - {credentials ? "Service account JSON selected" : "Upload service account JSON"} - </span> - <span className="text-xs text-muted-foreground"> - Choose a .json file from Google Cloud - </span> - </Label> - <p className="text-xs text-muted-foreground"> - {credentials - ? "Credentials file loaded." - : "Attach your service account key JSON from Google Cloud."} - </p> - </div> - ) : ( - <div className="flex flex-col gap-2"> - <Label>GCP Project ID</Label> - <Input - value={project} - onChange={(event) => setProject(event.target.value)} - placeholder="my-vertex-project" - /> - <p className="text-xs text-muted-foreground"> - The GCP project where Vertex AI is enabled. - </p> - </div> - )} - <p className="text-xs text-muted-foreground"> - Add Vertex AI model IDs from the provider's settings after connecting. - </p> - </div> - ); -} diff --git a/surfsense_web/components/settings/earn-credits-content.tsx b/surfsense_web/components/settings/more-pages-content.tsx similarity index 80% rename from surfsense_web/components/settings/earn-credits-content.tsx rename to surfsense_web/components/settings/more-pages-content.tsx index 731ea7726..e1b05f4d2 100644 --- a/surfsense_web/components/settings/earn-credits-content.tsx +++ b/surfsense_web/components/settings/more-pages-content.tsx @@ -22,14 +22,7 @@ import { } from "@/lib/posthog/events"; import { cn } from "@/lib/utils"; -// Compact dollar label for a task's reward (e.g. "+$0.03"). -const formatRewardUsd = (micros: number) => { - const dollars = micros / 1_000_000; - if (dollars >= 1) return `+$${dollars.toFixed(2)}`; - return `+$${dollars.toFixed(2)}`; -}; - -export function EarnCreditsContent() { +export function MorePagesContent() { const params = useParams(); const queryClient = useQueryClient(); const searchSpaceId = params?.search_space_id ?? ""; @@ -42,11 +35,11 @@ export function EarnCreditsContent() { queryKey: ["incentive-tasks"], queryFn: () => incentiveTasksApiService.getTasks(), }); - const { data: creditStatus } = useQuery({ - queryKey: ["credit-status"], - queryFn: () => stripeApiService.getCreditStatus(), + const { data: stripeStatus } = useQuery({ + queryKey: ["stripe-status"], + queryFn: () => stripeApiService.getStatus(), }); - const creditBuyingEnabled = creditStatus?.credit_buying_enabled ?? true; + const pageBuyingEnabled = stripeStatus?.page_buying_enabled ?? true; const completeMutation = useMutation({ mutationFn: incentiveTasksApiService.completeTask, @@ -55,7 +48,7 @@ export function EarnCreditsContent() { toast.success(response.message); const task = data?.tasks.find((t) => t.task_type === taskType); if (task) { - trackIncentiveTaskCompleted(taskType, task.credit_micros_reward); + trackIncentiveTaskCompleted(taskType, task.pages_reward); } queryClient.invalidateQueries({ queryKey: ["incentive-tasks"] }); queryClient.invalidateQueries({ queryKey: USER_QUERY_KEY }); @@ -76,12 +69,12 @@ export function EarnCreditsContent() { return ( <div className="w-full space-y-5"> <div className="text-center"> - <h2 className="text-xl font-bold tracking-tight">Earn Credits</h2> - <p className="mt-1 text-sm text-muted-foreground">Earn bonus credits by completing tasks</p> + <h2 className="text-xl font-bold tracking-tight">Get Free Pages</h2> + <p className="mt-1 text-sm text-muted-foreground">Earn bonus pages by completing tasks</p> </div> <div className="space-y-2"> - <h3 className="text-sm font-semibold">Earn Bonus Credits</h3> + <h3 className="text-sm font-semibold">Earn Bonus Pages</h3> {isLoading ? ( <div className="space-y-1.5"> {["github", "reddit", "discord"].map((task) => ( @@ -104,16 +97,14 @@ export function EarnCreditsContent() { <CardContent className="flex items-center gap-3 p-3"> <div className={cn( - "flex h-9 min-w-9 shrink-0 items-center justify-center rounded-full px-2", + "flex h-8 w-8 shrink-0 items-center justify-center rounded-full", task.completed ? "bg-primary text-primary-foreground" : "bg-muted" )} > {task.completed ? ( <Check className="h-3.5 w-3.5" /> ) : ( - <span className="text-[11px] font-semibold tabular-nums"> - {formatRewardUsd(task.credit_micros_reward)} - </span> + <span className="text-xs font-semibold">+{task.pages_reward}</span> )} </div> <p @@ -160,13 +151,15 @@ export function EarnCreditsContent() { <div className="text-center"> <p className="text-sm text-muted-foreground">Need more?</p> - {creditBuyingEnabled ? ( + {pageBuyingEnabled ? ( <Button asChild variant="link" className="text-emerald-600 dark:text-emerald-400"> - <Link href={`/dashboard/${searchSpaceId}/buy-more`}>Buy credits at $1 per $1</Link> + <Link href={`/dashboard/${searchSpaceId}/buy-pages`}> + Buy page packs at $1 per 1,000 + </Link> </Button> ) : ( <p className="text-xs text-muted-foreground"> - Credit purchases are temporarily unavailable. + Page purchases are temporarily unavailable. </p> )} </div> diff --git a/surfsense_web/components/settings/vision-model-manager.tsx b/surfsense_web/components/settings/vision-model-manager.tsx new file mode 100644 index 000000000..31578b4f1 --- /dev/null +++ b/surfsense_web/components/settings/vision-model-manager.tsx @@ -0,0 +1,486 @@ +"use client"; + +import { useAtomValue } from "jotai"; +import { AlertCircle, Dot, Info, Pencil, RefreshCw, Trash2 } from "lucide-react"; +import { useMemo, useState } from "react"; +import { membersAtom, myAccessAtom } from "@/atoms/members/members-query.atoms"; +import { deleteVisionLLMConfigMutationAtom } from "@/atoms/vision-llm-config/vision-llm-config-mutation.atoms"; +import { + globalVisionLLMConfigsAtom, + visionLLMConfigsAtom, +} from "@/atoms/vision-llm-config/vision-llm-config-query.atoms"; +import { VisionConfigDialog } from "@/components/shared/vision-config-dialog"; +import { Alert, AlertDescription } from "@/components/ui/alert"; +import { + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, +} from "@/components/ui/alert-dialog"; +import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar"; +import { Badge } from "@/components/ui/badge"; +import { Button } from "@/components/ui/button"; +import { Card, CardContent } from "@/components/ui/card"; +import { Separator } from "@/components/ui/separator"; +import { Skeleton } from "@/components/ui/skeleton"; +import { Spinner } from "@/components/ui/spinner"; +import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"; +import type { VisionLLMConfig } from "@/contracts/types/new-llm-config.types"; +import { useMediaQuery } from "@/hooks/use-media-query"; +import { getProviderIcon } from "@/lib/provider-icons"; +import { cn } from "@/lib/utils"; + +interface VisionModelManagerProps { + searchSpaceId: number; +} + +function getInitials(name: string): string { + const parts = name.trim().split(/\s+/); + if (parts.length >= 2) { + return (parts[0][0] + parts[1][0]).toUpperCase(); + } + return name.slice(0, 2).toUpperCase(); +} + +export function VisionModelManager({ searchSpaceId }: VisionModelManagerProps) { + const isDesktop = useMediaQuery("(min-width: 768px)"); + + const { + mutateAsync: deleteConfig, + isPending: isDeleting, + error: deleteError, + } = useAtomValue(deleteVisionLLMConfigMutationAtom); + + const { + data: userConfigs, + isFetching: configsLoading, + error: fetchError, + refetch: refreshConfigs, + } = useAtomValue(visionLLMConfigsAtom); + const { data: globalConfigs = [], isFetching: globalLoading } = useAtomValue( + globalVisionLLMConfigsAtom + ); + + const { data: members } = useAtomValue(membersAtom); + const memberMap = useMemo(() => { + const map = new Map<string, { name: string; email?: string; avatarUrl?: string }>(); + if (members) { + for (const m of members) { + map.set(m.user_id, { + name: m.user_display_name || m.user_email || "Unknown", + email: m.user_email || undefined, + avatarUrl: m.user_avatar_url || undefined, + }); + } + } + return map; + }, [members]); + + const { data: access } = useAtomValue(myAccessAtom); + const canCreate = useMemo(() => { + if (!access) return false; + if (access.is_owner) return true; + return access.permissions?.includes("vision_configs:create") ?? false; + }, [access]); + const canDelete = useMemo(() => { + if (!access) return false; + if (access.is_owner) return true; + return access.permissions?.includes("vision_configs:delete") ?? false; + }, [access]); + const canUpdate = canCreate; + const isReadOnly = !canCreate && !canDelete; + + const [isDialogOpen, setIsDialogOpen] = useState(false); + const [editingConfig, setEditingConfig] = useState<VisionLLMConfig | null>(null); + const [configToDelete, setConfigToDelete] = useState<VisionLLMConfig | null>(null); + + const isLoading = configsLoading || globalLoading; + const errors = [deleteError, fetchError].filter(Boolean) as Error[]; + + const openEditDialog = (config: VisionLLMConfig) => { + setEditingConfig(config); + setIsDialogOpen(true); + }; + + const openNewDialog = () => { + setEditingConfig(null); + setIsDialogOpen(true); + }; + + const handleDelete = async () => { + if (!configToDelete) return; + try { + await deleteConfig({ id: configToDelete.id, name: configToDelete.name }); + setConfigToDelete(null); + } catch { + // Error handled by mutation + } + }; + + return ( + <div className="space-y-4 md:space-y-6"> + <div className="flex items-center justify-between"> + <Button + variant="secondary" + size="sm" + onClick={() => refreshConfigs()} + disabled={isLoading} + className="gap-2" + > + <RefreshCw className={cn("h-3.5 w-3.5", configsLoading && "animate-spin")} /> + Refresh + </Button> + {canCreate && ( + <Button + variant="outline" + onClick={openNewDialog} + className="gap-2 border-transparent bg-white text-[#1f1f1f] font-medium hover:bg-zinc-100 hover:text-[#1f1f1f] dark:border-transparent dark:bg-white dark:text-[#1f1f1f]" + > + Add Vision Model + </Button> + )} + </div> + + {errors.map((err) => ( + <div key={err?.message}> + <Alert variant="destructive" className="py-3"> + <AlertCircle className="h-3 w-3 md:h-4 md:w-4 shrink-0" /> + <AlertDescription className="text-xs md:text-sm">{err?.message}</AlertDescription> + </Alert> + </div> + ))} + + {access && !isLoading && isReadOnly && ( + <div> + <Alert> + <Info /> + <AlertDescription> + <p> + You have <span className="font-medium">read-only</span> access to vision model + configurations. Contact a space owner to request additional permissions. + </p> + </AlertDescription> + </Alert> + </div> + )} + {access && !isLoading && !isReadOnly && (!canCreate || !canDelete) && ( + <div> + <Alert> + <Info /> + <AlertDescription> + <p> + You can{" "} + {[canCreate && "create and edit", canDelete && "delete"] + .filter(Boolean) + .join(" and ")}{" "} + vision model configurations + {!canDelete && ", but cannot delete them"}. + </p> + </AlertDescription> + </Alert> + </div> + )} + + {(isLoading || + globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length > 0) && ( + <Alert> + <Info /> + <AlertDescription> + {isLoading ? ( + <div className="flex min-h-[1.625em] items-center"> + <Skeleton className="h-4 w-60 bg-accent-foreground/15" /> + </div> + ) : ( + <p> + <span className="font-medium"> + {globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length}{" "} + global vision{" "} + {globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length === + 1 + ? "model" + : "models"} + </span>{" "} + available from your administrator. {(() => { + const nonAuto = globalConfigs.filter( + (g) => !("is_auto_mode" in g && g.is_auto_mode) + ); + const premium = nonAuto.filter( + (g) => + "billing_tier" in g && + (g as { billing_tier?: string }).billing_tier === "premium" + ).length; + const free = nonAuto.length - premium; + if (premium > 0 && free > 0) { + return `${premium} premium, ${free} free.`; + } + if (premium > 0) { + return `All ${premium} premium — debits your shared credit pool.`; + } + return `All ${free} free.`; + })()} + </p> + )} + </AlertDescription> + </Alert> + )} + + {/* Global Vision Models — read-only cards with per-model Free/Premium + badges. Mirrors the badge palette used by the chat role selector + (`llm-role-manager.tsx`) so the meaning is consistent across + every model-configuration surface (chat / image / vision). */} + {!isLoading && + globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length > 0 && ( + <div className="space-y-3"> + <div className="grid gap-3 grid-cols-1 sm:grid-cols-2 xl:grid-cols-3"> + {globalConfigs + .filter((g) => !("is_auto_mode" in g && g.is_auto_mode)) + .map((cfg) => { + const billingTier = + ("billing_tier" in cfg && + typeof (cfg as { billing_tier?: string }).billing_tier === "string" && + (cfg as { billing_tier?: string }).billing_tier) || + "free"; + const isPremium = billingTier === "premium"; + return ( + <Card + key={cfg.id} + className="group relative overflow-hidden transition-all duration-200 border-accent bg-accent/20 hover:shadow-md h-full" + > + <CardContent className="p-4 flex flex-col gap-3 h-full"> + <div className="flex items-center gap-2 min-w-0"> + <div className="shrink-0"> + {getProviderIcon(cfg.provider, { className: "size-4" })} + </div> + <div className="min-w-0 flex-1 flex items-center gap-1.5"> + <h4 className="text-sm font-semibold tracking-tight truncate"> + {cfg.name} + </h4> + {isPremium ? ( + <Badge + variant="secondary" + className="text-[8px] md:text-[9px] shrink-0 bg-purple-100 text-purple-700 dark:bg-purple-900/50 dark:text-purple-300 border-0" + > + Premium + </Badge> + ) : ( + <Badge + variant="secondary" + className="text-[8px] md:text-[9px] shrink-0 bg-emerald-100 text-emerald-700 dark:bg-emerald-900/50 dark:text-emerald-300 border-0" + > + Free + </Badge> + )} + </div> + </div> + {cfg.description && ( + <p className="text-[11px] text-muted-foreground/70 line-clamp-2"> + {cfg.description} + </p> + )} + <div className="mt-auto space-y-2"> + <Separator className="bg-accent" /> + <div className="flex items-center"> + <span className="text-[11px] text-muted-foreground/60 truncate"> + {cfg.model_name} + </span> + </div> + </div> + </CardContent> + </Card> + ); + })} + </div> + </div> + )} + + {isLoading && ( + <div className="space-y-4 md:space-y-6"> + <div className="space-y-4"> + <div className="grid gap-3 grid-cols-1 sm:grid-cols-2 xl:grid-cols-3"> + {["skeleton-a", "skeleton-b", "skeleton-c"].map((key) => ( + <Card key={key} className="border-accent bg-accent/20"> + <CardContent className="p-4 flex flex-col gap-3 min-h-32"> + <Skeleton className="h-4 w-32 md:w-40 bg-accent" /> + <Skeleton className="h-3 w-full bg-accent" /> + <Skeleton className="h-3 w-24 md:w-28 bg-accent mt-auto" /> + </CardContent> + </Card> + ))} + </div> + </div> + </div> + )} + + {!isLoading && ( + <div className="space-y-4 md:space-y-6"> + {(userConfigs?.length ?? 0) === 0 ? ( + <Card className="border-0 bg-transparent shadow-none"> + <CardContent className="flex flex-col items-center justify-center py-10 md:py-16 text-center"> + <h3 className="text-sm md:text-base font-semibold mb-2">No Vision Models Yet</h3> + <p className="text-[11px] md:text-xs text-muted-foreground max-w-sm mb-4"> + {canCreate + ? "Add your own vision-capable model (GPT-4o, Claude, Gemini, etc.)" + : "No vision models have been added to this space yet. Contact a space owner to add one."} + </p> + </CardContent> + </Card> + ) : ( + <div className="grid gap-3 grid-cols-1 sm:grid-cols-2 xl:grid-cols-3"> + {userConfigs?.map((config) => { + const member = config.user_id ? memberMap.get(config.user_id) : null; + + return ( + <div key={config.id}> + <Card className="group relative overflow-hidden transition-all duration-200 border-accent bg-accent/20 hover:shadow-md h-full"> + <CardContent className="p-4 flex flex-col gap-3 h-full"> + {/* Header: Icon + Name + Actions */} + <div className="flex items-center justify-between gap-2"> + <div className="flex items-center gap-2.5 min-w-0 flex-1"> + <div className="shrink-0"> + {getProviderIcon(config.provider, { className: "size-4" })} + </div> + <div className="min-w-0 flex-1"> + <h4 className="text-sm font-semibold tracking-tight truncate"> + {config.name} + </h4> + {config.description && ( + <p className="text-[11px] text-muted-foreground/70 truncate mt-0.5"> + {config.description} + </p> + )} + </div> + </div> + {(canUpdate || canDelete) && ( + <div className="flex items-center gap-1 shrink-0 sm:w-0 sm:overflow-hidden sm:group-hover:w-auto sm:opacity-0 sm:group-hover:opacity-100 transition-all duration-150"> + {canUpdate && ( + <TooltipProvider> + <Tooltip open={isDesktop ? undefined : false}> + <TooltipTrigger asChild> + <Button + variant="ghost" + size="icon" + onClick={() => openEditDialog(config)} + className="h-6 w-6 text-muted-foreground hover:text-accent-foreground" + > + <Pencil className="h-3 w-3" /> + </Button> + </TooltipTrigger> + <TooltipContent>Edit</TooltipContent> + </Tooltip> + </TooltipProvider> + )} + {canDelete && ( + <TooltipProvider> + <Tooltip open={isDesktop ? undefined : false}> + <TooltipTrigger asChild> + <Button + variant="ghost" + size="icon" + onClick={() => setConfigToDelete(config)} + className="h-7 w-7 rounded-lg text-muted-foreground hover:text-destructive" + > + <Trash2 className="h-3 w-3" /> + </Button> + </TooltipTrigger> + <TooltipContent>Delete</TooltipContent> + </Tooltip> + </TooltipProvider> + )} + </div> + )} + </div> + + {/* Footer: Date + Creator */} + <div className="mt-auto space-y-2"> + <Separator className="bg-accent" /> + <div className="flex items-center"> + <span className="shrink-0 text-[11px] text-muted-foreground/60 whitespace-nowrap"> + {new Date(config.created_at).toLocaleDateString(undefined, { + year: "numeric", + month: "short", + day: "numeric", + })} + </span> + {member && ( + <> + <Dot className="h-4 w-4 text-muted-foreground/30 shrink-0" /> + <TooltipProvider> + <Tooltip open={isDesktop ? undefined : false}> + <TooltipTrigger asChild> + <div className="min-w-0 flex items-center gap-1.5 cursor-default"> + <Avatar className="size-4.5 shrink-0"> + {member.avatarUrl && ( + <AvatarImage src={member.avatarUrl} alt={member.name} /> + )} + <AvatarFallback className="text-[9px]"> + {getInitials(member.name)} + </AvatarFallback> + </Avatar> + <span className="text-[11px] text-muted-foreground/60 truncate"> + {member.name} + </span> + </div> + </TooltipTrigger> + <TooltipContent side="bottom"> + {member.email || member.name} + </TooltipContent> + </Tooltip> + </TooltipProvider> + </> + )} + </div> + </div> + </CardContent> + </Card> + </div> + ); + })} + </div> + )} + </div> + )} + + <VisionConfigDialog + open={isDialogOpen} + onOpenChange={(open) => { + setIsDialogOpen(open); + if (!open) setEditingConfig(null); + }} + config={editingConfig} + isGlobal={false} + searchSpaceId={searchSpaceId} + mode={editingConfig ? "edit" : "create"} + /> + + <AlertDialog + open={!!configToDelete} + onOpenChange={(open) => !open && setConfigToDelete(null)} + > + <AlertDialogContent className="select-none"> + <AlertDialogHeader> + <AlertDialogTitle>Delete Vision Model</AlertDialogTitle> + <AlertDialogDescription> + Are you sure you want to delete{" "} + <span className="font-semibold text-foreground">{configToDelete?.name}</span>? + </AlertDialogDescription> + </AlertDialogHeader> + <AlertDialogFooter> + <AlertDialogCancel disabled={isDeleting}>Cancel</AlertDialogCancel> + <AlertDialogAction + onClick={handleDelete} + disabled={isDeleting} + className="relative bg-destructive text-destructive-foreground hover:bg-destructive/90" + > + <span className={isDeleting ? "opacity-0" : ""}>Delete</span> + {isDeleting && <Spinner size="sm" className="absolute" />} + </AlertDialogAction> + </AlertDialogFooter> + </AlertDialogContent> + </AlertDialog> + </div> + ); +} diff --git a/surfsense_web/components/shared/image-config-dialog.tsx b/surfsense_web/components/shared/image-config-dialog.tsx new file mode 100644 index 000000000..36d16081a --- /dev/null +++ b/surfsense_web/components/shared/image-config-dialog.tsx @@ -0,0 +1,456 @@ +"use client"; + +import { useAtomValue } from "jotai"; +import { AlertCircle, Check, ChevronsUpDown } from "lucide-react"; +import { useCallback, useEffect, useMemo, useRef, useState } from "react"; +import { toast } from "sonner"; +import { + createImageGenConfigMutationAtom, + updateImageGenConfigMutationAtom, +} from "@/atoms/image-gen-config/image-gen-config-mutation.atoms"; +import { updateLLMPreferencesMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; +import { Alert, AlertDescription } from "@/components/ui/alert"; +import { Badge } from "@/components/ui/badge"; +import { Button } from "@/components/ui/button"; +import { + Command, + CommandEmpty, + CommandGroup, + CommandInput, + CommandItem, + CommandList, +} from "@/components/ui/command"; +import { Dialog, DialogContent, DialogTitle } from "@/components/ui/dialog"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import { Separator } from "@/components/ui/separator"; +import { Spinner } from "@/components/ui/spinner"; +import { IMAGE_GEN_MODELS, IMAGE_GEN_PROVIDERS } from "@/contracts/enums/image-gen-providers"; +import type { + GlobalImageGenConfig, + ImageGenerationConfig, + ImageGenProvider, +} from "@/contracts/types/new-llm-config.types"; +import { cn } from "@/lib/utils"; + +interface ImageConfigDialogProps { + open: boolean; + onOpenChange: (open: boolean) => void; + config: ImageGenerationConfig | GlobalImageGenConfig | null; + isGlobal: boolean; + searchSpaceId: number; + mode: "create" | "edit" | "view"; + defaultProvider?: string; +} + +const INITIAL_FORM = { + name: "", + description: "", + provider: "", + model_name: "", + api_key: "", + api_base: "", + api_version: "", +}; + +export function ImageConfigDialog({ + open, + onOpenChange, + config, + isGlobal, + searchSpaceId, + mode, + defaultProvider, +}: ImageConfigDialogProps) { + const [isSubmitting, setIsSubmitting] = useState(false); + const [formData, setFormData] = useState(INITIAL_FORM); + const [modelComboboxOpen, setModelComboboxOpen] = useState(false); + const [scrollPos, setScrollPos] = useState<"top" | "middle" | "bottom">("top"); + const scrollRef = useRef<HTMLDivElement>(null); + + useEffect(() => { + if (open) { + if (mode === "edit" && config && !isGlobal) { + setFormData({ + name: config.name || "", + description: config.description || "", + provider: config.provider || "", + model_name: config.model_name || "", + api_key: (config as ImageGenerationConfig).api_key || "", + api_base: config.api_base || "", + api_version: config.api_version || "", + }); + } else if (mode === "create") { + setFormData({ ...INITIAL_FORM, provider: defaultProvider ?? "" }); + } + setScrollPos("top"); + } + }, [open, mode, config, isGlobal, defaultProvider]); + + const { mutateAsync: createConfig } = useAtomValue(createImageGenConfigMutationAtom); + const { mutateAsync: updateConfig } = useAtomValue(updateImageGenConfigMutationAtom); + const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom); + + const handleScroll = useCallback((e: React.UIEvent<HTMLDivElement>) => { + const el = e.currentTarget; + const atTop = el.scrollTop <= 2; + const atBottom = el.scrollHeight - el.scrollTop - el.clientHeight <= 2; + setScrollPos(atTop ? "top" : atBottom ? "bottom" : "middle"); + }, []); + + const suggestedModels = useMemo(() => { + if (!formData.provider) return []; + return IMAGE_GEN_MODELS.filter((m) => m.provider === formData.provider); + }, [formData.provider]); + + const getTitle = () => { + if (mode === "create") return "Add Image Model"; + if (isGlobal) return "View Global Image Model"; + return "Edit Image Model"; + }; + + const getSubtitle = () => { + if (mode === "create") return "Set up a new image generation provider"; + if (isGlobal) return "Read-only global configuration"; + return "Update your image model settings"; + }; + + const handleSubmit = useCallback(async () => { + setIsSubmitting(true); + try { + if (mode === "create") { + const result = await createConfig({ + name: formData.name, + provider: formData.provider as ImageGenProvider, + model_name: formData.model_name, + api_key: formData.api_key, + api_base: formData.api_base || undefined, + api_version: formData.api_version || undefined, + description: formData.description || undefined, + search_space_id: searchSpaceId, + }); + if (result?.id) { + await updatePreferences({ + search_space_id: searchSpaceId, + data: { image_generation_config_id: result.id }, + }); + } + onOpenChange(false); + } else if (!isGlobal && config) { + await updateConfig({ + id: config.id, + data: { + name: formData.name, + description: formData.description || undefined, + provider: formData.provider as ImageGenProvider, + model_name: formData.model_name, + api_key: formData.api_key, + api_base: formData.api_base || undefined, + api_version: formData.api_version || undefined, + }, + }); + onOpenChange(false); + } + } catch (error) { + console.error("Failed to save image config:", error); + toast.error("Failed to save image model"); + } finally { + setIsSubmitting(false); + } + }, [ + mode, + isGlobal, + config, + formData, + searchSpaceId, + createConfig, + updateConfig, + updatePreferences, + onOpenChange, + ]); + + const handleUseGlobalConfig = useCallback(async () => { + if (!config || !isGlobal) return; + setIsSubmitting(true); + try { + await updatePreferences({ + search_space_id: searchSpaceId, + data: { image_generation_config_id: config.id }, + }); + toast.success(`Now using ${config.name}`); + onOpenChange(false); + } catch (error) { + console.error("Failed to set image model:", error); + toast.error("Failed to set image model"); + } finally { + setIsSubmitting(false); + } + }, [config, isGlobal, searchSpaceId, updatePreferences, onOpenChange]); + + const isFormValid = formData.name && formData.provider && formData.model_name && formData.api_key; + const selectedProvider = IMAGE_GEN_PROVIDERS.find((p) => p.value === formData.provider); + + return ( + <Dialog open={open} onOpenChange={onOpenChange}> + <DialogContent + className="max-w-lg h-[85vh] flex flex-col p-0 gap-0 overflow-hidden" + onOpenAutoFocus={(e) => e.preventDefault()} + > + <DialogTitle className="sr-only">{getTitle()}</DialogTitle> + + {/* Header */} + <div className="flex items-start justify-between px-6 pt-6 pb-4 pr-14"> + <div className="space-y-1"> + <div className="flex items-center gap-2"> + <h2 className="text-lg font-semibold tracking-tight">{getTitle()}</h2> + {isGlobal && mode !== "create" && ( + <Badge variant="secondary" className="text-[10px]"> + Global + </Badge> + )} + </div> + <p className="text-sm text-muted-foreground">{getSubtitle()}</p> + {config && mode !== "create" && ( + <p className="text-xs font-mono text-muted-foreground/70">{config.model_name}</p> + )} + </div> + </div> + + {/* Scrollable content */} + <div + ref={scrollRef} + onScroll={handleScroll} + className="flex-1 overflow-y-auto px-6 py-5" + style={{ + maskImage: `linear-gradient(to bottom, ${scrollPos === "top" ? "black" : "transparent"}, black 16px, black calc(100% - 16px), ${scrollPos === "bottom" ? "black" : "transparent"})`, + WebkitMaskImage: `linear-gradient(to bottom, ${scrollPos === "top" ? "black" : "transparent"}, black 16px, black calc(100% - 16px), ${scrollPos === "bottom" ? "black" : "transparent"})`, + }} + > + {isGlobal && config && ( + <> + <Alert className="mb-5 border-amber-500/30 bg-amber-500/5"> + <AlertCircle className="size-4 text-amber-500" /> + <AlertDescription className="text-sm text-amber-700 dark:text-amber-400"> + Global configurations are read-only. To customize, create a new model. + </AlertDescription> + </Alert> + <div className="space-y-4"> + <div className="grid gap-4 sm:grid-cols-2"> + <div className="space-y-1.5"> + <div className="text-xs font-medium text-muted-foreground uppercase tracking-wider"> + Name + </div> + <p className="text-sm font-medium">{config.name}</p> + </div> + {config.description && ( + <div className="space-y-1.5"> + <div className="text-xs font-medium text-muted-foreground uppercase tracking-wider"> + Description + </div> + <p className="text-sm text-muted-foreground">{config.description}</p> + </div> + )} + </div> + <Separator /> + <div className="grid gap-4 sm:grid-cols-2"> + <div className="space-y-1.5"> + <div className="text-xs font-medium text-muted-foreground uppercase tracking-wider"> + Provider + </div> + <p className="text-sm font-medium">{config.provider}</p> + </div> + <div className="space-y-1.5"> + <div className="text-xs font-medium text-muted-foreground uppercase tracking-wider"> + Model + </div> + <p className="text-sm font-medium font-mono">{config.model_name}</p> + </div> + </div> + </div> + </> + )} + + {(mode === "create" || (mode === "edit" && !isGlobal)) && ( + <div className="space-y-4"> + <div className="space-y-2"> + <Label className="text-sm font-medium">Name *</Label> + <Input + placeholder="e.g., My DALL-E 3" + value={formData.name} + onChange={(e) => setFormData((p) => ({ ...p, name: e.target.value }))} + /> + </div> + + <div className="space-y-2"> + <Label className="text-sm font-medium">Description</Label> + <Input + placeholder="Optional description" + value={formData.description} + onChange={(e) => setFormData((p) => ({ ...p, description: e.target.value }))} + /> + </div> + + <Separator className="bg-popover-border" /> + + <div className="space-y-2"> + <Label className="text-sm font-medium">Provider *</Label> + <Select + value={formData.provider} + onValueChange={(val) => + setFormData((p) => ({ ...p, provider: val, model_name: "" })) + } + > + <SelectTrigger> + <SelectValue placeholder="Select a provider" /> + </SelectTrigger> + <SelectContent> + {IMAGE_GEN_PROVIDERS.map((p) => ( + <SelectItem key={p.value} value={p.value} description={p.example}> + {p.label} + </SelectItem> + ))} + </SelectContent> + </Select> + </div> + + <div className="space-y-2"> + <Label className="text-sm font-medium">Model Name *</Label> + {suggestedModels.length > 0 ? ( + <Popover open={modelComboboxOpen} onOpenChange={setModelComboboxOpen}> + <PopoverTrigger asChild> + <Button + variant="outline" + role="combobox" + className="w-full justify-between font-normal bg-transparent" + > + {formData.model_name || "Select or type a model..."} + <ChevronsUpDown className="ml-2 h-4 w-4 shrink-0 opacity-50" /> + </Button> + </PopoverTrigger> + <PopoverContent className="w-full p-0" align="start"> + <Command className="bg-transparent"> + <CommandInput + placeholder="Search or type model..." + value={formData.model_name} + onValueChange={(val) => setFormData((p) => ({ ...p, model_name: val }))} + /> + <CommandList> + <CommandEmpty> + <span className="text-xs text-muted-foreground"> + Type a custom model name + </span> + </CommandEmpty> + <CommandGroup> + {suggestedModels.map((m) => ( + <CommandItem + key={m.value} + value={m.value} + onSelect={() => { + setFormData((p) => ({ ...p, model_name: m.value })); + setModelComboboxOpen(false); + }} + > + <Check + className={cn( + "mr-2 h-4 w-4", + formData.model_name === m.value ? "opacity-100" : "opacity-0" + )} + /> + <span className="font-mono text-sm">{m.value}</span> + <span className="ml-2 text-xs text-muted-foreground"> + {m.label} + </span> + </CommandItem> + ))} + </CommandGroup> + </CommandList> + </Command> + </PopoverContent> + </Popover> + ) : ( + <Input + placeholder="e.g., dall-e-3" + value={formData.model_name} + onChange={(e) => setFormData((p) => ({ ...p, model_name: e.target.value }))} + /> + )} + </div> + + <div className="space-y-2"> + <Label className="text-sm font-medium">API Key *</Label> + <Input + type="password" + placeholder="sk-..." + value={formData.api_key} + onChange={(e) => setFormData((p) => ({ ...p, api_key: e.target.value }))} + /> + </div> + + <div className="space-y-2"> + <Label className="text-sm font-medium">API Base URL</Label> + <Input + placeholder={selectedProvider?.apiBase || "Optional"} + value={formData.api_base} + onChange={(e) => setFormData((p) => ({ ...p, api_base: e.target.value }))} + /> + </div> + + {formData.provider === "AZURE_OPENAI" && ( + <div className="space-y-2"> + <Label className="text-sm font-medium">API Version (Azure)</Label> + <Input + placeholder="2024-02-15-preview" + value={formData.api_version} + onChange={(e) => setFormData((p) => ({ ...p, api_version: e.target.value }))} + /> + </div> + )} + </div> + )} + </div> + + {/* Fixed footer */} + <div className="shrink-0 px-6 py-4 flex items-center justify-end gap-3"> + <Button + type="button" + variant="secondary" + onClick={() => onOpenChange(false)} + disabled={isSubmitting} + className="text-sm h-9" + > + Cancel + </Button> + {mode === "create" || (mode === "edit" && !isGlobal) ? ( + <Button + onClick={handleSubmit} + disabled={isSubmitting || !isFormValid} + className="relative text-sm h-9 min-w-[120px]" + > + <span className={isSubmitting ? "opacity-0" : ""}> + {mode === "edit" ? "Save Changes" : "Add Model"} + </span> + {isSubmitting && <Spinner size="sm" className="absolute" />} + </Button> + ) : isGlobal && config ? ( + <Button + className="relative text-sm h-9" + onClick={handleUseGlobalConfig} + disabled={isSubmitting} + > + <span className={isSubmitting ? "opacity-0" : ""}>Use This Model</span> + {isSubmitting && <Spinner size="sm" className="absolute" />} + </Button> + ) : null} + </div> + </DialogContent> + </Dialog> + ); +} diff --git a/surfsense_web/components/shared/llm-config-form.tsx b/surfsense_web/components/shared/llm-config-form.tsx new file mode 100644 index 000000000..06de4129b --- /dev/null +++ b/surfsense_web/components/shared/llm-config-form.tsx @@ -0,0 +1,527 @@ +"use client"; + +import { zodResolver } from "@hookform/resolvers/zod"; +import { useAtomValue } from "jotai"; +import { Check, ChevronDown, ChevronsUpDown } from "lucide-react"; +import { useEffect, useMemo, useState } from "react"; +import { type Resolver, useForm } from "react-hook-form"; +import { z } from "zod"; +import { + defaultSystemInstructionsAtom, + modelListAtom, +} from "@/atoms/new-llm-config/new-llm-config-query.atoms"; +import { Badge } from "@/components/ui/badge"; +import { Button } from "@/components/ui/button"; +import { Collapsible, CollapsibleContent, CollapsibleTrigger } from "@/components/ui/collapsible"; +import { + Command, + CommandEmpty, + CommandGroup, + CommandInput, + CommandItem, + CommandList, +} from "@/components/ui/command"; +import { + Form, + FormControl, + FormDescription, + FormField, + FormItem, + FormLabel, + FormMessage, +} from "@/components/ui/form"; +import { Input } from "@/components/ui/input"; +import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import { Separator } from "@/components/ui/separator"; +import { Switch } from "@/components/ui/switch"; +import { Textarea } from "@/components/ui/textarea"; +import { LLM_PROVIDERS } from "@/contracts/enums/llm-providers"; +import type { CreateNewLLMConfigRequest } from "@/contracts/types/new-llm-config.types"; +import { cn } from "@/lib/utils"; +import InferenceParamsEditor from "../inference-params-editor"; + +// Form schema with zod +const formSchema = z.object({ + name: z.string().min(1, "Name is required").max(100), + description: z.string().max(500).optional().nullable(), + provider: z.string().min(1, "Provider is required"), + custom_provider: z.string().max(100).optional().nullable(), + model_name: z.string().min(1, "Model name is required").max(100), + api_key: z.string().min(1, "API key is required"), + api_base: z.string().max(500).optional().nullable(), + litellm_params: z.record(z.string(), z.any()).optional().nullable(), + system_instructions: z.string().default(""), + use_default_system_instructions: z.boolean().default(true), + citations_enabled: z.boolean().default(true), + search_space_id: z.number(), +}); + +type FormValues = z.infer<typeof formSchema>; + +export type LLMConfigFormData = CreateNewLLMConfigRequest; + +interface LLMConfigFormProps { + initialData?: Partial<LLMConfigFormData>; + searchSpaceId: number; + onSubmit: (data: LLMConfigFormData) => Promise<void>; + mode?: "create" | "edit"; + showAdvanced?: boolean; + formId?: string; +} + +export function LLMConfigForm({ + initialData, + searchSpaceId, + onSubmit, + mode = "create", + showAdvanced = true, + formId, +}: LLMConfigFormProps) { + const { data: defaultInstructions, isSuccess: defaultInstructionsLoaded } = useAtomValue( + defaultSystemInstructionsAtom + ); + const { data: dynamicModels } = useAtomValue(modelListAtom); + const [modelComboboxOpen, setModelComboboxOpen] = useState(false); + const [advancedOpen, setAdvancedOpen] = useState(false); + const [systemInstructionsOpen, setSystemInstructionsOpen] = useState(false); + + const form = useForm<FormValues>({ + resolver: zodResolver(formSchema) as Resolver<FormValues>, + defaultValues: { + name: initialData?.name ?? "", + description: initialData?.description ?? "", + provider: initialData?.provider ?? "", + custom_provider: initialData?.custom_provider ?? "", + model_name: initialData?.model_name ?? "", + api_key: initialData?.api_key ?? "", + api_base: initialData?.api_base ?? "", + litellm_params: initialData?.litellm_params ?? {}, + system_instructions: initialData?.system_instructions ?? "", + use_default_system_instructions: initialData?.use_default_system_instructions ?? true, + citations_enabled: initialData?.citations_enabled ?? true, + search_space_id: searchSpaceId, + }, + }); + + // Load default instructions when available (only for new configs) + useEffect(() => { + if ( + mode === "create" && + defaultInstructionsLoaded && + defaultInstructions?.default_system_instructions && + !form.getValues("system_instructions") + ) { + form.setValue("system_instructions", defaultInstructions.default_system_instructions); + } + }, [defaultInstructionsLoaded, defaultInstructions, mode, form]); + + const watchProvider = form.watch("provider"); + const selectedProvider = LLM_PROVIDERS.find((p) => p.value === watchProvider); + const availableModels = useMemo( + () => (dynamicModels ?? []).filter((m) => m.provider === watchProvider), + [dynamicModels, watchProvider] + ); + + const handleProviderChange = (value: string) => { + form.setValue("provider", value); + form.setValue("model_name", ""); + + // Auto-fill API base for certain providers + const provider = LLM_PROVIDERS.find((p) => p.value === value); + if (provider?.apiBase) { + form.setValue("api_base", provider.apiBase); + } + }; + + const handleFormSubmit = async (values: FormValues) => { + await onSubmit(values as LLMConfigFormData); + }; + + return ( + <Form {...form}> + <form id={formId} onSubmit={form.handleSubmit(handleFormSubmit)} className="space-y-6"> + {/* Model Configuration Section */} + <div className="space-y-4"> + <div className="text-xs sm:text-sm font-medium text-muted-foreground"> + Model Configuration + </div> + + {/* Name & Description */} + <div className="grid gap-4 sm:grid-cols-2"> + <FormField + control={form.control} + name="name" + render={({ field }) => ( + <FormItem> + <FormLabel className="text-xs sm:text-sm">Configuration Name</FormLabel> + <FormControl> + <Input placeholder="e.g., My GPT-4 Agent" {...field} /> + </FormControl> + <FormMessage /> + </FormItem> + )} + /> + + <FormField + control={form.control} + name="description" + render={({ field }) => ( + <FormItem> + <FormLabel className="text-muted-foreground text-xs sm:text-sm"> + Description + <Badge variant="outline" className="ml-2 text-[10px]"> + Optional + </Badge> + </FormLabel> + <FormControl> + <Input placeholder="Brief description" {...field} value={field.value ?? ""} /> + </FormControl> + <FormMessage /> + </FormItem> + )} + /> + </div> + + {/* Provider Selection */} + <FormField + control={form.control} + name="provider" + render={({ field }) => ( + <FormItem> + <FormLabel className="text-xs sm:text-sm">LLM Provider</FormLabel> + <Select value={field.value} onValueChange={handleProviderChange}> + <FormControl> + <SelectTrigger> + <SelectValue placeholder="Select a provider" /> + </SelectTrigger> + </FormControl> + <SelectContent className="max-h-[300px]"> + {LLM_PROVIDERS.map((provider) => ( + <SelectItem + key={provider.value} + value={provider.value} + description={provider.description} + > + {provider.label} + </SelectItem> + ))} + </SelectContent> + </Select> + <FormMessage /> + </FormItem> + )} + /> + + {/* Custom Provider (conditional) */} + {watchProvider === "CUSTOM" && ( + <FormField + control={form.control} + name="custom_provider" + render={({ field }) => ( + <FormItem> + <FormLabel className="text-xs sm:text-sm">Custom Provider Name</FormLabel> + <FormControl> + <Input placeholder="my-custom-provider" {...field} value={field.value ?? ""} /> + </FormControl> + <FormMessage /> + </FormItem> + )} + /> + )} + + {/* Model Name with Combobox */} + <FormField + control={form.control} + name="model_name" + render={({ field }) => ( + <FormItem className="flex flex-col"> + <FormLabel className="text-xs sm:text-sm">Model Name</FormLabel> + <Popover open={modelComboboxOpen} onOpenChange={setModelComboboxOpen}> + <PopoverTrigger asChild> + <FormControl> + <Button + variant="outline" + role="combobox" + aria-expanded={modelComboboxOpen} + className={cn( + "w-full justify-between border-popover-border bg-transparent font-normal", + !field.value && "text-muted-foreground" + )} + > + {field.value || "Select a model"} + <ChevronsUpDown className="ml-2 h-4 w-4 shrink-0 opacity-50" /> + </Button> + </FormControl> + </PopoverTrigger> + <PopoverContent className="w-full p-0" align="start"> + <Command shouldFilter={false} className="bg-transparent"> + <CommandInput + placeholder={selectedProvider?.example || "Search model name"} + value={field.value} + onValueChange={field.onChange} + /> + <CommandList className="max-h-[300px]"> + <CommandEmpty> + <div className="py-3 text-center text-sm text-muted-foreground"> + {field.value ? `Using: "${field.value}"` : "Type your model name"} + </div> + </CommandEmpty> + {availableModels.length > 0 && ( + <CommandGroup heading="Suggested Models"> + {availableModels + .filter( + (model) => + !field.value || + model.value.toLowerCase().includes(field.value.toLowerCase()) || + model.label.toLowerCase().includes(field.value.toLowerCase()) + ) + .slice(0, 50) + .map((model) => ( + <CommandItem + key={model.value} + value={model.value} + onSelect={(value) => { + field.onChange(value); + setModelComboboxOpen(false); + }} + className="py-2" + > + <Check + className={cn( + "mr-2 h-4 w-4", + field.value === model.value ? "opacity-100" : "opacity-0" + )} + /> + <div> + <div className="font-medium">{model.label}</div> + {model.contextWindow && ( + <div className="text-xs text-muted-foreground"> + Context: {model.contextWindow} + </div> + )} + </div> + </CommandItem> + ))} + </CommandGroup> + )} + </CommandList> + </Command> + </PopoverContent> + </Popover> + {selectedProvider?.example && ( + <FormDescription className="text-[10px] sm:text-xs"> + Example: {selectedProvider.example} + </FormDescription> + )} + <FormMessage /> + </FormItem> + )} + /> + + {/* API Credentials */} + <div className="grid gap-4 sm:grid-cols-2"> + <FormField + control={form.control} + name="api_key" + render={({ field }) => ( + <FormItem> + <FormLabel className="text-xs sm:text-sm">API Key</FormLabel> + <FormControl> + <Input + type="password" + placeholder={watchProvider === "OLLAMA" ? "Any value" : "sk-..."} + {...field} + /> + </FormControl> + {watchProvider === "OLLAMA" && ( + <FormDescription className="text-[10px] sm:text-xs"> + Ollama doesn't require auth — enter any value + </FormDescription> + )} + <FormMessage /> + </FormItem> + )} + /> + + <FormField + control={form.control} + name="api_base" + render={({ field }) => ( + <FormItem> + <FormLabel className="flex items-center gap-2 text-xs sm:text-sm"> + API Base URL + {selectedProvider?.apiBase && ( + <Badge variant="secondary" className="text-[10px]"> + Auto-filled + </Badge> + )} + </FormLabel> + <FormControl> + <Input + placeholder={selectedProvider?.apiBase || "https://api.example.com/v1"} + {...field} + value={field.value ?? ""} + /> + </FormControl> + <FormMessage /> + </FormItem> + )} + /> + </div> + + {/* Ollama Quick Actions */} + {watchProvider === "OLLAMA" && ( + <div className="flex flex-wrap gap-2"> + <Button + type="button" + variant="outline" + size="sm" + className="h-7 text-xs" + onClick={() => form.setValue("api_base", "http://localhost:11434")} + > + localhost:11434 + </Button> + <Button + type="button" + variant="outline" + size="sm" + className="h-7 text-xs" + onClick={() => form.setValue("api_base", "http://host.docker.internal:11434")} + > + Docker + </Button> + </div> + )} + </div> + + {/* Advanced Parameters */} + {showAdvanced && ( + <> + <Separator className="bg-popover-border" /> + <Collapsible open={advancedOpen} onOpenChange={setAdvancedOpen}> + <CollapsibleTrigger asChild> + <Button + type="button" + variant="ghost" + className="h-auto w-full justify-between px-0 py-2 text-xs font-medium text-muted-foreground hover:bg-transparent hover:text-accent-foreground sm:text-sm" + > + <span>Advanced Parameters</span> + <ChevronDown + className={cn( + "h-4 w-4 transition-transform duration-200", + advancedOpen && "rotate-180" + )} + /> + </Button> + </CollapsibleTrigger> + <CollapsibleContent className="space-y-4 pt-2"> + <FormField + control={form.control} + name="litellm_params" + render={({ field }) => ( + <FormItem> + <FormControl> + <InferenceParamsEditor + params={field.value || {}} + setParams={field.onChange} + /> + </FormControl> + <FormMessage /> + </FormItem> + )} + /> + </CollapsibleContent> + </Collapsible> + </> + )} + + {/* System Instructions & Citations Section */} + <Separator className="bg-popover-border" /> + <Collapsible open={systemInstructionsOpen} onOpenChange={setSystemInstructionsOpen}> + <CollapsibleTrigger asChild> + <Button + type="button" + variant="ghost" + className="h-auto w-full justify-between px-0 py-2 text-xs font-medium text-muted-foreground hover:bg-transparent hover:text-accent-foreground sm:text-sm" + > + <span>System Instructions</span> + <ChevronDown + className={cn( + "h-4 w-4 transition-transform duration-200", + systemInstructionsOpen && "rotate-180" + )} + /> + </Button> + </CollapsibleTrigger> + <CollapsibleContent className="space-y-4 pt-2"> + {/* System Instructions */} + <FormField + control={form.control} + name="system_instructions" + render={({ field }) => ( + <FormItem> + <div className="flex items-center justify-between"> + <FormLabel className="text-xs sm:text-sm">Instructions for the AI</FormLabel> + {defaultInstructions && ( + <Button + type="button" + variant="ghost" + size="sm" + onClick={() => + field.onChange(defaultInstructions.default_system_instructions) + } + className="h-7 text-[10px] sm:text-xs text-muted-foreground hover:text-accent-foreground" + > + Reset to Default + </Button> + )} + </div> + <FormControl> + <Textarea + placeholder="Enter system instructions for the AI..." + rows={6} + className="font-mono text-[11px] sm:text-xs resize-none" + {...field} + /> + </FormControl> + <FormDescription className="text-[10px] sm:text-xs"> + Use {"{resolved_today}"} to include today's date dynamically + </FormDescription> + <FormMessage /> + </FormItem> + )} + /> + + {/* Citations Toggle */} + <FormField + control={form.control} + name="citations_enabled" + render={({ field }) => ( + <FormItem className="flex items-center justify-between rounded-lg border p-3 bg-muted/30"> + <div className="space-y-0.5"> + <FormLabel className="text-xs sm:text-sm font-medium"> + Enable Citations + </FormLabel> + <FormDescription className="text-[10px] sm:text-xs"> + Include [citation:id] references to source documents + </FormDescription> + </div> + <FormControl> + <Switch checked={field.value} onCheckedChange={field.onChange} /> + </FormControl> + </FormItem> + )} + /> + </CollapsibleContent> + </Collapsible> + </form> + </Form> + ); +} diff --git a/surfsense_web/components/shared/model-config-dialog.tsx b/surfsense_web/components/shared/model-config-dialog.tsx new file mode 100644 index 000000000..d4f57ff7d --- /dev/null +++ b/surfsense_web/components/shared/model-config-dialog.tsx @@ -0,0 +1,339 @@ +"use client"; + +import { useAtomValue } from "jotai"; +import { AlertCircle } from "lucide-react"; +import { useCallback, useRef, useState } from "react"; +import { toast } from "sonner"; +import { + createNewLLMConfigMutationAtom, + updateLLMPreferencesMutationAtom, + updateNewLLMConfigMutationAtom, +} from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; +import { LLMConfigForm, type LLMConfigFormData } from "@/components/shared/llm-config-form"; +import { Alert, AlertDescription } from "@/components/ui/alert"; +import { Badge } from "@/components/ui/badge"; +import { Button } from "@/components/ui/button"; +import { Dialog, DialogContent, DialogTitle } from "@/components/ui/dialog"; +import { Spinner } from "@/components/ui/spinner"; +import type { + GlobalNewLLMConfig, + LiteLLMProvider, + NewLLMConfigPublic, +} from "@/contracts/types/new-llm-config.types"; + +interface ModelConfigDialogProps { + open: boolean; + onOpenChange: (open: boolean) => void; + config: NewLLMConfigPublic | GlobalNewLLMConfig | null; + isGlobal: boolean; + searchSpaceId: number; + mode: "create" | "edit" | "view"; + defaultProvider?: string; +} + +export function ModelConfigDialog({ + open, + onOpenChange, + config, + isGlobal, + searchSpaceId, + mode, + defaultProvider, +}: ModelConfigDialogProps) { + const [isSubmitting, setIsSubmitting] = useState(false); + const [scrollPos, setScrollPos] = useState<"top" | "middle" | "bottom">("top"); + const scrollRef = useRef<HTMLDivElement>(null); + + const handleScroll = useCallback((e: React.UIEvent<HTMLDivElement>) => { + const el = e.currentTarget; + const atTop = el.scrollTop <= 2; + const atBottom = el.scrollHeight - el.scrollTop - el.clientHeight <= 2; + setScrollPos(atTop ? "top" : atBottom ? "bottom" : "middle"); + }, []); + + const { mutateAsync: createConfig } = useAtomValue(createNewLLMConfigMutationAtom); + const { mutateAsync: updateConfig } = useAtomValue(updateNewLLMConfigMutationAtom); + const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom); + + const getTitle = () => { + if (mode === "create") return "Add New Configuration"; + if (isGlobal) return "View Global Configuration"; + return "Edit Configuration"; + }; + + const getSubtitle = () => { + if (mode === "create") return "Set up a new LLM provider for this search space"; + if (isGlobal) return "Read-only global configuration"; + return "Update your configuration settings"; + }; + + const handleSubmit = useCallback( + async (data: LLMConfigFormData) => { + setIsSubmitting(true); + try { + if (mode === "create") { + const result = await createConfig({ + ...data, + search_space_id: searchSpaceId, + }); + + if (result?.id) { + await updatePreferences({ + search_space_id: searchSpaceId, + data: { + agent_llm_id: result.id, + }, + }); + } + + onOpenChange(false); + } else if (!isGlobal && config) { + await updateConfig({ + id: config.id, + data: { + name: data.name, + description: data.description, + provider: data.provider, + custom_provider: data.custom_provider, + model_name: data.model_name, + api_key: data.api_key, + api_base: data.api_base, + litellm_params: data.litellm_params, + system_instructions: data.system_instructions, + use_default_system_instructions: data.use_default_system_instructions, + citations_enabled: data.citations_enabled, + }, + }); + onOpenChange(false); + } + } catch (error) { + console.error("Failed to save configuration:", error); + } finally { + setIsSubmitting(false); + } + }, + [ + mode, + isGlobal, + config, + searchSpaceId, + createConfig, + updateConfig, + updatePreferences, + onOpenChange, + ] + ); + + const handleUseGlobalConfig = useCallback(async () => { + if (!config || !isGlobal) return; + setIsSubmitting(true); + try { + await updatePreferences({ + search_space_id: searchSpaceId, + data: { + agent_llm_id: config.id, + }, + }); + toast.success(`Now using ${config.name}`); + onOpenChange(false); + } catch (error) { + console.error("Failed to set model:", error); + } finally { + setIsSubmitting(false); + } + }, [config, isGlobal, searchSpaceId, updatePreferences, onOpenChange]); + + return ( + <Dialog open={open} onOpenChange={onOpenChange}> + <DialogContent + className="max-w-lg h-[85vh] flex flex-col p-0 gap-0 overflow-hidden" + onOpenAutoFocus={(e) => e.preventDefault()} + > + <DialogTitle className="sr-only">{getTitle()}</DialogTitle> + + {/* Header */} + <div className="flex items-start justify-between px-6 pt-6 pb-4 pr-14"> + <div className="space-y-1"> + <div className="flex items-center gap-2"> + <h2 className="text-lg font-semibold tracking-tight">{getTitle()}</h2> + {isGlobal && mode !== "create" && ( + <Badge variant="secondary" className="text-[10px]"> + Global + </Badge> + )} + {!isGlobal && mode !== "create" && ( + <Badge variant="outline" className="text-[10px]"> + Custom + </Badge> + )} + </div> + <p className="text-sm text-muted-foreground">{getSubtitle()}</p> + {config && mode !== "create" && ( + <p className="text-xs font-mono text-muted-foreground/70">{config.model_name}</p> + )} + </div> + </div> + + {/* Scrollable content */} + <div + ref={scrollRef} + onScroll={handleScroll} + className="flex-1 overflow-y-auto px-6 py-5" + style={{ + maskImage: `linear-gradient(to bottom, ${scrollPos === "top" ? "black" : "transparent"}, black 16px, black calc(100% - 16px), ${scrollPos === "bottom" ? "black" : "transparent"})`, + WebkitMaskImage: `linear-gradient(to bottom, ${scrollPos === "top" ? "black" : "transparent"}, black 16px, black calc(100% - 16px), ${scrollPos === "bottom" ? "black" : "transparent"})`, + }} + > + {isGlobal && mode !== "create" && ( + <Alert className="mb-5 border-amber-500/30 bg-amber-500/5"> + <AlertCircle className="size-4 text-amber-500" /> + <AlertDescription className="text-sm text-amber-700 dark:text-amber-400"> + Global configurations are read-only. To customize settings, create a new + configuration based on this template. + </AlertDescription> + </Alert> + )} + + {mode === "create" ? ( + <LLMConfigForm + key={defaultProvider ?? "no-provider"} + searchSpaceId={searchSpaceId} + onSubmit={handleSubmit} + mode="create" + formId="model-config-form" + initialData={ + defaultProvider ? { provider: defaultProvider as LiteLLMProvider } : undefined + } + /> + ) : isGlobal && config ? ( + <div className="space-y-6"> + <div className="space-y-4"> + <div className="grid gap-4 sm:grid-cols-2"> + <div className="space-y-1.5"> + <div className="text-xs font-medium text-muted-foreground uppercase tracking-wider"> + Configuration Name + </div> + <p className="text-sm font-medium">{config.name}</p> + </div> + {config.description && ( + <div className="space-y-1.5"> + <div className="text-xs font-medium text-muted-foreground uppercase tracking-wider"> + Description + </div> + <p className="text-sm text-muted-foreground">{config.description}</p> + </div> + )} + </div> + + <div className="h-px bg-border/50" /> + + <div className="grid gap-4 sm:grid-cols-2"> + <div className="space-y-1.5"> + <div className="text-xs font-medium text-muted-foreground uppercase tracking-wider"> + Provider + </div> + <p className="text-sm font-medium">{config.provider}</p> + </div> + <div className="space-y-1.5"> + <div className="text-xs font-medium text-muted-foreground uppercase tracking-wider"> + Model + </div> + <p className="text-sm font-medium font-mono">{config.model_name}</p> + </div> + </div> + + <div className="h-px bg-border/50" /> + + <div className="grid gap-4 sm:grid-cols-2"> + <div className="space-y-2"> + <div className="text-xs font-medium text-muted-foreground uppercase tracking-wider"> + Citations + </div> + <Badge + variant={config.citations_enabled ? "default" : "secondary"} + className="w-fit" + > + {config.citations_enabled ? "Enabled" : "Disabled"} + </Badge> + </div> + </div> + + {config.system_instructions && ( + <> + <div className="h-px bg-border/50" /> + <div className="space-y-1.5"> + <div className="text-xs font-medium text-muted-foreground uppercase tracking-wider"> + System Instructions + </div> + <div className="p-3 rounded-lg bg-muted/50 border border-border/50"> + <p className="text-xs font-mono text-muted-foreground whitespace-pre-wrap line-clamp-10"> + {config.system_instructions} + </p> + </div> + </div> + </> + )} + </div> + </div> + ) : config ? ( + <LLMConfigForm + searchSpaceId={searchSpaceId} + initialData={{ + name: config.name, + description: config.description, + provider: config.provider as LiteLLMProvider, + custom_provider: config.custom_provider, + model_name: config.model_name, + api_key: "api_key" in config ? (config.api_key as string) : "", + api_base: config.api_base, + litellm_params: config.litellm_params, + system_instructions: config.system_instructions, + use_default_system_instructions: config.use_default_system_instructions, + citations_enabled: config.citations_enabled, + search_space_id: searchSpaceId, + }} + onSubmit={handleSubmit} + mode="edit" + formId="model-config-form" + /> + ) : null} + </div> + + {/* Fixed footer */} + <div className="shrink-0 px-6 py-4 flex items-center justify-end gap-3"> + <Button + type="button" + variant="secondary" + onClick={() => onOpenChange(false)} + disabled={isSubmitting} + className="text-sm h-9" + > + Cancel + </Button> + {mode === "create" || (!isGlobal && config) ? ( + <Button + type="submit" + form="model-config-form" + disabled={isSubmitting} + className="relative text-sm h-9 min-w-[120px]" + > + <span className={isSubmitting ? "opacity-0" : ""}> + {mode === "edit" ? "Save Changes" : "Add Model"} + </span> + {isSubmitting && <Spinner size="sm" className="absolute" />} + </Button> + ) : isGlobal && config ? ( + <Button + className="relative text-sm h-9" + onClick={handleUseGlobalConfig} + disabled={isSubmitting} + > + <span className={isSubmitting ? "opacity-0" : ""}>Use This Model</span> + {isSubmitting && <Spinner size="sm" className="absolute" />} + </Button> + ) : null} + </div> + </DialogContent> + </Dialog> + ); +} diff --git a/surfsense_web/components/shared/vision-config-dialog.tsx b/surfsense_web/components/shared/vision-config-dialog.tsx new file mode 100644 index 000000000..2646f3842 --- /dev/null +++ b/surfsense_web/components/shared/vision-config-dialog.tsx @@ -0,0 +1,478 @@ +"use client"; + +import { useAtomValue } from "jotai"; +import { AlertCircle, Check, ChevronsUpDown } from "lucide-react"; +import { useCallback, useEffect, useMemo, useRef, useState } from "react"; +import { toast } from "sonner"; +import { updateLLMPreferencesMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; +import { + createVisionLLMConfigMutationAtom, + updateVisionLLMConfigMutationAtom, +} from "@/atoms/vision-llm-config/vision-llm-config-mutation.atoms"; +import { visionModelListAtom } from "@/atoms/vision-llm-config/vision-llm-config-query.atoms"; +import { Alert, AlertDescription } from "@/components/ui/alert"; +import { Badge } from "@/components/ui/badge"; +import { Button } from "@/components/ui/button"; +import { + Command, + CommandEmpty, + CommandGroup, + CommandInput, + CommandItem, + CommandList, +} from "@/components/ui/command"; +import { Dialog, DialogContent, DialogTitle } from "@/components/ui/dialog"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import { Separator } from "@/components/ui/separator"; +import { Spinner } from "@/components/ui/spinner"; +import { VISION_PROVIDERS } from "@/contracts/enums/vision-providers"; +import type { + GlobalVisionLLMConfig, + VisionLLMConfig, + VisionProvider, +} from "@/contracts/types/new-llm-config.types"; +import { cn } from "@/lib/utils"; + +interface VisionConfigDialogProps { + open: boolean; + onOpenChange: (open: boolean) => void; + config: VisionLLMConfig | GlobalVisionLLMConfig | null; + isGlobal: boolean; + searchSpaceId: number; + mode: "create" | "edit" | "view"; + defaultProvider?: string; +} + +const INITIAL_FORM = { + name: "", + description: "", + provider: "", + model_name: "", + api_key: "", + api_base: "", + api_version: "", +}; + +export function VisionConfigDialog({ + open, + onOpenChange, + config, + isGlobal, + searchSpaceId, + mode, + defaultProvider, +}: VisionConfigDialogProps) { + const [isSubmitting, setIsSubmitting] = useState(false); + const [formData, setFormData] = useState(INITIAL_FORM); + const [scrollPos, setScrollPos] = useState<"top" | "middle" | "bottom">("top"); + const scrollRef = useRef<HTMLDivElement>(null); + + useEffect(() => { + if (open) { + if (mode === "edit" && config && !isGlobal) { + setFormData({ + name: config.name || "", + description: config.description || "", + provider: config.provider || "", + model_name: config.model_name || "", + api_key: (config as VisionLLMConfig).api_key || "", + api_base: config.api_base || "", + api_version: (config as VisionLLMConfig).api_version || "", + }); + } else if (mode === "create") { + setFormData({ ...INITIAL_FORM, provider: defaultProvider ?? "" }); + } + setScrollPos("top"); + } + }, [open, mode, config, isGlobal, defaultProvider]); + + const { mutateAsync: createConfig } = useAtomValue(createVisionLLMConfigMutationAtom); + const { mutateAsync: updateConfig } = useAtomValue(updateVisionLLMConfigMutationAtom); + const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom); + + const handleScroll = useCallback((e: React.UIEvent<HTMLDivElement>) => { + const el = e.currentTarget; + const atTop = el.scrollTop <= 2; + const atBottom = el.scrollHeight - el.scrollTop - el.clientHeight <= 2; + setScrollPos(atTop ? "top" : atBottom ? "bottom" : "middle"); + }, []); + + const getTitle = () => { + if (mode === "create") return "Add Vision Model"; + if (isGlobal) return "View Global Vision Model"; + return "Edit Vision Model"; + }; + + const getSubtitle = () => { + if (mode === "create") return "Set up a new vision-capable LLM provider"; + if (isGlobal) return "Read-only global configuration"; + return "Update your vision model settings"; + }; + + const handleSubmit = useCallback(async () => { + setIsSubmitting(true); + try { + if (mode === "create") { + const result = await createConfig({ + name: formData.name, + provider: formData.provider as VisionProvider, + model_name: formData.model_name, + api_key: formData.api_key, + api_base: formData.api_base || undefined, + api_version: formData.api_version || undefined, + description: formData.description || undefined, + search_space_id: searchSpaceId, + }); + if (result?.id) { + await updatePreferences({ + search_space_id: searchSpaceId, + data: { vision_llm_config_id: result.id }, + }); + } + onOpenChange(false); + } else if (!isGlobal && config) { + await updateConfig({ + id: config.id, + data: { + name: formData.name, + description: formData.description || undefined, + provider: formData.provider as VisionProvider, + model_name: formData.model_name, + api_key: formData.api_key, + api_base: formData.api_base || undefined, + api_version: formData.api_version || undefined, + }, + }); + onOpenChange(false); + } + } catch (error) { + console.error("Failed to save vision config:", error); + toast.error("Failed to save vision model"); + } finally { + setIsSubmitting(false); + } + }, [ + mode, + isGlobal, + config, + formData, + searchSpaceId, + createConfig, + updateConfig, + updatePreferences, + onOpenChange, + ]); + + const handleUseGlobalConfig = useCallback(async () => { + if (!config || !isGlobal) return; + setIsSubmitting(true); + try { + await updatePreferences({ + search_space_id: searchSpaceId, + data: { vision_llm_config_id: config.id }, + }); + toast.success(`Now using ${config.name}`); + onOpenChange(false); + } catch (error) { + console.error("Failed to set vision model:", error); + toast.error("Failed to set vision model"); + } finally { + setIsSubmitting(false); + } + }, [config, isGlobal, searchSpaceId, updatePreferences, onOpenChange]); + + const { data: dynamicModels } = useAtomValue(visionModelListAtom); + const [modelComboboxOpen, setModelComboboxOpen] = useState(false); + + const availableModels = useMemo( + () => (dynamicModels ?? []).filter((m) => m.provider === formData.provider), + [dynamicModels, formData.provider] + ); + + const isFormValid = formData.name && formData.provider && formData.model_name && formData.api_key; + const selectedProvider = VISION_PROVIDERS.find((p) => p.value === formData.provider); + + return ( + <Dialog open={open} onOpenChange={onOpenChange}> + <DialogContent + className="max-w-lg h-[85vh] flex flex-col p-0 gap-0 overflow-hidden" + onOpenAutoFocus={(e) => e.preventDefault()} + > + <DialogTitle className="sr-only">{getTitle()}</DialogTitle> + + <div className="flex items-start justify-between px-6 pt-6 pb-4 pr-14"> + <div className="space-y-1"> + <div className="flex items-center gap-2"> + <h2 className="text-lg font-semibold tracking-tight">{getTitle()}</h2> + {isGlobal && mode !== "create" && ( + <Badge variant="secondary" className="text-[10px]"> + Global + </Badge> + )} + </div> + <p className="text-sm text-muted-foreground">{getSubtitle()}</p> + {config && mode !== "create" && ( + <p className="text-xs font-mono text-muted-foreground/70">{config.model_name}</p> + )} + </div> + </div> + + <div + ref={scrollRef} + onScroll={handleScroll} + className="flex-1 overflow-y-auto px-6 py-5" + style={{ + maskImage: `linear-gradient(to bottom, ${scrollPos === "top" ? "black" : "transparent"}, black 16px, black calc(100% - 16px), ${scrollPos === "bottom" ? "black" : "transparent"})`, + WebkitMaskImage: `linear-gradient(to bottom, ${scrollPos === "top" ? "black" : "transparent"}, black 16px, black calc(100% - 16px), ${scrollPos === "bottom" ? "black" : "transparent"})`, + }} + > + {isGlobal && config && ( + <> + <Alert className="mb-5 border-amber-500/30 bg-amber-500/5"> + <AlertCircle className="size-4 text-amber-500" /> + <AlertDescription className="text-sm text-amber-700 dark:text-amber-400"> + Global configurations are read-only. To customize, create a new model. + </AlertDescription> + </Alert> + <div className="space-y-4"> + <div className="grid gap-4 sm:grid-cols-2"> + <div className="space-y-1.5"> + <div className="text-xs font-medium text-muted-foreground uppercase tracking-wider"> + Name + </div> + <p className="text-sm font-medium">{config.name}</p> + </div> + {config.description && ( + <div className="space-y-1.5"> + <div className="text-xs font-medium text-muted-foreground uppercase tracking-wider"> + Description + </div> + <p className="text-sm text-muted-foreground">{config.description}</p> + </div> + )} + </div> + <Separator /> + <div className="grid gap-4 sm:grid-cols-2"> + <div className="space-y-1.5"> + <div className="text-xs font-medium text-muted-foreground uppercase tracking-wider"> + Provider + </div> + <p className="text-sm font-medium">{config.provider}</p> + </div> + <div className="space-y-1.5"> + <div className="text-xs font-medium text-muted-foreground uppercase tracking-wider"> + Model + </div> + <p className="text-sm font-medium font-mono">{config.model_name}</p> + </div> + </div> + </div> + </> + )} + + {(mode === "create" || (mode === "edit" && !isGlobal)) && ( + <div className="space-y-4"> + <div className="space-y-2"> + <Label className="text-sm font-medium">Name *</Label> + <Input + placeholder="e.g., My GPT-4o Vision" + value={formData.name} + onChange={(e) => setFormData((p) => ({ ...p, name: e.target.value }))} + /> + </div> + + <div className="space-y-2"> + <Label className="text-sm font-medium">Description</Label> + <Input + placeholder="Optional description" + value={formData.description} + onChange={(e) => setFormData((p) => ({ ...p, description: e.target.value }))} + /> + </div> + + <Separator className="bg-popover-border" /> + + <div className="space-y-2"> + <Label className="text-sm font-medium">Provider *</Label> + <Select + value={formData.provider} + onValueChange={(val) => + setFormData((p) => ({ ...p, provider: val, model_name: "" })) + } + > + <SelectTrigger> + <SelectValue placeholder="Select a provider" /> + </SelectTrigger> + <SelectContent> + {VISION_PROVIDERS.map((p) => ( + <SelectItem key={p.value} value={p.value} description={p.example}> + {p.label} + </SelectItem> + ))} + </SelectContent> + </Select> + </div> + + <div className="space-y-2"> + <Label className="text-sm font-medium">Model Name *</Label> + <Popover open={modelComboboxOpen} onOpenChange={setModelComboboxOpen}> + <PopoverTrigger asChild> + <Button + variant="outline" + role="combobox" + aria-expanded={modelComboboxOpen} + className={cn( + "w-full justify-between font-normal bg-transparent", + !formData.model_name && "text-muted-foreground" + )} + > + {formData.model_name || "Select a model"} + <ChevronsUpDown className="ml-2 h-4 w-4 shrink-0 opacity-50" /> + </Button> + </PopoverTrigger> + <PopoverContent className="w-full p-0" align="start"> + <Command shouldFilter={false} className="bg-transparent"> + <CommandInput + placeholder={selectedProvider?.example || "Search model name"} + value={formData.model_name} + onValueChange={(val) => setFormData((p) => ({ ...p, model_name: val }))} + /> + <CommandList className="max-h-[300px]"> + <CommandEmpty> + <div className="py-3 text-center text-sm text-muted-foreground"> + {formData.model_name + ? `Using: "${formData.model_name}"` + : "Type your model name"} + </div> + </CommandEmpty> + {availableModels.length > 0 && ( + <CommandGroup heading="Suggested Models"> + {availableModels + .filter( + (model) => + !formData.model_name || + model.value + .toLowerCase() + .includes(formData.model_name.toLowerCase()) || + model.label + .toLowerCase() + .includes(formData.model_name.toLowerCase()) + ) + .slice(0, 50) + .map((model) => ( + <CommandItem + key={model.value} + value={model.value} + onSelect={(value) => { + setFormData((p) => ({ + ...p, + model_name: value, + })); + setModelComboboxOpen(false); + }} + className="py-2" + > + <Check + className={cn( + "mr-2 h-4 w-4", + formData.model_name === model.value + ? "opacity-100" + : "opacity-0" + )} + /> + <div> + <div className="font-medium">{model.label}</div> + {model.contextWindow && ( + <div className="text-xs text-muted-foreground"> + Context: {model.contextWindow} + </div> + )} + </div> + </CommandItem> + ))} + </CommandGroup> + )} + </CommandList> + </Command> + </PopoverContent> + </Popover> + </div> + + <div className="space-y-2"> + <Label className="text-sm font-medium">API Key *</Label> + <Input + type="password" + placeholder="sk-..." + value={formData.api_key} + onChange={(e) => setFormData((p) => ({ ...p, api_key: e.target.value }))} + /> + </div> + + <div className="space-y-2"> + <Label className="text-sm font-medium">API Base URL</Label> + <Input + placeholder={selectedProvider?.apiBase || "Optional"} + value={formData.api_base} + onChange={(e) => setFormData((p) => ({ ...p, api_base: e.target.value }))} + /> + </div> + + {formData.provider === "AZURE_OPENAI" && ( + <div className="space-y-2"> + <Label className="text-sm font-medium">API Version (Azure)</Label> + <Input + placeholder="2024-02-15-preview" + value={formData.api_version} + onChange={(e) => setFormData((p) => ({ ...p, api_version: e.target.value }))} + /> + </div> + )} + </div> + )} + </div> + + <div className="shrink-0 px-6 py-4 flex items-center justify-end gap-3"> + <Button + type="button" + variant="secondary" + onClick={() => onOpenChange(false)} + disabled={isSubmitting} + className="text-sm h-9" + > + Cancel + </Button> + {mode === "create" || (mode === "edit" && !isGlobal) ? ( + <Button + onClick={handleSubmit} + disabled={isSubmitting || !isFormValid} + className="relative text-sm h-9 min-w-[120px]" + > + <span className={isSubmitting ? "opacity-0" : ""}> + {mode === "edit" ? "Save Changes" : "Add Model"} + </span> + {isSubmitting && <Spinner size="sm" className="absolute" />} + </Button> + ) : isGlobal && config ? ( + <Button + className="relative text-sm h-9" + onClick={handleUseGlobalConfig} + disabled={isSubmitting} + > + <span className={isSubmitting ? "opacity-0" : ""}>Use This Model</span> + {isSubmitting && <Spinner size="sm" className="absolute" />} + </Button> + ) : null} + </div> + </DialogContent> + </Dialog> + ); +} diff --git a/surfsense_web/components/sources/DocumentUploadTab.tsx b/surfsense_web/components/sources/DocumentUploadTab.tsx index 8ee203765..3f68f6d64 100644 --- a/surfsense_web/components/sources/DocumentUploadTab.tsx +++ b/surfsense_web/components/sources/DocumentUploadTab.tsx @@ -8,7 +8,6 @@ import { type ChangeEvent, useCallback, useEffect, useMemo, useRef, useState } f import { useDropzone } from "react-dropzone"; import { toast } from "sonner"; import { uploadDocumentMutationAtom } from "@/atoms/documents/document-mutation.atoms"; -import { useRuntimeConfig } from "@/components/providers/runtime-config"; import { Accordion, AccordionContent, @@ -137,7 +136,6 @@ export function DocumentUploadTab({ onAccordionStateChange, }: DocumentUploadTabProps) { const t = useTranslations("upload_documents"); - const { etlService } = useRuntimeConfig(); const [files, setFiles] = useState<FileWithId[]>([]); const [uploadProgress, setUploadProgress] = useState(0); const [accordionValue, setAccordionValue] = useState<string>(""); @@ -162,7 +160,7 @@ export function DocumentUploadTab({ const electronAPI = useElectronAPI(); const isElectron = !!electronAPI?.browseFiles; - const acceptedFileTypes = useMemo(() => getAcceptedFileTypes(etlService), [etlService]); + const acceptedFileTypes = useMemo(() => getAcceptedFileTypes(), []); const supportedExtensions = useMemo( () => getSupportedExtensions(acceptedFileTypes), [acceptedFileTypes] diff --git a/surfsense_web/components/sources/FolderWatchDialog.tsx b/surfsense_web/components/sources/FolderWatchDialog.tsx index 7a64f3835..8c5629276 100644 --- a/surfsense_web/components/sources/FolderWatchDialog.tsx +++ b/surfsense_web/components/sources/FolderWatchDialog.tsx @@ -3,7 +3,6 @@ import { X } from "lucide-react"; import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { toast } from "sonner"; -import { useRuntimeConfig } from "@/components/providers/runtime-config"; import { Button } from "@/components/ui/button"; import { Dialog, @@ -49,7 +48,6 @@ export function FolderWatchDialog({ const [submitting, setSubmitting] = useState(false); const [progress, setProgress] = useState<FolderSyncProgress | null>(null); const abortRef = useRef<AbortController | null>(null); - const { etlService } = useRuntimeConfig(); useEffect(() => { if (open && initialFolder) { @@ -57,10 +55,7 @@ export function FolderWatchDialog({ } }, [open, initialFolder]); - const supportedExtensions = useMemo( - () => Array.from(getSupportedExtensionsSet(undefined, etlService)), - [etlService] - ); + const supportedExtensions = useMemo(() => Array.from(getSupportedExtensionsSet()), []); const handleSelectFolder = useCallback(async () => { const api = window.electronAPI; diff --git a/surfsense_web/components/tool-ui/audio.tsx b/surfsense_web/components/tool-ui/audio.tsx index cf78298b5..aeadae45b 100644 --- a/surfsense_web/components/tool-ui/audio.tsx +++ b/surfsense_web/components/tool-ui/audio.tsx @@ -201,7 +201,7 @@ export function Audio({ id, src, title, durationMs, className }: AudioProps) { <div className="mx-5 h-px bg-border/50" /> <div className="px-5 pt-3 pb-4 space-y-3"> - <div className="space-y-2"> + <div className="space-y-0.5"> <Slider value={[currentTime]} max={duration || 100} diff --git a/surfsense_web/components/tool-ui/automation/create-automation.tsx b/surfsense_web/components/tool-ui/automation/create-automation.tsx index 8775b275b..24e9d66bd 100644 --- a/surfsense_web/components/tool-ui/automation/create-automation.tsx +++ b/surfsense_web/components/tool-ui/automation/create-automation.tsx @@ -2,7 +2,7 @@ import type { ToolCallMessagePartProps } from "@assistant-ui/react"; import { useAtomValue } from "jotai"; -import { AlarmClock, AlertCircle, CornerDownLeftIcon, ExternalLink, Pencil } from "lucide-react"; +import { AlertCircle, CornerDownLeftIcon, ExternalLink, Pencil, Workflow } from "lucide-react"; import Link from "next/link"; import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { @@ -113,7 +113,7 @@ function ApprovalCard({ args, interruptData, onDecision }: ApprovalCardProps) { const searchSpaceId = useAtomValue(activeSearchSpaceIdAtom); const eligibleModels = useAutomationEligibleModels(); const [modelSelection, setModelSelection] = useState<AutomationModelSelection>({ - chatModelId: 0, + agentLlmId: 0, imageConfigId: 0, visionConfigId: 0, }); @@ -121,7 +121,7 @@ function ApprovalCard({ args, interruptData, onDecision }: ApprovalCardProps) { // default. No effect seeds async hook data into state. const resolvedModels = useMemo<AutomationModelSelection>( () => ({ - chatModelId: modelSelection.chatModelId || eligibleModels.llm.defaultId || 0, + agentLlmId: modelSelection.agentLlmId || eligibleModels.llm.defaultId || 0, imageConfigId: modelSelection.imageConfigId || eligibleModels.image.defaultId || 0, visionConfigId: modelSelection.visionConfigId || eligibleModels.vision.defaultId || 0, }), @@ -133,7 +133,7 @@ function ApprovalCard({ args, interruptData, onDecision }: ApprovalCardProps) { ] ); const modelsResolved = - resolvedModels.chatModelId !== 0 && + resolvedModels.agentLlmId !== 0 && resolvedModels.imageConfigId !== 0 && resolvedModels.visionConfigId !== 0; @@ -147,9 +147,9 @@ function ApprovalCard({ args, interruptData, onDecision }: ApprovalCardProps) { definition: { ...baseDefinition, models: { - chat_model_id: resolvedModels.chatModelId, - image_gen_model_id: resolvedModels.imageConfigId, - vision_model_id: resolvedModels.visionConfigId, + agent_llm_id: resolvedModels.agentLlmId, + image_generation_config_id: resolvedModels.imageConfigId, + vision_llm_config_id: resolvedModels.visionConfigId, }, }, }; @@ -162,9 +162,9 @@ function ApprovalCard({ args, interruptData, onDecision }: ApprovalCardProps) { trigger_type: (triggers[0] as { type?: string } | undefined)?.type ?? (triggers.length ? undefined : "none"), - chat_model_id: resolvedModels.chatModelId, - image_gen_model_id: resolvedModels.imageConfigId, - vision_model_id: resolvedModels.visionConfigId, + agent_llm_id: resolvedModels.agentLlmId, + image_generation_config_id: resolvedModels.imageConfigId, + vision_llm_config_id: resolvedModels.visionConfigId, }); onDecision({ type: "edit", @@ -211,7 +211,7 @@ function ApprovalCard({ args, interruptData, onDecision }: ApprovalCardProps) { <div className="my-4 max-w-lg overflow-hidden rounded-2xl border bg-muted/30 transition-[box-shadow] duration-300"> <div className="flex items-start justify-between gap-3 px-5 pt-5 pb-4 select-none"> <div className="flex items-start gap-3 min-w-0"> - <AlarmClock className="h-5 w-5 text-muted-foreground mt-0.5 shrink-0" aria-hidden /> + <Workflow className="h-5 w-5 text-muted-foreground mt-0.5 shrink-0" aria-hidden /> <div className="min-w-0"> <p className="text-sm font-semibold text-foreground"> {phase === "rejected" @@ -404,7 +404,7 @@ function SavedCard({ result }: { result: SavedResult }) { return ( <div className="my-4 max-w-lg overflow-hidden rounded-2xl border bg-muted/30 select-none"> <div className="flex items-start gap-3 px-5 pt-5 pb-4"> - <AlarmClock className="h-5 w-5 text-muted-foreground mt-0.5 shrink-0" aria-hidden /> + <Workflow className="h-5 w-5 text-muted-foreground mt-0.5 shrink-0" aria-hidden /> <div className="min-w-0"> <p className="text-sm font-semibold text-foreground">Automation saved</p> <p className="text-xs text-muted-foreground mt-0.5">{result.name}</p> diff --git a/surfsense_web/components/tool-ui/generate-podcast.tsx b/surfsense_web/components/tool-ui/generate-podcast.tsx new file mode 100644 index 000000000..2a62785e8 --- /dev/null +++ b/surfsense_web/components/tool-ui/generate-podcast.tsx @@ -0,0 +1,468 @@ +"use client"; + +import type { ToolCallMessagePartProps } from "@assistant-ui/react"; +import { useParams, usePathname } from "next/navigation"; +import { useCallback, useEffect, useRef, useState } from "react"; +import { z } from "zod"; +import { TextShimmerLoader } from "@/components/prompt-kit/loader"; +import { Audio } from "@/components/tool-ui/audio"; +import { + Accordion, + AccordionContent, + AccordionItem, + AccordionTrigger, +} from "@/components/ui/accordion"; +import { baseApiService } from "@/lib/apis/base-api.service"; +import { authenticatedFetch } from "@/lib/auth-utils"; +import { clearActivePodcastTaskId, setActivePodcastTaskId } from "@/lib/chat/podcast-state"; +import { BACKEND_URL } from "@/lib/env-config"; + +/** + * Zod schemas for runtime validation + */ +const GeneratePodcastArgsSchema = z.object({ + source_content: z.string(), + podcast_title: z.string().nullish(), + user_prompt: z.string().nullish(), +}); + +const GeneratePodcastResultSchema = z.object({ + // Support both old and new status values for backwards compatibility + status: z.enum([ + "pending", + "generating", + "ready", + "failed", + // Legacy values from old saved chats + "processing", + "already_generating", + "success", + "error", + ]), + podcast_id: z.number().nullish(), + task_id: z.string().nullish(), // Legacy field for old saved chats + title: z.string().nullish(), + transcript_entries: z.number().nullish(), + message: z.string().nullish(), + error: z.string().nullish(), +}); + +const PodcastStatusResponseSchema = z.object({ + status: z.enum(["pending", "generating", "ready", "failed"]), + id: z.number(), + title: z.string(), + transcript_entries: z.number().nullish(), + error: z.string().nullish(), +}); + +const PodcastTranscriptEntrySchema = z.object({ + speaker_id: z.number(), + dialog: z.string(), +}); + +const PodcastDetailsSchema = z.object({ + podcast_transcript: z.array(PodcastTranscriptEntrySchema).nullish(), +}); + +/** + * Types derived from Zod schemas + */ +type GeneratePodcastArgs = z.infer<typeof GeneratePodcastArgsSchema>; +type GeneratePodcastResult = z.infer<typeof GeneratePodcastResultSchema>; +type PodcastStatusResponse = z.infer<typeof PodcastStatusResponseSchema>; +type PodcastTranscriptEntry = z.infer<typeof PodcastTranscriptEntrySchema>; + +/** + * Parse and validate podcast status response + */ +function parsePodcastStatusResponse(data: unknown): PodcastStatusResponse | null { + const result = PodcastStatusResponseSchema.safeParse(data); + if (!result.success) { + console.warn("Invalid podcast status response:", result.error.issues); + return null; + } + return result.data; +} + +/** + * Parse and validate podcast details + */ +function parsePodcastDetails(data: unknown): { podcast_transcript?: PodcastTranscriptEntry[] } { + const result = PodcastDetailsSchema.safeParse(data); + if (!result.success) { + console.warn("Invalid podcast details:", result.error.issues); + return {}; + } + return { + podcast_transcript: result.data.podcast_transcript ?? undefined, + }; +} + +function PodcastGeneratingState({ title }: { title: string }) { + return ( + <div className="my-4 max-w-lg overflow-hidden rounded-2xl border bg-muted/30 select-none"> + <div className="px-5 pt-5 pb-4"> + <p className="text-sm font-semibold text-foreground line-clamp-2">{title}</p> + <TextShimmerLoader text="Generating podcast" size="sm" /> + </div> + </div> + ); +} + +function PodcastErrorState({ title, error }: { title: string; error: string }) { + return ( + <div className="my-4 max-w-lg overflow-hidden rounded-2xl border bg-muted/30 select-none"> + <div className="px-5 pt-5 pb-4"> + <p className="text-sm font-semibold text-destructive">Podcast Generation Failed</p> + </div> + <div className="mx-5 h-px bg-border/50" /> + <div className="px-5 py-4"> + <p className="text-sm font-medium text-foreground line-clamp-2">{title}</p> + <p className="text-sm text-muted-foreground mt-1">{error}</p> + </div> + </div> + ); +} + +function AudioLoadingState({ title }: { title: string }) { + return ( + <div className="my-4 max-w-lg overflow-hidden rounded-2xl border bg-muted/30 select-none"> + <div className="px-5 pt-5 pb-4"> + <p className="text-sm font-semibold text-foreground line-clamp-2">{title}</p> + <TextShimmerLoader text="Loading audio" size="sm" /> + </div> + </div> + ); +} + +function PodcastPlayer({ + podcastId, + title, + durationMs, +}: { + podcastId: number; + title: string; + durationMs?: number; +}) { + const params = useParams(); + const pathname = usePathname(); + const isPublicRoute = pathname?.startsWith("/public/"); + const shareToken = isPublicRoute && typeof params?.token === "string" ? params.token : null; + + const [audioSrc, setAudioSrc] = useState<string | null>(null); + const [transcript, setTranscript] = useState<PodcastTranscriptEntry[] | null>(null); + const [isLoading, setIsLoading] = useState(true); + const [error, setError] = useState<string | null>(null); + const objectUrlRef = useRef<string | null>(null); + + // Cleanup object URL on unmount + useEffect(() => { + return () => { + if (objectUrlRef.current) { + URL.revokeObjectURL(objectUrlRef.current); + } + }; + }, []); + + // Fetch audio and podcast details (including transcript) + const loadPodcast = useCallback(async () => { + setIsLoading(true); + setError(null); + + try { + // Revoke previous object URL if exists + if (objectUrlRef.current) { + URL.revokeObjectURL(objectUrlRef.current); + objectUrlRef.current = null; + } + + const controller = new AbortController(); + const timeoutId = setTimeout(() => controller.abort(), 60000); // 60s timeout + + try { + let audioBlob: Blob; + let rawPodcastDetails: unknown = null; + + if (shareToken) { + // Public view - use public endpoints (baseApiService handles no-auth for /api/v1/public/) + const [blob, details] = await Promise.all([ + baseApiService.getBlob(`/api/v1/public/${shareToken}/podcasts/${podcastId}/stream`), + baseApiService.get(`/api/v1/public/${shareToken}/podcasts/${podcastId}`), + ]); + audioBlob = blob; + rawPodcastDetails = details; + } else { + // Authenticated view - fetch audio and details in parallel + const [audioResponse, details] = await Promise.all([ + authenticatedFetch(`${BACKEND_URL}/api/v1/podcasts/${podcastId}/audio`, { + method: "GET", + signal: controller.signal, + }), + baseApiService.get<unknown>(`/api/v1/podcasts/${podcastId}`), + ]); + + if (!audioResponse.ok) { + throw new Error(`Failed to load audio: ${audioResponse.status}`); + } + + audioBlob = await audioResponse.blob(); + rawPodcastDetails = details; + } + + // Create object URL from blob + const objectUrl = URL.createObjectURL(audioBlob); + objectUrlRef.current = objectUrl; + setAudioSrc(objectUrl); + + // Parse and validate podcast details, then set transcript + if (rawPodcastDetails) { + const podcastDetails = parsePodcastDetails(rawPodcastDetails); + if (podcastDetails.podcast_transcript) { + setTranscript(podcastDetails.podcast_transcript); + } + } + } finally { + clearTimeout(timeoutId); + } + } catch (err) { + console.error("Error loading podcast:", err); + if (err instanceof DOMException && err.name === "AbortError") { + setError("Request timed out. Please try again."); + } else { + setError(err instanceof Error ? err.message : "Failed to load podcast"); + } + } finally { + setIsLoading(false); + } + }, [podcastId, shareToken]); + + // Load podcast when component mounts + useEffect(() => { + loadPodcast(); + }, [loadPodcast]); + + if (isLoading) { + return <AudioLoadingState title={title} />; + } + + if (error || !audioSrc) { + return <PodcastErrorState title={title} error={error || "Failed to load audio"} />; + } + + const hasTranscript = transcript && transcript.length > 0; + + return ( + <div className="my-4"> + <Audio + id={`podcast-${podcastId}`} + src={audioSrc} + title={title} + durationMs={durationMs} + className={hasTranscript ? "rounded-b-none border-b-0" : undefined} + /> + {hasTranscript && ( + <div className="max-w-lg overflow-hidden rounded-b-2xl border border-t-0 bg-muted/30 select-none"> + <div className="mx-5 h-px bg-border/50" /> + <Accordion type="single" collapsible className="px-5"> + <AccordionItem value="transcript" className="border-b-0"> + <AccordionTrigger className="py-3 text-xs sm:text-sm font-medium text-muted-foreground hover:text-accent-foreground hover:no-underline"> + View transcript + </AccordionTrigger> + <AccordionContent className="pb-0"> + <div className="space-y-2 max-h-64 sm:max-h-96 overflow-y-auto select-text"> + {transcript.map((entry, idx) => ( + <div key={`${idx}-${entry.speaker_id}`} className="text-xs sm:text-sm"> + <span className="font-medium text-primary"> + Speaker {entry.speaker_id + 1}: + </span>{" "} + <span className="text-muted-foreground">{entry.dialog}</span> + </div> + ))} + </div> + </AccordionContent> + </AccordionItem> + </Accordion> + </div> + )} + </div> + ); +} + +/** + * Polling component that checks podcast status and shows player when ready + */ +function PodcastStatusPoller({ podcastId, title }: { podcastId: number; title: string }) { + const [podcastStatus, setPodcastStatus] = useState<PodcastStatusResponse | null>(null); + const pollingRef = useRef<NodeJS.Timeout | null>(null); + + // Set active podcast state when this component mounts + useEffect(() => { + setActivePodcastTaskId(String(podcastId)); + + // Clear when component unmounts + return () => { + clearActivePodcastTaskId(); + }; + }, [podcastId]); + + // Poll for podcast status + useEffect(() => { + const pollStatus = async () => { + try { + const rawResponse = await baseApiService.get<unknown>(`/api/v1/podcasts/${podcastId}`); + const response = parsePodcastStatusResponse(rawResponse); + if (response) { + setPodcastStatus(response); + + // Stop polling if podcast is ready or failed + if (response.status === "ready" || response.status === "failed") { + if (pollingRef.current) { + clearInterval(pollingRef.current); + pollingRef.current = null; + } + clearActivePodcastTaskId(); + } + } + } catch (err) { + console.error("Error polling podcast status:", err); + // Don't stop polling on network errors, continue polling + } + }; + + // Initial poll + pollStatus(); + + // Poll every 5 seconds + pollingRef.current = setInterval(pollStatus, 5000); + + return () => { + if (pollingRef.current) { + clearInterval(pollingRef.current); + } + }; + }, [podcastId]); + + // Show loading state while pending or generating + if ( + !podcastStatus || + podcastStatus.status === "pending" || + podcastStatus.status === "generating" + ) { + return <PodcastGeneratingState title={title} />; + } + + // Show error state + if (podcastStatus.status === "failed") { + return <PodcastErrorState title={title} error={podcastStatus.error || "Generation failed"} />; + } + + // Show player when ready + if (podcastStatus.status === "ready") { + return <PodcastPlayer podcastId={podcastStatus.id} title={podcastStatus.title || title} />; + } + + // Fallback + return <PodcastErrorState title={title} error="Unexpected state" />; +} + +/** + * Generate Podcast Tool UI Component + * + * This component is registered with assistant-ui to render custom UI + * when the generate_podcast tool is called by the agent. + * + * It polls for task completion and auto-updates when the podcast is ready. + */ +export const GeneratePodcastToolUI = ({ + args, + result, + status, +}: ToolCallMessagePartProps<GeneratePodcastArgs, GeneratePodcastResult>) => { + const title = args.podcast_title || "SurfSense Podcast"; + + // Loading state - tool is still running (agent processing) + if (status.type === "running" || status.type === "requires-action") { + return <PodcastGeneratingState title={title} />; + } + + // Incomplete/cancelled state + if (status.type === "incomplete") { + if (status.reason === "cancelled") { + return ( + <div className="my-4 max-w-lg overflow-hidden rounded-2xl border bg-muted/30 select-none"> + <div className="px-5 pt-5 pb-4"> + <p className="text-sm font-semibold text-muted-foreground">Podcast Cancelled</p> + <p className="text-xs text-muted-foreground mt-0.5">Podcast generation was cancelled</p> + </div> + </div> + ); + } + if (status.reason === "error") { + return ( + <PodcastErrorState + title={title} + error={typeof status.error === "string" ? status.error : "An error occurred"} + /> + ); + } + } + + // No result yet + if (!result) { + return <PodcastGeneratingState title={title} />; + } + + // Failed result (new: "failed", legacy: "error") + if (result.status === "failed" || result.status === "error") { + return <PodcastErrorState title={title} error={result.error || "Generation failed"} />; + } + + // Pending/generating rows have a stable podcast_id, so the card can poll + // independently while the chat stream finishes. + if ( + (result.status === "pending" || + result.status === "generating" || + result.status === "processing") && + result.podcast_id + ) { + return <PodcastStatusPoller podcastId={result.podcast_id} title={result.title || title} />; + } + + // Legacy duplicate/no-ID result - show a simple warning, don't create + // another poller. The first tool call will display the podcast when ready. + if (result.status === "generating" || result.status === "already_generating") { + return ( + <div className="my-4 max-w-lg overflow-hidden rounded-2xl border bg-muted/30 select-none"> + <div className="px-5 pt-5 pb-4"> + <p className="text-sm font-semibold text-foreground">Podcast already in progress</p> + <p className="text-xs text-muted-foreground mt-0.5"> + Please wait for the current podcast to complete. + </p> + </div> + </div> + ); + } + + // Ready with podcast_id (new: "ready", legacy: "success") + if ((result.status === "ready" || result.status === "success") && result.podcast_id) { + return <PodcastPlayer podcastId={result.podcast_id} title={result.title || title} />; + } + + // Legacy: old chats with Celery task_id (status: "processing" or "success" without podcast_id) + // These can't be recovered since the old task polling endpoint no longer exists + if (result.task_id && !result.podcast_id) { + return ( + <div className="my-4 max-w-lg overflow-hidden rounded-2xl border bg-muted/30 select-none"> + <div className="px-5 pt-5 pb-4"> + <p className="text-sm font-semibold text-muted-foreground">Podcast Unavailable</p> + <p className="text-xs text-muted-foreground mt-0.5"> + This podcast was generated with an older version. Please generate a new one. + </p> + </div> + </div> + ); + } + + // Fallback - missing required data + return <PodcastErrorState title={title} error="Missing podcast ID" />; +}; diff --git a/surfsense_web/components/tool-ui/generate-resume.tsx b/surfsense_web/components/tool-ui/generate-resume.tsx index 9147d4199..5533674bf 100644 --- a/surfsense_web/components/tool-ui/generate-resume.tsx +++ b/surfsense_web/components/tool-ui/generate-resume.tsx @@ -13,7 +13,7 @@ import { Button } from "@/components/ui/button"; import { useMediaQuery } from "@/hooks/use-media-query"; import { baseApiService } from "@/lib/apis/base-api.service"; import { getAuthHeaders } from "@/lib/auth-utils"; -import { buildBackendUrl } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; pdfjsLib.GlobalWorkerOptions.workerSrc = new URL( "pdfjs-dist/build/pdf.worker.min.mjs", @@ -223,7 +223,7 @@ function ResumeCard({ const previewPath = shareToken ? `/api/v1/public/${shareToken}/reports/${reportId}/preview` : `/api/v1/reports/${reportId}/preview`; - setPdfUrl(buildBackendUrl(previewPath)); + setPdfUrl(`${BACKEND_URL}${previewPath}`); if (autoOpen && isDesktop && !autoOpenedRef.current) { autoOpenedRef.current = true; diff --git a/surfsense_web/components/tool-ui/index.ts b/surfsense_web/components/tool-ui/index.ts index a6576f065..ee5072dad 100644 --- a/surfsense_web/components/tool-ui/index.ts +++ b/surfsense_web/components/tool-ui/index.ts @@ -16,6 +16,7 @@ export { GenerateImageResultSchema, GenerateImageToolUI, } from "./generate-image"; +export { GeneratePodcastToolUI } from "./generate-podcast"; export { GenerateReportToolUI } from "./generate-report"; export { CreateGoogleDriveFileToolUI, DeleteGoogleDriveFileToolUI } from "./google-drive"; export { @@ -43,7 +44,6 @@ export { type SerializablePlan, type TodoStatus, } from "./plan"; -export { GeneratePodcastToolUI } from "./podcast"; export { type ExecuteArgs, ExecuteArgsSchema, diff --git a/surfsense_web/components/tool-ui/podcast/brief-review.tsx b/surfsense_web/components/tool-ui/podcast/brief-review.tsx deleted file mode 100644 index d982d6f85..000000000 --- a/surfsense_web/components/tool-ui/podcast/brief-review.tsx +++ /dev/null @@ -1,572 +0,0 @@ -"use client"; - -import { Check, ChevronDown, Loader2, Plus, Trash2 } from "lucide-react"; -import { useEffect, useMemo, useState } from "react"; -import { toast } from "sonner"; -import { Button } from "@/components/ui/button"; -import { - Command, - CommandEmpty, - CommandGroup, - CommandInput, - CommandItem, - CommandList, -} from "@/components/ui/command"; -import { Input } from "@/components/ui/input"; -import { Label } from "@/components/ui/label"; -import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; -import { - Select, - SelectContent, - SelectItem, - SelectTrigger, - SelectValue, -} from "@/components/ui/select"; -import { Textarea } from "@/components/ui/textarea"; -import { - type LanguageOptions, - MAX_DURATION_SECONDS, - MAX_SPEAKERS, - MIN_DURATION_SECONDS, - type PodcastSpec, - type PodcastStyle, - podcastStyle, - type SpeakerRole, - speakerRole, - type VoiceOption, -} from "@/contracts/types/podcast.types"; -import type { LivePodcast } from "@/hooks/use-podcast-live"; -import { podcastsApiService } from "@/lib/apis/podcasts-api.service"; -import { AppError } from "@/lib/error"; -import { VoicePreviewButton } from "./voice-preview-button"; - -// A "*" voice speaks whatever language the text is in (mirrors the backend -// catalog's ANY_LANGUAGE sentinel). -const ANY_LANGUAGE = "*"; - -function speaks(voice: VoiceOption, language: string): boolean { - if (voice.language === ANY_LANGUAGE) return true; - return primary(voice.language) === primary(language); -} - -function primary(language: string): string { - return language.split("-", 1)[0].trim().toLowerCase(); -} - -interface BriefReviewProps { - podcast: LivePodcast; - spec: PodcastSpec; -} - -/** - * The brief gate, rendered inline in the chat card: a pre-filled - * near-confirmation. One-click approve is the easy path; every field stays - * overridable and saves through the version-guarded PATCH so concurrent edits - * surface instead of clobbering. Approval needs no local follow-up — the - * pushed status flips the card to its drafting state. - */ -export function BriefReview({ podcast, spec }: BriefReviewProps) { - const [draft, setDraft] = useState<PodcastSpec>(spec); - const [durationUnit, setDurationUnit] = useState<DurationUnit>(() => - defaultDurationUnit(spec.duration.max_seconds) - ); - const [voices, setVoices] = useState<VoiceOption[] | null>(null); - const [offering, setOffering] = useState<LanguageOptions | null>(null); - const [isSubmitting, setIsSubmitting] = useState(false); - - // A pushed spec change (saved edit or concurrent editor) resets the form to - // the authoritative version. - // biome-ignore lint/correctness/useExhaustiveDependencies: reset only when the server version moves - useEffect(() => { - setDraft(spec); - setDurationUnit(defaultDurationUnit(spec.duration.max_seconds)); - }, [podcast.specVersion]); - - useEffect(() => { - let cancelled = false; - podcastsApiService - .listVoices() - .then((catalog) => { - if (!cancelled) setVoices(catalog); - }) - .catch(() => { - if (!cancelled) setVoices([]); - }); - podcastsApiService - .listLanguages() - .then((options) => { - if (!cancelled) setOffering(options); - }) - .catch(() => { - if (!cancelled) setOffering({ languages: [], allows_custom: false }); - }); - return () => { - cancelled = true; - }; - }, []); - - // The backend owns the offering; the draft's language stays listed even - // when it falls outside it (e.g. a custom tag entered earlier). - const languages = useMemo(() => { - const tags = new Set(offering?.languages ?? []); - tags.add(draft.language); - return [...tags].sort(); - }, [offering, draft.language]); - - const voicesForLanguage = useMemo( - () => (voices ?? []).filter((voice) => speaks(voice, draft.language)), - [voices, draft.language] - ); - - const isDirty = useMemo(() => JSON.stringify(draft) !== JSON.stringify(spec), [draft, spec]); - - const setLanguage = (language: string) => { - setDraft((current) => { - const candidates = (voices ?? []).filter((voice) => speaks(voice, language)); - // Voices that can't render the new language are remapped so the saved - // spec never pairs a language with an incompatible voice. - const speakers = current.speakers.map((speaker, index) => { - const stillValid = candidates.some((voice) => voice.voice_id === speaker.voice_id); - const fallback = candidates[index % Math.max(candidates.length, 1)]; - return stillValid || !fallback ? speaker : { ...speaker, voice_id: fallback.voice_id }; - }); - return { ...current, language, speakers }; - }); - }; - - const setStyle = (style: PodcastStyle) => { - setDraft((current) => ({ - ...current, - style, - // A monologue has exactly one speaker, so extra speakers are dropped - // rather than letting approval fail validation. - speakers: style === "monologue" ? current.speakers.slice(0, 1) : current.speakers, - })); - }; - - const updateSpeaker = (slot: number, change: Partial<PodcastSpec["speakers"][number]>) => { - setDraft((current) => ({ - ...current, - speakers: current.speakers.map((speaker) => - speaker.slot === slot ? { ...speaker, ...change } : speaker - ), - })); - }; - - const addSpeaker = () => { - setDraft((current) => { - if (current.speakers.length >= MAX_SPEAKERS) return current; - const slot = Math.max(...current.speakers.map((s) => s.slot)) + 1; - const voice = - voicesForLanguage[current.speakers.length % Math.max(voicesForLanguage.length, 1)]; - return { - ...current, - speakers: [ - ...current.speakers, - { - slot, - name: `Speaker ${current.speakers.length + 1}`, - role: "guest" as SpeakerRole, - voice_id: voice?.voice_id ?? current.speakers[0].voice_id, - }, - ], - }; - }); - }; - - const removeSpeaker = (slot: number) => { - setDraft((current) => { - if (current.speakers.length <= 1) return current; - return { - ...current, - speakers: current.speakers.filter((speaker) => speaker.slot !== slot), - }; - }); - }; - - const saveIfDirty = async (): Promise<boolean> => { - if (!isDirty) return true; - try { - await podcastsApiService.updateSpec(podcast.id, draft, podcast.specVersion); - return true; - } catch (error) { - if (error instanceof AppError && error.status === 409) { - toast.warning("The brief changed elsewhere — reloaded the latest version."); - setDraft(spec); - } else { - toast.error(error instanceof Error ? error.message : "Failed to save the brief"); - } - return false; - } - }; - - const handleApprove = async () => { - setIsSubmitting(true); - try { - if (!(await saveIfDirty())) return; - await podcastsApiService.approveBrief(podcast.id); - } catch (error) { - toast.error(error instanceof Error ? error.message : "Failed to approve the brief"); - } finally { - setIsSubmitting(false); - } - }; - - return ( - <div className="flex flex-col gap-6"> - <div className="grid grid-cols-2 gap-4"> - <div className="flex flex-col gap-2"> - <Label htmlFor="podcast-language">Language</Label> - {offering?.allows_custom ? ( - <LanguageCombobox value={draft.language} languages={languages} onSelect={setLanguage} /> - ) : ( - <Select value={draft.language} onValueChange={setLanguage}> - <SelectTrigger id="podcast-language"> - <SelectValue placeholder="Language" /> - </SelectTrigger> - <SelectContent> - {languages.map((tag) => ( - <SelectItem key={tag} value={tag}> - {languageLabel(tag)} - </SelectItem> - ))} - </SelectContent> - </Select> - )} - </div> - <div className="flex flex-col gap-2"> - <Label htmlFor="podcast-style">Style</Label> - <Select value={draft.style} onValueChange={(value) => setStyle(value as PodcastStyle)}> - <SelectTrigger id="podcast-style"> - <SelectValue placeholder="Style" /> - </SelectTrigger> - <SelectContent> - {podcastStyle.options.map((style) => ( - <SelectItem key={style} value={style}> - {capitalize(style)} - </SelectItem> - ))} - </SelectContent> - </Select> - </div> - </div> - - <div className="flex flex-col gap-3"> - <div className="flex items-center justify-between"> - <Label>Speakers</Label> - <Button - type="button" - variant="ghost" - size="sm" - onClick={addSpeaker} - disabled={draft.style === "monologue" || draft.speakers.length >= MAX_SPEAKERS} - > - <Plus className="size-4" /> Add speaker - </Button> - </div> - {draft.speakers.map((speaker) => ( - <div key={speaker.slot} className="flex items-end gap-2 rounded-lg border p-3"> - <div className="flex flex-1 flex-col gap-1.5"> - <Label htmlFor={`speaker-name-${speaker.slot}`} className="text-xs"> - Name - </Label> - <Input - id={`speaker-name-${speaker.slot}`} - value={speaker.name} - maxLength={120} - onChange={(e) => updateSpeaker(speaker.slot, { name: e.target.value })} - /> - </div> - <div className="flex w-28 flex-col gap-1.5"> - <Label className="text-xs">Role</Label> - <Select - value={speaker.role} - onValueChange={(value) => - updateSpeaker(speaker.slot, { role: value as SpeakerRole }) - } - > - <SelectTrigger> - <SelectValue /> - </SelectTrigger> - <SelectContent> - {speakerRole.options.map((role) => ( - <SelectItem key={role} value={role}> - {capitalize(role)} - </SelectItem> - ))} - </SelectContent> - </Select> - </div> - <div className="flex w-52 flex-col gap-1.5"> - <Label className="text-xs">Voice</Label> - <div className="flex items-center gap-1"> - <Select - value={speaker.voice_id} - onValueChange={(value) => updateSpeaker(speaker.slot, { voice_id: value })} - > - <SelectTrigger> - <SelectValue placeholder={voices === null ? "Loading…" : "Voice"} /> - </SelectTrigger> - <SelectContent> - {voiceItems(voicesForLanguage, speaker.voice_id).map((voice) => ( - <SelectItem key={voice.voice_id} value={voice.voice_id}> - {voice.display_name} ({voice.gender}) - </SelectItem> - ))} - </SelectContent> - </Select> - <VoicePreviewButton voiceId={speaker.voice_id} /> - </div> - </div> - <Button - type="button" - variant="ghost" - size="icon" - aria-label={`Remove ${speaker.name}`} - onClick={() => removeSpeaker(speaker.slot)} - disabled={draft.speakers.length <= 1} - > - <Trash2 className="size-4" /> - </Button> - </div> - ))} - </div> - - <div className="flex flex-col gap-2"> - <div className="flex items-center justify-between gap-3"> - <Label>Target length</Label> - <Select - value={durationUnit} - onValueChange={(value) => setDurationUnit(value as DurationUnit)} - > - <SelectTrigger className="w-[7.5rem]" aria-label="Length unit"> - <SelectValue /> - </SelectTrigger> - <SelectContent> - <SelectItem value="seconds">Seconds</SelectItem> - <SelectItem value="minutes">Minutes</SelectItem> - <SelectItem value="hours">Hours</SelectItem> - </SelectContent> - </Select> - </div> - <div className="grid grid-cols-2 gap-4"> - <div className="flex flex-col gap-2"> - <Label htmlFor="podcast-min-length">Min</Label> - <Input - id="podcast-min-length" - type="number" - min={durationUnitBounds(durationUnit).min} - max={durationUnitBounds(durationUnit).max} - step={durationInputStep(durationUnit)} - value={formatDurationForUnit(draft.duration.min_seconds, durationUnit)} - onChange={(e) => { - const seconds = clampDurationSeconds( - fromUnitValue(Number(e.target.value), durationUnit) - ); - setDraft((current) => ({ - ...current, - duration: { ...current.duration, min_seconds: seconds }, - })); - }} - /> - </div> - <div className="flex flex-col gap-2"> - <Label htmlFor="podcast-max-length">Max</Label> - <Input - id="podcast-max-length" - type="number" - min={secondsToUnitValue(draft.duration.min_seconds, durationUnit)} - max={durationUnitBounds(durationUnit).max} - step={durationInputStep(durationUnit)} - value={formatDurationForUnit(draft.duration.max_seconds, durationUnit)} - onChange={(e) => { - const parsed = Number(e.target.value); - const fallback = secondsToUnitValue(draft.duration.min_seconds, durationUnit); - const seconds = clampDurationSeconds( - fromUnitValue(Number.isFinite(parsed) ? parsed : fallback, durationUnit) - ); - setDraft((current) => ({ - ...current, - duration: { ...current.duration, max_seconds: seconds }, - })); - }} - /> - </div> - </div> - </div> - - <div className="flex flex-col gap-2"> - <Label htmlFor="podcast-focus">Focus (optional)</Label> - <Textarea - id="podcast-focus" - placeholder="What should the episode emphasise?" - maxLength={2000} - value={draft.focus ?? ""} - onChange={(e) => setDraft((current) => ({ ...current, focus: e.target.value || null }))} - /> - </div> - - <div className="flex justify-end gap-2"> - {isDirty ? ( - <Button - type="button" - variant="ghost" - onClick={() => setDraft(spec)} - disabled={isSubmitting} - > - Discard - </Button> - ) : null} - <Button - type="button" - onClick={handleApprove} - disabled={isSubmitting || draft.duration.max_seconds < draft.duration.min_seconds} - > - {isSubmitting ? <Loader2 className="size-4 animate-spin" /> : null} - {isDirty ? "Approve changes & draft transcript" : "Approve & draft transcript"} - </Button> - </div> - </div> - ); -} - -/** A searchable language picker for providers whose voices speak anything: - * the offered list comes from the backend, and any BCP-47 tag may be typed - * when none of them fits. */ -function LanguageCombobox({ - value, - languages, - onSelect, -}: { - value: string; - languages: string[]; - onSelect: (language: string) => void; -}) { - const [open, setOpen] = useState(false); - const [query, setQuery] = useState(""); - - const pick = (tag: string) => { - onSelect(tag); - setOpen(false); - setQuery(""); - }; - - const customTag = query.trim(); - const isNewTag = - customTag.length > 0 && !languages.some((tag) => tag.toLowerCase() === customTag.toLowerCase()); - - return ( - <Popover open={open} onOpenChange={setOpen}> - <PopoverTrigger asChild> - <button - type="button" - role="combobox" - aria-expanded={open} - id="podcast-language" - className="border-popover-border flex h-9 w-full items-center justify-between gap-2 rounded-md border bg-transparent px-3 py-2 text-sm whitespace-nowrap shadow-xs outline-none transition-[color,box-shadow] disabled:cursor-not-allowed disabled:opacity-50" - > - <span className="line-clamp-1 text-left">{languageLabel(value)}</span> - <ChevronDown className="size-4 shrink-0 opacity-50" /> - </button> - </PopoverTrigger> - <PopoverContent className="w-[var(--radix-popover-trigger-width)] p-0" align="start"> - <Command> - <CommandInput - placeholder="Search or type a language tag…" - value={query} - onValueChange={setQuery} - /> - <CommandList> - <CommandEmpty>No matching language.</CommandEmpty> - <CommandGroup> - {languages.map((tag) => ( - <CommandItem - key={tag} - value={tag} - keywords={[languageLabel(tag)]} - onSelect={() => pick(tag)} - > - <Check className={tag === value ? "size-4" : "size-4 opacity-0"} /> - {languageLabel(tag)} - </CommandItem> - ))} - {isNewTag ? ( - <CommandItem value={customTag} onSelect={() => pick(customTag)}> - <Plus className="size-4" /> - Use “{customTag}” - </CommandItem> - ) : null} - </CommandGroup> - </CommandList> - </Command> - </PopoverContent> - </Popover> - ); -} - -/** The current selection stays listed even when it no longer matches the - * language filter, so the Select never renders an orphaned value. */ -type DurationUnit = "seconds" | "minutes" | "hours"; - -function defaultDurationUnit(maxSeconds: number): DurationUnit { - if (maxSeconds >= 3600) return "hours"; - if (maxSeconds >= 60) return "minutes"; - return "seconds"; -} - -function secondsToUnitValue(seconds: number, unit: DurationUnit): number { - if (unit === "minutes") return seconds / 60; - if (unit === "hours") return seconds / 3600; - return seconds; -} - -function fromUnitValue(value: number, unit: DurationUnit): number { - if (!Number.isFinite(value)) return MIN_DURATION_SECONDS; - if (unit === "minutes") return value * 60; - if (unit === "hours") return value * 3600; - return value; -} - -function formatDurationForUnit(seconds: number, unit: DurationUnit): number { - const raw = secondsToUnitValue(seconds, unit); - if (unit === "seconds") return Math.round(raw); - return Math.round(raw * 100) / 100; -} - -function durationInputStep(unit: DurationUnit): number { - if (unit === "hours") return 0.1; - return 1; -} - -function durationUnitBounds(unit: DurationUnit): { min: number; max: number } { - return { - min: formatDurationForUnit(MIN_DURATION_SECONDS, unit), - max: formatDurationForUnit(MAX_DURATION_SECONDS, unit), - }; -} - -function clampDurationSeconds(value: number): number { - if (!Number.isFinite(value)) return MIN_DURATION_SECONDS; - return Math.min(MAX_DURATION_SECONDS, Math.max(MIN_DURATION_SECONDS, Math.round(value))); -} - -function voiceItems(candidates: VoiceOption[], selectedId: string): VoiceOption[] { - if (candidates.some((voice) => voice.voice_id === selectedId)) return candidates; - return [ - { voice_id: selectedId, display_name: selectedId, language: "", gender: "unknown" }, - ...candidates, - ]; -} - -function languageLabel(tag: string): string { - try { - const label = new Intl.DisplayNames(["en"], { type: "language" }).of(tag); - return label && label !== tag ? `${label} (${tag})` : tag; - } catch { - return tag; - } -} - -function capitalize(value: string): string { - return value.charAt(0).toUpperCase() + value.slice(1); -} diff --git a/surfsense_web/components/tool-ui/podcast/generate-podcast.tsx b/surfsense_web/components/tool-ui/podcast/generate-podcast.tsx deleted file mode 100644 index f881be9dd..000000000 --- a/surfsense_web/components/tool-ui/podcast/generate-podcast.tsx +++ /dev/null @@ -1,371 +0,0 @@ -"use client"; - -import type { ToolCallMessagePartProps } from "@assistant-ui/react"; -import { Loader2, RotateCcw, Undo2, X } from "lucide-react"; -import { usePathname } from "next/navigation"; -import { type ReactNode, useEffect, useState } from "react"; -import { toast } from "sonner"; -import { TextShimmerLoader } from "@/components/prompt-kit/loader"; -import { - AlertDialog, - AlertDialogAction, - AlertDialogCancel, - AlertDialogContent, - AlertDialogDescription, - AlertDialogFooter, - AlertDialogHeader, - AlertDialogTitle, - AlertDialogTrigger, -} from "@/components/ui/alert-dialog"; -import { Button, buttonVariants } from "@/components/ui/button"; -import { type LivePodcast, usePodcastLive } from "@/hooks/use-podcast-live"; -import { podcastsApiService } from "@/lib/apis/podcasts-api.service"; -import { BriefReview } from "./brief-review"; -import { PodcastErrorState, PodcastPlayer } from "./player"; -import type { GeneratePodcastArgs, GeneratePodcastResult } from "./schema"; - -function WorkingState({ - title, - label, - action, -}: { - title: string; - label: string; - action?: ReactNode; -}) { - return ( - <div className="my-4 max-w-lg overflow-hidden rounded-2xl border bg-muted/30 select-none"> - <div className="flex items-start justify-between gap-3 px-5 pt-5 pb-4"> - <div className="min-w-0"> - <p className="text-sm font-semibold text-foreground line-clamp-2">{title}</p> - <TextShimmerLoader text={label} size="sm" /> - </div> - {action} - </div> - </div> - ); -} - -function NoticeState({ title, message }: { title: string; message: string }) { - return ( - <div className="my-4 max-w-lg overflow-hidden rounded-2xl border bg-muted/30 select-none"> - <div className="px-5 pt-5 pb-4"> - <p className="text-sm font-semibold text-muted-foreground">{title}</p> - <p className="text-xs text-muted-foreground mt-0.5">{message}</p> - </div> - </div> - ); -} - -/** - * Regenerating reopens the brief and ultimately replaces the current audio, - * so a stray click is guarded by an inline confirm step. - */ -function RegenerateButton({ podcast }: { podcast: LivePodcast }) { - const [confirming, setConfirming] = useState(false); - const [isSubmitting, setIsSubmitting] = useState(false); - - const regenerate = async () => { - setIsSubmitting(true); - try { - await podcastsApiService.regenerate(podcast.id); - } catch (error) { - toast.error(error instanceof Error ? error.message : "Failed to regenerate the podcast"); - } finally { - setIsSubmitting(false); - setConfirming(false); - } - }; - - if (!confirming) { - return ( - <Button - type="button" - variant="ghost" - size="sm" - className="text-muted-foreground" - onClick={() => setConfirming(true)} - > - <RotateCcw className="size-3.5" /> Regenerate - </Button> - ); - } - - return ( - <div className="flex items-center gap-2"> - <span className="text-xs text-muted-foreground"> - Reopen the brief and replace this episode? - </span> - <Button - type="button" - variant="ghost" - size="sm" - onClick={() => setConfirming(false)} - disabled={isSubmitting} - > - Keep it - </Button> - <Button - type="button" - variant="destructive" - size="sm" - onClick={regenerate} - disabled={isSubmitting} - > - {isSubmitting ? <Loader2 className="size-3.5 animate-spin" /> : null} - Regenerate - </Button> - </div> - ); -} - -/** - * The way out of an in-flight generation depends on what already exists: - * a regeneration is reverted (the stored episode survives, so no confirm), - * while a first-time generation is cancelled (destructive, so confirmed via a - * dialog — the card header is too cramped to host a confirmation row). - */ -function BackOutButton({ podcastId, hasEpisode }: { podcastId: number; hasEpisode: boolean }) { - const [isSubmitting, setIsSubmitting] = useState(false); - - const run = async (call: (id: number) => Promise<unknown>, failure: string) => { - setIsSubmitting(true); - try { - await call(podcastId); - } catch (error) { - toast.error(error instanceof Error ? error.message : failure); - } finally { - setIsSubmitting(false); - } - }; - - if (hasEpisode) { - return ( - <Button - type="button" - variant="ghost" - size="sm" - className="shrink-0 text-muted-foreground" - disabled={isSubmitting} - onClick={() => - run(podcastsApiService.revertRegeneration, "Failed to restore the current episode") - } - > - {isSubmitting ? ( - <Loader2 className="size-3.5 animate-spin" /> - ) : ( - <Undo2 className="size-3.5" /> - )} - Keep current episode - </Button> - ); - } - - return ( - <AlertDialog> - <AlertDialogTrigger asChild> - <Button - type="button" - variant="ghost" - size="sm" - className="shrink-0 text-muted-foreground" - disabled={isSubmitting} - > - <X className="size-3.5" /> Cancel - </Button> - </AlertDialogTrigger> - <AlertDialogContent> - <AlertDialogHeader> - <AlertDialogTitle>Cancel this podcast?</AlertDialogTitle> - <AlertDialogDescription> - Generation stops and the podcast is discarded. This cannot be undone. - </AlertDialogDescription> - </AlertDialogHeader> - <AlertDialogFooter> - <AlertDialogCancel>Keep going</AlertDialogCancel> - <AlertDialogAction - className={buttonVariants({ variant: "destructive" })} - onClick={() => run(podcastsApiService.cancel, "Failed to cancel the podcast")} - > - Cancel podcast - </AlertDialogAction> - </AlertDialogFooter> - </AlertDialogContent> - </AlertDialog> - ); -} - -const BACK_OUT_STATUSES = new Set(["awaiting_brief", "drafting", "rendering"]); - -/** Status-driven card for an authenticated viewer, fed by Zero push. */ -function LivePodcastCard({ - podcastId, - fallbackTitle, -}: { - podcastId: number; - fallbackTitle: string; -}) { - const { podcast, isLoading } = usePodcastLive(podcastId); - - // Whether a finished episode exists decides revert-vs-cancel, and Zero - // doesn't publish audio fields — so the in-flight states check over REST, - // re-checking on each status change (a fresh podcast gains its episode, - // a regeneration starts with one). - const status = podcast?.status; - const [hasEpisode, setHasEpisode] = useState(false); - useEffect(() => { - if (!status || !BACK_OUT_STATUSES.has(status)) return; - let stale = false; - podcastsApiService - .getDetail(podcastId) - .then((detail) => { - if (!stale) setHasEpisode(detail.has_audio); - }) - .catch(() => {}); - return () => { - stale = true; - }; - }, [podcastId, status]); - - if (!podcast) { - if (isLoading) { - return <WorkingState title={fallbackTitle} label="Loading podcast" />; - } - return ( - <NoticeState - title="Podcast Unavailable" - message="This podcast no longer exists or you don't have access to it." - /> - ); - } - - const title = podcast.title || fallbackTitle; - - const backOut = <BackOutButton podcastId={podcast.id} hasEpisode={hasEpisode} />; - - switch (podcast.status) { - case "pending": - return <WorkingState title={title} label="Preparing brief" />; - case "drafting": - return <WorkingState title={title} label="Drafting transcript" action={backOut} />; - case "rendering": - return <WorkingState title={title} label="Rendering audio" action={backOut} />; - case "awaiting_brief": - // The gate lives right in the chat: the form is the card, so there - // is nothing to open and nothing to dismiss. - if (!podcast.spec) { - return <WorkingState title={title} label="Preparing brief" />; - } - return ( - <div className="my-4 max-w-xl overflow-hidden rounded-2xl border bg-muted/30"> - <div className="flex items-start justify-between gap-3 px-5 pt-5 pb-3 select-none"> - <div className="min-w-0"> - <p className="text-sm font-semibold text-foreground line-clamp-2">{title}</p> - <p className="text-xs text-muted-foreground mt-0.5"> - Confirm the language, voices, and length — the episode generates automatically after - you approve. - </p> - </div> - {backOut} - </div> - <div className="mx-5 h-px bg-border/50" /> - <div className="px-5 py-4"> - <BriefReview podcast={podcast} spec={podcast.spec} /> - </div> - </div> - ); - case "awaiting_review": - // Legacy rows parked at the removed transcript gate; the only way - // forward is regenerating through the brief gate. - return ( - <div className="my-4 max-w-lg overflow-hidden rounded-2xl border bg-muted/30 select-none"> - <div className="px-5 pt-5 pb-4"> - <p className="text-sm font-semibold text-foreground line-clamp-2">{title}</p> - <p className="text-xs text-muted-foreground mt-0.5"> - This podcast was drafted before audio rendering became automatic. - </p> - </div> - <div className="mx-5 h-px bg-border/50" /> - <div className="flex justify-end px-5 py-3"> - <RegenerateButton podcast={podcast} /> - </div> - </div> - ); - case "ready": - return ( - <div> - <PodcastPlayer - podcastId={podcast.id} - title={title} - durationMs={podcast.durationSeconds ? podcast.durationSeconds * 1000 : undefined} - /> - <div className="-mt-2 mb-4 flex max-w-lg justify-end"> - <RegenerateButton podcast={podcast} /> - </div> - </div> - ); - case "failed": - return <PodcastErrorState title={title} error={podcast.error || "Generation failed"} />; - case "cancelled": - return <NoticeState title="Podcast Cancelled" message="This podcast was cancelled." />; - } -} - -/** - * Tool UI for `generate_podcast`. The tool only prepares the podcast (it - * returns with the brief awaiting review), so this card follows the lifecycle - * by Zero push, rendering the brief form inline at the gate. Public shared - * chats have no Zero session; their snapshots only ever contain finished - * episodes, so the player renders directly against the share-token endpoints. - */ -export const GeneratePodcastToolUI = ({ - args, - result, - status, -}: ToolCallMessagePartProps<GeneratePodcastArgs, GeneratePodcastResult>) => { - const pathname = usePathname(); - const isPublicRoute = !!pathname?.startsWith("/public/"); - const title = args.podcast_title || "SurfSense Podcast"; - - if (status.type === "running" || status.type === "requires-action") { - return <WorkingState title={title} label="Preparing podcast" />; - } - - if (status.type === "incomplete") { - if (status.reason === "cancelled") { - return <NoticeState title="Podcast Cancelled" message="Podcast preparation was cancelled." />; - } - if (status.reason === "error") { - return ( - <PodcastErrorState - title={title} - error={typeof status.error === "string" ? status.error : "An error occurred"} - /> - ); - } - } - - if (!result) { - return <WorkingState title={title} label="Preparing podcast" />; - } - - if (result.podcast_id) { - if (isPublicRoute) { - return <PodcastPlayer podcastId={result.podcast_id} title={result.title || title} />; - } - return <LivePodcastCard podcastId={result.podcast_id} fallbackTitle={result.title || title} />; - } - - if (result.status === "failed" || result.status === "error") { - return <PodcastErrorState title={title} error={result.error || "Generation failed"} />; - } - - // Legacy saved chats: results identified only by a Celery task id can't be - // recovered through the lifecycle API. - return ( - <NoticeState - title="Podcast Unavailable" - message="This podcast was generated with an older version. Please generate a new one." - /> - ); -}; diff --git a/surfsense_web/components/tool-ui/podcast/index.ts b/surfsense_web/components/tool-ui/podcast/index.ts deleted file mode 100644 index 1e5e5e06e..000000000 --- a/surfsense_web/components/tool-ui/podcast/index.ts +++ /dev/null @@ -1 +0,0 @@ -export { GeneratePodcastToolUI } from "./generate-podcast"; diff --git a/surfsense_web/components/tool-ui/podcast/player.tsx b/surfsense_web/components/tool-ui/podcast/player.tsx deleted file mode 100644 index ac00b6780..000000000 --- a/surfsense_web/components/tool-ui/podcast/player.tsx +++ /dev/null @@ -1,209 +0,0 @@ -"use client"; - -import { useParams, usePathname } from "next/navigation"; -import { useCallback, useEffect, useRef, useState } from "react"; -import { z } from "zod"; -import { TextShimmerLoader } from "@/components/prompt-kit/loader"; -import { Audio } from "@/components/tool-ui/audio"; -import { - Accordion, - AccordionContent, - AccordionItem, - AccordionTrigger, -} from "@/components/ui/accordion"; -import { baseApiService } from "@/lib/apis/base-api.service"; -import { podcastsApiService } from "@/lib/apis/podcasts-api.service"; -import { authenticatedFetch } from "@/lib/auth-utils"; -import { buildBackendUrl } from "@/lib/env-config"; -import { speakerLabel } from "./schema"; - -// Public snapshots predate the transcript.turns shape and keep their own. -const publicPodcastDetailsSchema = z.object({ - podcast_transcript: z.array(z.object({ speaker_id: z.number(), dialog: z.string() })).nullish(), -}); - -interface TranscriptLine { - // Transcripts are immutable once fetched, so a turn's position identifies it. - key: string; - label: string; - text: string; -} - -export function PodcastErrorState({ title, error }: { title: string; error: string }) { - return ( - <div className="my-4 max-w-lg overflow-hidden rounded-2xl border bg-muted/30 select-none"> - <div className="px-5 pt-5 pb-4"> - <p className="text-sm font-semibold text-destructive">Podcast Generation Failed</p> - </div> - <div className="mx-5 h-px bg-border/50" /> - <div className="px-5 py-4"> - <p className="text-sm font-medium text-foreground line-clamp-2">{title}</p> - <p className="text-sm text-muted-foreground mt-1">{error}</p> - </div> - </div> - ); -} - -function AudioLoadingState({ title }: { title: string }) { - return ( - <div className="my-4 max-w-lg overflow-hidden rounded-2xl border bg-muted/30 select-none"> - <div className="px-5 pt-5 pb-4"> - <p className="text-sm font-semibold text-foreground line-clamp-2">{title}</p> - <TextShimmerLoader text="Loading audio" size="sm" /> - </div> - </div> - ); -} - -/** - * Streams the rendered episode and shows its transcript. Works in two modes: - * authenticated (lifecycle stream + detail endpoints) and public shared chat - * (share-token snapshot endpoints), detected from the route. - */ -export function PodcastPlayer({ - podcastId, - title, - durationMs, -}: { - podcastId: number; - title: string; - durationMs?: number; -}) { - const params = useParams(); - const pathname = usePathname(); - const isPublicRoute = pathname?.startsWith("/public/"); - const shareToken = isPublicRoute && typeof params?.token === "string" ? params.token : null; - - const [audioSrc, setAudioSrc] = useState<string | null>(null); - const [transcriptLines, setTranscriptLines] = useState<TranscriptLine[] | null>(null); - const [isLoading, setIsLoading] = useState(true); - const [error, setError] = useState<string | null>(null); - const objectUrlRef = useRef<string | null>(null); - - useEffect(() => { - return () => { - if (objectUrlRef.current) { - URL.revokeObjectURL(objectUrlRef.current); - } - }; - }, []); - - const loadPodcast = useCallback(async () => { - setIsLoading(true); - setError(null); - - try { - if (objectUrlRef.current) { - URL.revokeObjectURL(objectUrlRef.current); - objectUrlRef.current = null; - } - - const controller = new AbortController(); - const timeoutId = setTimeout(() => controller.abort(), 60000); - - try { - let audioBlob: Blob; - let lines: TranscriptLine[] = []; - - if (shareToken) { - const [blob, details] = await Promise.all([ - baseApiService.getBlob(`/api/v1/public/${shareToken}/podcasts/${podcastId}/stream`), - baseApiService.get(`/api/v1/public/${shareToken}/podcasts/${podcastId}`), - ]); - audioBlob = blob; - const parsed = publicPodcastDetailsSchema.safeParse(details); - lines = (parsed.success ? (parsed.data.podcast_transcript ?? []) : []).map( - (entry, turn) => ({ - key: `turn-${turn}`, - label: `Speaker ${entry.speaker_id + 1}`, - text: entry.dialog, - }) - ); - } else { - const [audioResponse, detail] = await Promise.all([ - authenticatedFetch(buildBackendUrl(`/api/v1/podcasts/${podcastId}/stream`), { - method: "GET", - signal: controller.signal, - }), - podcastsApiService.getDetail(podcastId), - ]); - - if (!audioResponse.ok) { - throw new Error(`Failed to load audio: ${audioResponse.status}`); - } - - audioBlob = await audioResponse.blob(); - lines = (detail.transcript?.turns ?? []).map((entry, turn) => ({ - key: `turn-${turn}`, - label: speakerLabel(detail.spec, entry.speaker), - text: entry.text, - })); - } - - const objectUrl = URL.createObjectURL(audioBlob); - objectUrlRef.current = objectUrl; - setAudioSrc(objectUrl); - setTranscriptLines(lines); - } finally { - clearTimeout(timeoutId); - } - } catch (err) { - console.error("Error loading podcast:", err); - if (err instanceof DOMException && err.name === "AbortError") { - setError("Request timed out. Please try again."); - } else { - setError(err instanceof Error ? err.message : "Failed to load podcast"); - } - } finally { - setIsLoading(false); - } - }, [podcastId, shareToken]); - - useEffect(() => { - loadPodcast(); - }, [loadPodcast]); - - if (isLoading) { - return <AudioLoadingState title={title} />; - } - - if (error || !audioSrc) { - return <PodcastErrorState title={title} error={error || "Failed to load audio"} />; - } - - const hasTranscript = transcriptLines && transcriptLines.length > 0; - - return ( - <div className="my-4"> - <Audio - id={`podcast-${podcastId}`} - src={audioSrc} - title={title} - durationMs={durationMs} - className={hasTranscript ? "rounded-b-none border-b-0" : undefined} - /> - {hasTranscript ? ( - <div className="max-w-lg overflow-hidden rounded-b-2xl border border-t-0 bg-muted/30 select-none"> - <div className="mx-5 h-px bg-border/50" /> - <Accordion type="single" collapsible className="px-5"> - <AccordionItem value="transcript" className="border-b-0"> - <AccordionTrigger className="py-3 text-xs sm:text-sm font-medium text-muted-foreground hover:text-accent-foreground hover:no-underline"> - View transcript - </AccordionTrigger> - <AccordionContent className="pb-0"> - <div className="space-y-2 max-h-64 sm:max-h-96 overflow-y-auto select-text"> - {transcriptLines.map((line) => ( - <div key={line.key} className="text-xs sm:text-sm"> - <span className="font-medium text-primary">{line.label}:</span>{" "} - <span className="text-muted-foreground">{line.text}</span> - </div> - ))} - </div> - </AccordionContent> - </AccordionItem> - </Accordion> - </div> - ) : null} - </div> - ); -} diff --git a/surfsense_web/components/tool-ui/podcast/schema.ts b/surfsense_web/components/tool-ui/podcast/schema.ts deleted file mode 100644 index 91937eaad..000000000 --- a/surfsense_web/components/tool-ui/podcast/schema.ts +++ /dev/null @@ -1,33 +0,0 @@ -import { z } from "zod"; -import type { PodcastSpec } from "@/contracts/types/podcast.types"; - -/** - * Tool-call contract for `generate_podcast`. - * - * The tool prepares a podcast and returns immediately with the row awaiting - * brief review; the card then follows the lifecycle by push. Legacy status - * values are accepted so old saved chats still render something sensible. - */ - -export const generatePodcastArgsSchema = z.object({ - source_content: z.string(), - podcast_title: z.string().nullish(), - user_prompt: z.string().nullish(), -}); -export type GeneratePodcastArgs = z.infer<typeof generatePodcastArgsSchema>; - -export const generatePodcastResultSchema = z.object({ - status: z.string(), - podcast_id: z.number().nullish(), - task_id: z.string().nullish(), // legacy Celery id from old saved chats - title: z.string().nullish(), - message: z.string().nullish(), - error: z.string().nullish(), -}); -export type GeneratePodcastResult = z.infer<typeof generatePodcastResultSchema>; - -/** Display name for the speaker bound to `slot`, falling back to a number. */ -export function speakerLabel(spec: PodcastSpec | null | undefined, slot: number): string { - const speaker = spec?.speakers.find((s) => s.slot === slot); - return speaker?.name ?? `Speaker ${slot + 1}`; -} diff --git a/surfsense_web/components/tool-ui/podcast/voice-preview-button.tsx b/surfsense_web/components/tool-ui/podcast/voice-preview-button.tsx deleted file mode 100644 index 989b15e0f..000000000 --- a/surfsense_web/components/tool-ui/podcast/voice-preview-button.tsx +++ /dev/null @@ -1,98 +0,0 @@ -"use client"; - -import { Loader2, Play, Square } from "lucide-react"; -import { useEffect, useRef, useState } from "react"; -import { toast } from "sonner"; -import { Button } from "@/components/ui/button"; -import { podcastsApiService } from "@/lib/apis/podcasts-api.service"; - -// Comparing voices means replaying the same samples, so each voice is fetched -// at most once per page lifetime. -const sampleUrls = new Map<string, Promise<string>>(); - -// Overlapping samples are useless for comparison, so only one plays at a time. -let activeAudio: HTMLAudioElement | null = null; -let stopActive: (() => void) | null = null; - -function getSampleUrl(voiceId: string): Promise<string> { - let url = sampleUrls.get(voiceId); - if (!url) { - url = podcastsApiService.previewVoice(voiceId).then((blob) => URL.createObjectURL(blob)); - // A failed fetch must not poison the cache for retries. - url.catch(() => sampleUrls.delete(voiceId)); - sampleUrls.set(voiceId, url); - } - return url; -} - -/** Plays a short sample of `voiceId` so users pick voices by sound. */ -export function VoicePreviewButton({ voiceId }: { voiceId: string }) { - const [state, setState] = useState<"idle" | "loading" | "playing">("idle"); - const mountedRef = useRef(true); - - useEffect(() => { - mountedRef.current = true; - return () => { - mountedRef.current = false; - if (stopActive && activeAudio?.dataset.voiceId === voiceId) { - stopActive(); - } - }; - }, [voiceId]); - - const stop = () => { - if (stopActive) stopActive(); - }; - - const play = async () => { - stop(); - setState("loading"); - try { - const url = await getSampleUrl(voiceId); - if (!mountedRef.current) return; - - const audio = new Audio(url); - audio.dataset.voiceId = voiceId; - activeAudio = audio; - stopActive = () => { - audio.pause(); - activeAudio = null; - stopActive = null; - if (mountedRef.current) setState("idle"); - }; - audio.onended = () => { - if (activeAudio === audio) { - activeAudio = null; - stopActive = null; - } - if (mountedRef.current) setState("idle"); - }; - await audio.play(); - if (mountedRef.current) setState("playing"); - } catch (error) { - if (mountedRef.current) setState("idle"); - toast.error(error instanceof Error ? error.message : "Couldn't play the voice sample"); - } - }; - - const isPlaying = state === "playing"; - - return ( - <Button - type="button" - variant="ghost" - size="icon" - aria-label={isPlaying ? "Stop voice sample" : "Play voice sample"} - disabled={state === "loading"} - onClick={isPlaying ? stop : play} - > - {state === "loading" ? ( - <Loader2 className="size-4 animate-spin" /> - ) : isPlaying ? ( - <Square className="size-4" /> - ) : ( - <Play className="size-4" /> - )} - </Button> - ); -} diff --git a/surfsense_web/components/tool-ui/sandbox-execute.tsx b/surfsense_web/components/tool-ui/sandbox-execute.tsx index a7633d0ec..3d309332e 100644 --- a/surfsense_web/components/tool-ui/sandbox-execute.tsx +++ b/surfsense_web/components/tool-ui/sandbox-execute.tsx @@ -17,7 +17,7 @@ import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; import { Collapsible, CollapsibleContent, CollapsibleTrigger } from "@/components/ui/collapsible"; import { getBearerToken } from "@/lib/auth-utils"; -import { buildBackendUrl } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; import { cn } from "@/lib/utils"; // ============================================================================ @@ -158,9 +158,7 @@ function truncateCommand(command: string, maxLen = 80): string { async function downloadSandboxFile(threadId: string, filePath: string, fileName: string) { const token = getBearerToken(); - const url = buildBackendUrl(`/api/v1/threads/${threadId}/sandbox/download`, { - path: filePath, - }); + const url = `${BACKEND_URL}/api/v1/threads/${threadId}/sandbox/download?path=${encodeURIComponent(filePath)}`; const res = await fetch(url, { headers: { Authorization: `Bearer ${token || ""}` }, }); diff --git a/surfsense_web/components/tool-ui/video-presentation/generate-video-presentation.tsx b/surfsense_web/components/tool-ui/video-presentation/generate-video-presentation.tsx index 9f2115073..1db8dabb0 100644 --- a/surfsense_web/components/tool-ui/video-presentation/generate-video-presentation.tsx +++ b/surfsense_web/components/tool-ui/video-presentation/generate-video-presentation.tsx @@ -10,7 +10,7 @@ import { TextShimmerLoader } from "@/components/prompt-kit/loader"; import { Button } from "@/components/ui/button"; import { baseApiService } from "@/lib/apis/base-api.service"; import { authenticatedFetch } from "@/lib/auth-utils"; -import { buildBackendUrl } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; import { compileCheck, compileToComponent } from "@/lib/remotion/compile-check"; import { FPS } from "@/lib/remotion/constants"; import { @@ -137,6 +137,7 @@ function VideoPresentationPlayer({ const [isPptxExporting, setIsPptxExporting] = useState(false); const [pptxProgress, setPptxProgress] = useState<string | null>(null); + const backendUrl = BACKEND_URL ?? ""; const audioBlobUrlsRef = useRef<string[]>([]); const loadPresentation = useCallback(async () => { @@ -176,7 +177,7 @@ function VideoPresentationPlayer({ title: scene.title ?? slide.title, code: scene.code, durationInFrames, - audioUrl: slide.audio_url ? buildBackendUrl(slide.audio_url) : undefined, + audioUrl: slide.audio_url ? `${backendUrl}${slide.audio_url}` : undefined, }); } @@ -221,7 +222,7 @@ function VideoPresentationPlayer({ } finally { setIsLoading(false); } - }, [presentationId, shareToken]); + }, [presentationId, backendUrl, shareToken]); useEffect(() => { loadPresentation(); diff --git a/surfsense_web/content/docs/docker-installation/dev-compose.mdx b/surfsense_web/content/docs/docker-installation/dev-compose.mdx index e16c6a685..599e9beb2 100644 --- a/surfsense_web/content/docs/docker-installation/dev-compose.mdx +++ b/surfsense_web/content/docs/docker-installation/dev-compose.mdx @@ -10,11 +10,7 @@ cd SurfSense/docker docker compose -f docker-compose.dev.yml up --build ``` -This file builds the backend and frontend from your local source code (instead -of pulling prebuilt images) and includes pgAdmin for database inspection at -[http://localhost:5050](http://localhost:5050). It intentionally keeps raw -frontend, backend, and zero-cache ports published for debugging. Use the -production `docker-compose.yml` for the default Caddy single-origin setup. +This file builds the backend and frontend from your local source code (instead of pulling prebuilt images) and includes pgAdmin for database inspection at [http://localhost:5050](http://localhost:5050). Use the production `docker-compose.yml` for all other cases. ## Dev-Only Environment Variables @@ -26,14 +22,9 @@ The following `.env` variables are **only used by the dev compose file** (they h | `PGADMIN_DEFAULT_EMAIL` | pgAdmin login email | `admin@surfsense.com` | | `PGADMIN_DEFAULT_PASSWORD` | pgAdmin login password | `surfsense` | | `REDIS_PORT` | Exposed Redis port (internal-only in prod) | `6379` | -| `AUTH_TYPE` | Runtime auth mode | `LOCAL` | -| `ETL_SERVICE` | Runtime document parsing service | `DOCLING` | -| `DEPLOYMENT_MODE` | Runtime deployment mode | `self-hosted` | -| `ZERO_CACHE_PORT` | Exposed zero-cache port for debugging | `4848` | +| `NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE` | Frontend build arg for auth type | `LOCAL` | +| `NEXT_PUBLIC_ETL_SERVICE` | Frontend build arg for ETL service | `DOCLING` | +| `NEXT_PUBLIC_ZERO_CACHE_URL` | Frontend build arg for Zero-cache URL | `http://localhost:4848` | +| `NEXT_PUBLIC_DEPLOYMENT_MODE` | Frontend build arg for deployment mode | `self-hosted` | -In the production compose file, the frontend reads `AUTH_TYPE`, `ETL_SERVICE`, -and `DEPLOYMENT_MODE` at request time. Browser API and Zero traffic are -same-origin relative through bundled Caddy. -Production Docker exposes only the bundled Caddy proxy by default; dev compose -keeps direct service ports so contributors can inspect and restart individual -services without going through the proxy. +In the production compose file, the `NEXT_PUBLIC_*` frontend variables are automatically derived from `AUTH_TYPE`, `ETL_SERVICE`, and the port settings. In the dev compose file, they are passed as build args since the frontend is built from source. diff --git a/surfsense_web/content/docs/docker-installation/docker-compose.mdx b/surfsense_web/content/docs/docker-installation/docker-compose.mdx index bf71c077b..60b5e67b6 100644 --- a/surfsense_web/content/docs/docker-installation/docker-compose.mdx +++ b/surfsense_web/content/docs/docker-installation/docker-compose.mdx @@ -15,9 +15,9 @@ docker compose up -d After starting, access SurfSense at: -- **SurfSense**: [http://localhost:3929](http://localhost:3929) -- **Backend API**: [http://localhost:3929/api/v1](http://localhost:3929/api/v1) -- **Zero sync**: `ws://localhost:3929/zero` +- **Frontend**: [http://localhost:3929](http://localhost:3929) +- **Backend API**: [http://localhost:8929](http://localhost:8929) +- **API Docs**: [http://localhost:8929/docs](http://localhost:8929/docs) --- ## Configuration @@ -99,59 +99,24 @@ docker run -d --name watchtower \ SurfSense containers are labeled for Watchtower, so `--label-enable` limits updates to the SurfSense services. -### Public URL and Ports +### Ports | Variable | Description | Default | |----------|-------------|---------| -| `SURFSENSE_PUBLIC_URL` | Public origin used by the frontend, backend OAuth callbacks, and Zero browser URL | `http://localhost:3929` | -| `SURFSENSE_SITE_ADDRESS` | Caddy site address. `:80` means local plain HTTP; a hostname enables automatic HTTPS | `:80` | -| `LISTEN_HTTP_PORT` | Host port mapped to Caddy's HTTP listener | `3929` | -| `LISTEN_HTTPS_PORT` | Host port mapped to Caddy's HTTPS listener for domain mode | `443` | +| `FRONTEND_PORT` | Frontend service port | `3929` | +| `BACKEND_PORT` | Backend API service port | `8929` | +| `ZERO_CACHE_PORT` | Zero-cache real-time sync port | `5929` | -SurfSense includes Caddy by default. The `frontend`, `backend`, and -`zero-cache` containers are internal-only in the production compose file; the -browser reaches them through Caddy path routing. +### Custom Domain / Reverse Proxy -### Custom Domain / Automatic HTTPS - -For a real domain, point DNS at the Docker host and set: - -```dotenv -SURFSENSE_SITE_ADDRESS=surf.example.com -LISTEN_HTTP_PORT=80 -LISTEN_HTTPS_PORT=443 -CERT_EMAIL=you@example.com -SURFSENSE_PUBLIC_URL=https://surf.example.com -``` - -Caddy will issue and renew Let's Encrypt certificates automatically. Ports 80 -and 443 must be reachable from the internet for the default HTTP-01 challenge. +Only set these if serving SurfSense on a real domain via a reverse proxy (Caddy, Nginx, Cloudflare Tunnel, etc.). Leave commented out for standard localhost deployments. | Variable | Description | |----------|-------------| -| `CERT_EMAIL` | Optional ACME contact email | -| `CERT_ACME_CA` | ACME directory URL; use Let's Encrypt staging when testing cert issuance | -| `CERT_ACME_DNS` | DNS-01 challenge config; requires the custom Caddy build | -| `TRUSTED_PROXIES` | CIDR ranges trusted for forwarded client IP headers | -| `SURFSENSE_MAX_BODY_SIZE` | Upload limit enforced at the proxy | - -### Bring Your Own Proxy - -If you already run nginx, Traefik, Cloudflare Tunnel, or another ingress, you -can comment out the `proxy` service and route traffic to the internal services -with the same path contract: - -| Public path | Upstream | -|-------------|----------| -| `/auth/*` | `backend:8000` | -| `/api/v1/*` | `backend:8000` | -| `/zero/*` | `zero-cache:4848` | -| `/*` | `frontend:3000` | - -Alternative proxies must preserve WebSocket upgrades for `/zero`, avoid -buffering streaming responses, allow long-running requests, and support large -uploads. For DNS-01 or wildcard certificates with Caddy, build -`docker/proxy/Dockerfile` and set `CERT_ACME_DNS` for your DNS provider. +| `NEXT_FRONTEND_URL` | Public frontend URL (e.g. `https://app.yourdomain.com`) | +| `BACKEND_URL` | Public backend URL for OAuth callbacks (e.g. `https://api.yourdomain.com`) | +| `NEXT_PUBLIC_FASTAPI_BACKEND_URL` | Backend URL used by the frontend (e.g. `https://api.yourdomain.com`) | +| `NEXT_PUBLIC_ZERO_CACHE_URL` | Zero-cache URL used by the frontend (e.g. `https://zero.yourdomain.com`) | ### Zero-cache (Real-Time Sync) @@ -200,10 +165,7 @@ Create credentials at the [Google Cloud Console](https://console.cloud.google.co ### Connector OAuth Keys -Uncomment the connectors you want to use. Redirect URIs follow the single-origin -pattern `${SURFSENSE_PUBLIC_URL}/api/v1/auth/<connector>/connector/callback`. -For local Docker defaults, that means -`http://localhost:3929/api/v1/auth/<connector>/connector/callback`. +Uncomment the connectors you want to use. Redirect URIs follow the pattern `http://localhost:8000/api/v1/auth/<connector>/connector/callback`. | Connector | Variables | |-----------|-----------| @@ -256,7 +218,6 @@ for full setup. | Service | Description | |---------|-------------| -| `proxy` | Caddy reverse proxy; the only public ingress in production Docker | | `db` | PostgreSQL with pgvector extension | | `migrations` | Short-lived: runs `alembic upgrade head` and verifies `zero_publication`, then exits | | `redis` | Message broker for Celery | @@ -265,7 +226,7 @@ for full setup. | `celery_worker` | Background task processing (document indexing, etc.) | | `celery_beat` | Periodic task scheduler (connector sync) | | `zero-cache` | Rocicorp Zero real-time sync (replicates Postgres to clients) | -| `frontend` | Next.js web application, internal behind Caddy | +| `frontend` | Next.js web application | All services start automatically with `docker compose up -d`. @@ -331,9 +292,9 @@ docker compose down -v ## Troubleshooting -- **Port already in use**: Change `LISTEN_HTTP_PORT` in `.env` and restart. In domain mode, use ports `80` and `443` so Caddy can complete certificate issuance. +- **Ports already in use**: Change the relevant `*_PORT` variable in `.env` and restart. - **Permission errors on Linux**: You may need to prefix `docker` commands with `sudo`. -- **Real-time updates not working**: Open DevTools → Console and check for WebSocket errors. In production Docker the expected URL is `${SURFSENSE_PUBLIC_URL}/zero`. +- **Real-time updates not working**: Open DevTools → Console and check for WebSocket errors. Verify `NEXT_PUBLIC_ZERO_CACHE_URL` matches the running zero-cache address. - **Line ending issues on Windows**: Run `git config --global core.autocrlf true` before cloning. ### Migration service exited non-zero diff --git a/surfsense_web/content/docs/docker-installation/install-script.mdx b/surfsense_web/content/docs/docker-installation/install-script.mdx index fb7e6b5b6..9f8acf9e5 100644 --- a/surfsense_web/content/docs/docker-installation/install-script.mdx +++ b/surfsense_web/content/docs/docker-installation/install-script.mdx @@ -74,27 +74,7 @@ If Watchtower is enabled, it preserves the running image variant tag automatical After starting, access SurfSense at: -- **SurfSense**: [http://localhost:3929](http://localhost:3929) -- **Backend API**: [http://localhost:3929/api/v1](http://localhost:3929/api/v1) -- **Zero sync**: `ws://localhost:3929/zero` - -The installer uses the bundled Caddy reverse proxy by default. The backend and -zero-cache containers are not published on separate host ports in the production -stack. - -For a custom domain, edit `surfsense/.env` after installation: - -```dotenv -SURFSENSE_SITE_ADDRESS=surf.example.com -LISTEN_HTTP_PORT=80 -LISTEN_HTTPS_PORT=443 -CERT_EMAIL=you@example.com -SURFSENSE_PUBLIC_URL=https://surf.example.com -``` - -Then run: - -```bash -cd surfsense -docker compose up -d --wait -``` +- **Frontend**: [http://localhost:3929](http://localhost:3929) +- **Backend API**: [http://localhost:8929](http://localhost:8929) +- **API Docs**: [http://localhost:8929/docs](http://localhost:8929/docs) +- **Zero-cache**: [http://localhost:5929](http://localhost:5929) diff --git a/surfsense_web/content/docs/how-to/meta.json b/surfsense_web/content/docs/how-to/meta.json index 477fcafc4..329b7172e 100644 --- a/surfsense_web/content/docs/how-to/meta.json +++ b/surfsense_web/content/docs/how-to/meta.json @@ -1,6 +1,6 @@ { "title": "How to", - "pages": ["zero-sync", "realtime-collaboration", "web-search"], + "pages": ["zero-sync", "realtime-collaboration", "web-search", "ollama"], "icon": "Compass", "defaultOpen": false } diff --git a/surfsense_web/content/docs/how-to/ollama.mdx b/surfsense_web/content/docs/how-to/ollama.mdx new file mode 100644 index 000000000..48b231705 --- /dev/null +++ b/surfsense_web/content/docs/how-to/ollama.mdx @@ -0,0 +1,90 @@ +--- +title: Connect Ollama +description: Simple setup guide for using Ollama with SurfSense across local, Docker, remote, and cloud setups +--- + +# Connect Ollama + +Use this page to choose the correct **API Base URL** when adding an Ollama provider in SurfSense. + +## 1) Pick your API Base URL + +| Ollama location | SurfSense location | API Base URL | +|---|---|---| +| Same machine | No Docker | `http://localhost:11434` | +| Host machine (macOS/Windows) | Docker Desktop | `http://host.docker.internal:11434` | +| Host machine (Linux) | Docker Compose | `http://host.docker.internal:11434` | +| Same Docker Compose stack | Docker Compose | `http://ollama:11434` | +| Another machine in your network | Any | `http://<lan-ip>:11434` | +| Public Ollama endpoint / proxy / cloud | Any | `http(s)://<your-domain-or-endpoint>` | + +If SurfSense runs in Docker, do not use `localhost` unless Ollama is in the same container. + +## 2) Add Ollama in SurfSense + +Go to **Search Space Settings -> Agent Models -> Add Model** and set: + +- Provider: `OLLAMA` +- Model name: your model tag, for example `llama3.2` or `qwen3:8b` +- API Base URL: from the table above +- API key: + - local/self-hosted Ollama: any non-empty value + - Ollama cloud/proxied auth: real key or token required by that endpoint + +Save. SurfSense validates the connection immediately. + +## 3) Common setups + +### A) SurfSense in Docker Desktop, Ollama on your host + +Use: + +```text +http://host.docker.internal:11434 +``` + +### B) Ollama as a service in the same Compose + +Use API Base URL: + +```text +http://ollama:11434 +``` + +Minimal service example: + +```yaml +ollama: + image: ollama/ollama:latest + volumes: + - ollama_data:/root/.ollama + ports: + - "11434:11434" +``` + +### C) Ollama on another machine + +Ollama binds to `127.0.0.1` by default. Make it reachable on the network: + +- Set `OLLAMA_HOST=0.0.0.0:11434` on the machine/service running Ollama +- Open firewall port `11434` +- Use `http://<lan-ip>:11434` in SurfSense's API Base URL + +## 4) Quick troubleshooting + +| Error | Cause | Fix | +|---|---|---| +| `Cannot connect to host localhost:11434` | Wrong URL from Dockerized backend | Use `host.docker.internal` or `ollama` | +| `Cannot connect to host <lan-ip>:11434` | Ollama not exposed on network or firewall blocked | Set `OLLAMA_HOST=0.0.0.0:11434`, allow port 11434 | +| URL starts with `/%20http://...` | Leading space in URL | Re-enter API Base URL without spaces | +| `model not found` | Model not pulled on Ollama | Run `ollama pull <model>` | + +If needed, test from the backend container using the same host you put in **API Base URL**: + +```bash +docker compose exec backend curl -v <YOUR_API_BASE_URL>/api/tags +``` + +## See also + +- [Docker Installation](/docs/docker-installation/docker-compose) \ No newline at end of file diff --git a/surfsense_web/content/docs/how-to/zero-sync.mdx b/surfsense_web/content/docs/how-to/zero-sync.mdx index 728da9b86..7007e6637 100644 --- a/surfsense_web/content/docs/how-to/zero-sync.mdx +++ b/surfsense_web/content/docs/how-to/zero-sync.mdx @@ -32,10 +32,10 @@ zero-cache is included in the Docker Compose setup. The key environment variable | Variable | Description | Default | |----------|-------------|---------| -| `SURFSENSE_PUBLIC_URL` | Public SurfSense origin used by the browser | `http://localhost:3929` | +| `ZERO_CACHE_PORT` | Port for the zero-cache service | `5929` (prod) / `4848` (dev) | | `ZERO_ADMIN_PASSWORD` | Password for the zero-cache admin UI and `/statz` endpoint | `surfsense-zero-admin` | | `ZERO_UPSTREAM_DB` | PostgreSQL connection URL for replication | Built from `DB_*` vars | -| `/zero` | Same-origin browser path Caddy routes to zero-cache | `${SURFSENSE_PUBLIC_URL}/zero` | +| `NEXT_PUBLIC_ZERO_CACHE_URL` | URL the frontend uses to connect to zero-cache | `http://localhost:<ZERO_CACHE_PORT>` | | `ZERO_APP_PUBLICATIONS` | PostgreSQL publication restricting which tables are replicated | `zero_publication` | | `ZERO_NUM_SYNC_WORKERS` | Number of view-sync worker processes. Must be ≤ `ZERO_UPSTREAM_MAX_CONNS` and ≤ `ZERO_CVR_MAX_CONNS` | `4` | | `ZERO_UPSTREAM_MAX_CONNS` | Max connections to upstream PostgreSQL for mutations | `20` | @@ -64,18 +64,14 @@ If running the frontend outside Docker (e.g. `pnpm dev`), you need: ``` Run `uv run alembic upgrade head` from `surfsense_backend/` **before** starting this container so the `zero_publication` exists. -2. If the frontend is not behind bundled Caddy, set `NEXT_PUBLIC_ZERO_CACHE_URL=http://localhost:4848` before building/running the frontend so the browser connects directly to zero-cache. +2. **`NEXT_PUBLIC_ZERO_CACHE_URL`** set in `surfsense_web/.env` (default: `http://localhost:4848`). 3. **`wal_level = logical`** in your PostgreSQL config (see [Manual Installation → Configure PostgreSQL for Zero Sync](/docs/manual-installation#3-configure-postgresql-for-zero-sync)). For the full manual setup walkthrough, see the [Manual Installation guide](/docs/manual-installation). ### Custom Domain / Reverse Proxy -The production Docker stack includes Caddy by default. Zero is exposed under the -same public origin as the app at `${SURFSENSE_PUBLIC_URL}/zero`, for example -`https://surf.example.com/zero`. Zero accepts this single path-component base -URL, so Caddy forwards `/zero/*` to the internal `zero-cache:4848` service -without stripping the prefix. +When deploying behind a reverse proxy, set `NEXT_PUBLIC_ZERO_CACHE_URL` to your public zero-cache URL (e.g., `https://zero.yourdomain.com`). The zero-cache service must be accessible via WebSocket from the browser. ### Database Requirements @@ -114,7 +110,7 @@ Zero syncs the following tables for real-time features: - **zero-cache not starting**: Check `docker compose logs zero-cache`. Ensure PostgreSQL has `wal_level=logical` (configured in `postgresql.conf`). - **"Insufficient upstream connections" error**: zero-cache defaults `ZERO_NUM_SYNC_WORKERS` to the number of CPU cores, which can exceed connection pool limits on high-core machines. Lower `ZERO_NUM_SYNC_WORKERS` or raise `ZERO_UPSTREAM_MAX_CONNS` / `ZERO_CVR_MAX_CONNS` in your `.env`. -- **Frontend not syncing**: Open DevTools → Console and check for WebSocket connection errors. In production Docker, verify Caddy serves `${SURFSENSE_PUBLIC_URL}/zero`. In manual local development, verify `NEXT_PUBLIC_ZERO_CACHE_URL` points at the running zero-cache port. +- **Frontend not syncing**: Open DevTools → Console and check for WebSocket connection errors. Verify `NEXT_PUBLIC_ZERO_CACHE_URL` matches the running zero-cache address. - **Stale data after restart**: zero-cache rebuilds its SQLite replica from PostgreSQL on startup. This may take a moment for large databases. ## Learn More diff --git a/surfsense_web/content/docs/index.mdx b/surfsense_web/content/docs/index.mdx index c8540fed0..4a321b376 100644 --- a/surfsense_web/content/docs/index.mdx +++ b/surfsense_web/content/docs/index.mdx @@ -5,7 +5,7 @@ icon: BookOpen --- import { Card, Cards } from 'fumadocs-ui/components/card'; -import { ClipboardCheck, Download, Container, Wrench, Cable, BookOpen, FlaskConical, Heart, MessageCircle, Cpu } from 'lucide-react'; +import { ClipboardCheck, Download, Container, Wrench, Cable, BookOpen, FlaskConical, Heart, MessageCircle } from 'lucide-react'; Welcome to **SurfSense's Documentation!** Here, you'll find everything you need to get the most out of SurfSense. Dive in to explore how SurfSense can be your AI-powered research companion. @@ -34,12 +34,6 @@ Welcome to **SurfSense's Documentation!** Here, you'll find everything you need description="Set up SurfSense manually from source" href="/docs/manual-installation" /> - <Card - icon={<Cpu />} - title="Local Models" - description="Connect local model servers" - href="/docs/local-models" - /> <Card icon={<Cable />} title="Connectors" diff --git a/surfsense_web/content/docs/local-models/index.mdx b/surfsense_web/content/docs/local-models/index.mdx deleted file mode 100644 index e1839a4f0..000000000 --- a/surfsense_web/content/docs/local-models/index.mdx +++ /dev/null @@ -1,30 +0,0 @@ ---- -title: Local Models -description: Connect local model servers to SurfSense ---- - -import { Card, Cards } from 'fumadocs-ui/components/card'; - -# Local Models - -SurfSense can use local model servers such as Ollama and LM Studio. - -The API Base URL is read by the SurfSense backend. If SurfSense runs in Docker, use an address the backend container can reach. - -<Cards> - <Card - title="Ollama" - description="Connect an Ollama server and discover local model tags" - href="/docs/local-models/ollama" - /> - <Card - title="LM Studio" - description="Connect an LM Studio local server" - href="/docs/local-models/lm-studio" - /> - <Card - title="Other Local Servers" - description="Connect llama.cpp, vLLM, LocalAI, LiteLLM Proxy, and more" - href="/docs/local-models/other-local-servers" - /> -</Cards> diff --git a/surfsense_web/content/docs/local-models/lm-studio.mdx b/surfsense_web/content/docs/local-models/lm-studio.mdx deleted file mode 100644 index 6877786e5..000000000 --- a/surfsense_web/content/docs/local-models/lm-studio.mdx +++ /dev/null @@ -1,92 +0,0 @@ ---- -title: LM Studio -description: Connect LM Studio to SurfSense ---- - -# Connect LM Studio - -Connect to your LM Studio local server. Add it from -**Search Space Settings > Models**. - -## Base URL - -LM Studio uses an OpenAI compatible server. - -### SurfSense Runs in Docker - -Use this when SurfSense is running from Docker and LM Studio is running on your computer. - -```text -http://host.docker.internal:1234/v1 -``` - -<Callout type="info"> -This is the default in SurfSense. -</Callout> - -### SurfSense Runs Without Docker - -Use this when SurfSense and LM Studio both run directly on the same computer. - -```text -http://localhost:1234/v1 -``` - -### LM Studio Runs on Another Computer - -Use this when LM Studio is running on another machine in your network. - -```text -http://<host>:1234/v1 -``` - -Replace `<host>` with the LAN IP or domain for that machine. - -## LM Studio Setup - -1. Open LM Studio. -2. Load a model. -3. Start the local server. -4. Confirm the server listens on port `1234`. - -## Add the Connection - -1. Open Search Space Settings. -2. Go to Models. -3. Select LM Studio. -4. Set API Base URL. -5. Leave API Key empty unless your server requires one. -6. Select the models you want to enable. -7. Save the connection. - -SurfSense discovers models from `/v1/models`. If you enter the URL without `/v1`, SurfSense adds it for requests. - -## Verify - -From the host: - -```bash -curl http://localhost:1234/v1/models -``` - -From the SurfSense backend container: - -```bash -docker compose exec backend curl http://host.docker.internal:1234/v1/models -``` - -## Troubleshooting - -### Connection refused - -LM Studio is not reachable from the backend. - -Start the LM Studio server and confirm that port `1234` is open. - -### No models found - -Load a model in LM Studio, then refresh model discovery in SurfSense. - -### Endpoint returned 404 - -Use an OpenAI compatible server URL. The models endpoint must be available at `/v1/models`. diff --git a/surfsense_web/content/docs/local-models/meta.json b/surfsense_web/content/docs/local-models/meta.json deleted file mode 100644 index d904a8d5c..000000000 --- a/surfsense_web/content/docs/local-models/meta.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "title": "Local Models", - "pages": ["ollama", "lm-studio", "other-local-servers"], - "icon": "Cpu", - "defaultOpen": false -} diff --git a/surfsense_web/content/docs/local-models/ollama.mdx b/surfsense_web/content/docs/local-models/ollama.mdx deleted file mode 100644 index b062d98b0..000000000 --- a/surfsense_web/content/docs/local-models/ollama.mdx +++ /dev/null @@ -1,102 +0,0 @@ ---- -title: Ollama -description: Connect Ollama to SurfSense ---- - -# Connect Ollama - -Connect to your Ollama local server. Add it from -**Search Space Settings > Models**. - -## Base URL - -Choose the URL from where the SurfSense backend runs. - -### SurfSense Runs in Docker - -Use this when SurfSense is running from Docker and Ollama is running on your computer. - -```text -http://host.docker.internal:11434 -``` - -<Callout type="info"> -This is the default in SurfSense. -</Callout> - -### Ollama Runs in Docker - -Use this only when Ollama is a service in the same Compose stack as SurfSense. - -```text -http://ollama:11434 -``` - -### SurfSense Runs Without Docker - -Use this when SurfSense and Ollama both run directly on the same computer. - -```text -http://localhost:11434 -``` - -### Ollama Runs on Another Computer - -Use this when Ollama is running on another machine in your network. - -```text -http://<host>:11434 -``` - -Replace `<host>` with the LAN IP or domain for that machine. - -## Add the Connection - -1. Open Search Space Settings. -2. Go to Models. -3. Select Ollama. -4. Set API Base URL. -5. Leave API Key empty unless your endpoint needs one. -6. Select the models you want to enable. -7. Save the connection. - -Do not add `/v1` to the URL. SurfSense uses Ollama native routes such as `/api/version` and `/api/tags`. - -## Verify - -From the host: - -```bash -curl http://localhost:11434/api/version -``` - -From the SurfSense backend container: - -```bash -docker compose exec backend curl http://host.docker.internal:11434/api/version -docker compose exec backend curl http://host.docker.internal:11434/api/tags -``` - -## Troubleshooting - -### Name or service not known - -The backend cannot resolve the host name. - -Use `http://host.docker.internal:11434` unless you run Ollama as a Compose service named `ollama`. - -### Connection refused - -Ollama is not reachable from the backend. - -Start Ollama and confirm that port `11434` is open. - -### No models found - -Pull at least one model: - -```bash -ollama pull llama3.2 -``` - -Then refresh model discovery in SurfSense. diff --git a/surfsense_web/content/docs/local-models/other-local-servers.mdx b/surfsense_web/content/docs/local-models/other-local-servers.mdx deleted file mode 100644 index 669684929..000000000 --- a/surfsense_web/content/docs/local-models/other-local-servers.mdx +++ /dev/null @@ -1,109 +0,0 @@ ---- -title: Other Local Servers -description: Connect local OpenAI compatible model servers ---- - -# Connect Other Local Servers - -Connect to llama.cpp, vLLM, LocalAI, LiteLLM Proxy, and other servers -that expose OpenAI compatible routes. - -SurfSense discovers models from: - -```text -/v1/models -``` - -Chat requests use the same `/v1` base URL. - -## Pick Your Setup - -Use one of these URL patterns. - -### SurfSense Runs in Docker - -Use this when SurfSense is running from Docker and the model server is running on your computer. - -```text -http://host.docker.internal:<port>/v1 -``` - -Common ports: - -| Server | Port | -|---|---| -| llama.cpp | `10000` | -| vLLM | `8000` | -| LocalAI | `8080` | -| LiteLLM Proxy | `4000` | -| text-generation-webui | `5000` | - -### SurfSense Runs Without Docker - -Use this when SurfSense and the model server both run directly on the same computer. - -```text -http://localhost:<port>/v1 -``` - -### Model Server Runs on Another Computer - -Use this when the model server is running on another machine in your network. - -```text -http://<host>:<port>/v1 -``` - -## Add the Connection - -1. Open Search Space Settings. -2. Go to Models. -3. Select OpenAI Compatible. -4. Set API Base URL. -5. Add an API Key only if your server requires one. -6. Select the models you want to enable. -7. Save the connection. - -If you enter the URL without `/v1`, SurfSense adds `/v1` for requests. - -## Verify - -From the host: - -```bash -curl http://localhost:<port>/v1/models -``` - -From the SurfSense backend container: - -```bash -docker compose exec backend curl http://host.docker.internal:<port>/v1/models -``` - -A working server returns JSON with a `data` array. - -## When Not to Use This - -Use the Ollama provider for Ollama. It uses native routes such as `/api/tags`. - -Use the LM Studio provider for LM Studio. Its default URL is already set. - -## Troubleshooting - -### Endpoint returned 404 - -The server does not expose `/v1/models`. - -Enable the server's OpenAI compatible mode. - -### Connection refused - -The backend cannot reach the server. - -Check that the server is running and that the port is open. - -### No models found - -The server returned an empty model list. - -Load or serve a model, then refresh model discovery in SurfSense. diff --git a/surfsense_web/content/docs/manual-installation.mdx b/surfsense_web/content/docs/manual-installation.mdx index ab09b4155..22a8ff5a1 100644 --- a/surfsense_web/content/docs/manual-installation.mdx +++ b/surfsense_web/content/docs/manual-installation.mdx @@ -546,10 +546,7 @@ cd ../docker docker compose -f docker-compose.deps-only.yml up -d ``` -The deps-only stack exposes zero-cache on port `4848` by default. If your -frontend is not behind a reverse proxy that serves `/zero`, set -`NEXT_PUBLIC_ZERO_CACHE_URL=http://localhost:4848` before building/running the -frontend. +The deps-only stack exposes zero-cache on port `4848` by default. Keep `NEXT_PUBLIC_ZERO_CACHE_URL=http://localhost:4848` in your `surfsense_web/.env`. ## Frontend Setup @@ -580,13 +577,12 @@ Copy-Item -Path .env.example -Destination .env Edit the `.env` file and set: -| ENV VARIABLE | DESCRIPTION | -| --- | --- | -| `SURFSENSE_BACKEND_INTERNAL_URL` | Backend URL used by Next.js server routes, e.g. `http://localhost:8000` or `http://backend:8000` in Docker | -| `AUTH_TYPE` | Same value as backend auth type: `GOOGLE` for OAuth with Google, `LOCAL` for email/password authentication | -| `ETL_SERVICE` | Document parsing service (should match backend ETL_SERVICE): `UNSTRUCTURED`, `LLAMACLOUD`, or `DOCLING`; affects supported file formats in the upload interface | -| `DEPLOYMENT_MODE` | `self-hosted` or `cloud`; controls self-hosted-only connector visibility | -| `NEXT_PUBLIC_ZERO_CACHE_URL` | Only needed when the browser cannot reach Zero through same-origin `/zero`, e.g. manual local dev at `http://localhost:4848` | +| ENV VARIABLE | DESCRIPTION | +| ------------------------------- | ------------------------------------------- | +| NEXT_PUBLIC_FASTAPI_BACKEND_URL | Backend URL (e.g., `http://localhost:8000`) | +| NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE | Same value as set in backend AUTH_TYPE i.e `GOOGLE` for OAuth with Google, `LOCAL` for email/password authentication | +| NEXT_PUBLIC_ETL_SERVICE | Document parsing service (should match backend ETL_SERVICE): `UNSTRUCTURED`, `LLAMACLOUD`, or `DOCLING` - affects supported file formats in upload interface | +| NEXT_PUBLIC_ZERO_CACHE_URL | URL for Zero-cache real-time sync service (e.g., `http://localhost:4848`) | ### 2. Install Dependencies @@ -697,7 +693,7 @@ To verify your installation: - **Authentication Problems**: Check your Google OAuth configuration and ensure redirect URIs are set correctly - **LLM Errors**: Confirm your LLM API keys are valid and the selected models are accessible - **File Upload Failures**: Validate your ETL service API key (Unstructured.io or LlamaCloud) or ensure Docling is properly configured -- **Real-time updates not working / stale UI**: Verify zero-cache is running (`curl http://localhost:4848/keepalive` returns 200). Open browser DevTools → Console and look for WebSocket errors. In default Docker, confirm `/zero` is routed by Caddy. In manual local development, confirm `NEXT_PUBLIC_ZERO_CACHE_URL` in `surfsense_web/.env` matches the running zero-cache address. +- **Real-time updates not working / stale UI**: Verify zero-cache is running (`curl http://localhost:4848/keepalive` returns 200). Open browser DevTools → Console and look for WebSocket errors. Confirm `NEXT_PUBLIC_ZERO_CACHE_URL` in `surfsense_web/.env` matches the running zero-cache address. - **Zero-cache stuck on `Unknown or invalid publications. Specified: [zero_publication]`**: You skipped (or never ran) `uv run alembic upgrade head` from `surfsense_backend/`. Run it, then restart the zero-cache container with `docker restart surfsense-zero-cache`. - **Zero-cache crashes with `_zero.tableMetadata` errors**: A previous run left a half-built SQLite replica behind. Stop the container, remove the volume, and start fresh: `docker rm -f surfsense-zero-cache && docker volume rm surfsense-zero-cache && docker run -d ...` (re-run the command from [Zero-Cache Setup](#zero-cache-setup)). - **`wal_level` is not set to `logical`**: zero-cache requires logical replication. Set `wal_level = logical` in `postgresql.conf`, restart PostgreSQL, and verify with `SHOW wal_level;` in psql. diff --git a/surfsense_web/content/docs/messaging-channels/docker.mdx b/surfsense_web/content/docs/messaging-channels/docker.mdx index 63acfa0a8..3a4d4177f 100644 --- a/surfsense_web/content/docs/messaging-channels/docker.mdx +++ b/surfsense_web/content/docs/messaging-channels/docker.mdx @@ -15,19 +15,17 @@ wired by Compose. ## Public URLs For localhost-only testing, the defaults are enough for the SurfSense UI, but -public webhooks from Telegram, WhatsApp, and Slack require a public HTTPS -SurfSense URL. Use your deployed domain or a tunnel such as Cloudflare Tunnel -or ngrok. +public webhooks from Telegram, WhatsApp, and Slack require a public HTTPS backend +URL. Use your deployed backend URL or a tunnel such as Cloudflare Tunnel or +ngrok. -When using a custom domain or tunnel with the bundled Caddy proxy, set: +When using a custom domain or tunnel, set: ```bash -SURFSENSE_PUBLIC_URL=https://surf.example.com -SURFSENSE_SITE_ADDRESS=surf.example.com -LISTEN_HTTP_PORT=80 -LISTEN_HTTPS_PORT=443 -CERT_EMAIL=you@example.com -GATEWAY_BASE_URL=https://surf.example.com +BACKEND_URL=https://api.example.com +GATEWAY_BASE_URL=https://api.example.com +NEXT_FRONTEND_URL=https://app.example.com +NEXT_PUBLIC_FASTAPI_BACKEND_URL=https://api.example.com ``` ## Environment Variables diff --git a/surfsense_web/content/docs/meta.json b/surfsense_web/content/docs/meta.json index 435e49f9f..74be10600 100644 --- a/surfsense_web/content/docs/meta.json +++ b/surfsense_web/content/docs/meta.json @@ -9,7 +9,6 @@ "installation", "manual-installation", "docker-installation", - "local-models", "messaging-channels", "connectors", "how-to", diff --git a/surfsense_web/contracts/enums/image-gen-providers.ts b/surfsense_web/contracts/enums/image-gen-providers.ts new file mode 100644 index 000000000..8410aeb4b --- /dev/null +++ b/surfsense_web/contracts/enums/image-gen-providers.ts @@ -0,0 +1,105 @@ +export interface ImageGenProvider { + value: string; + label: string; + example: string; + description: string; + apiBase?: string; +} + +/** + * Image generation providers supported by LiteLLM. + * See: https://docs.litellm.ai/docs/image_generation#supported-providers + */ +export const IMAGE_GEN_PROVIDERS: ImageGenProvider[] = [ + { + value: "OPENAI", + label: "OpenAI", + example: "dall-e-3, gpt-image-1, dall-e-2", + description: "DALL-E and GPT Image models", + }, + { + value: "AZURE_OPENAI", + label: "Azure OpenAI", + example: "azure/dall-e-3, azure/gpt-image-1", + description: "OpenAI image models on Azure", + }, + { + value: "GOOGLE", + label: "Google AI Studio", + example: "gemini/imagen-3.0-generate-002", + description: "Google AI Studio image generation", + }, + { + value: "VERTEX_AI", + label: "Google Vertex AI", + example: "vertex_ai/imagegeneration@006", + description: "Vertex AI image generation models", + }, + { + value: "BEDROCK", + label: "AWS Bedrock", + example: "bedrock/stability.stable-diffusion-xl-v0", + description: "Stable Diffusion on AWS Bedrock", + }, + { + value: "RECRAFT", + label: "Recraft", + example: "recraft/recraftv3", + description: "AI-powered design and image generation", + }, + { + value: "OPENROUTER", + label: "OpenRouter", + example: "openrouter/google/gemini-2.5-flash-image", + description: "Image generation via OpenRouter", + }, + { + value: "XINFERENCE", + label: "Xinference", + example: "xinference/stable-diffusion-xl", + description: "Self-hosted Stable Diffusion models", + }, + { + value: "NSCALE", + label: "Nscale", + example: "nscale/flux.1-schnell", + description: "Nscale image generation", + }, +]; + +/** + * Image generation models organized by provider. + */ +export interface ImageGenModel { + value: string; + label: string; + provider: string; +} + +export const IMAGE_GEN_MODELS: ImageGenModel[] = [ + // OpenAI + { value: "gpt-image-1", label: "GPT Image 1", provider: "OPENAI" }, + { value: "dall-e-3", label: "DALL-E 3", provider: "OPENAI" }, + { value: "dall-e-2", label: "DALL-E 2", provider: "OPENAI" }, + // Azure OpenAI + { value: "azure/dall-e-3", label: "DALL-E 3 (Azure)", provider: "AZURE_OPENAI" }, + { value: "azure/gpt-image-1", label: "GPT Image 1 (Azure)", provider: "AZURE_OPENAI" }, + // Recraft + { value: "recraft/recraftv3", label: "Recraft V3", provider: "RECRAFT" }, + // Bedrock + { + value: "bedrock/stability.stable-diffusion-xl-v0", + label: "Stable Diffusion XL", + provider: "BEDROCK", + }, + // Vertex AI + { + value: "vertex_ai/imagegeneration@006", + label: "Imagen 3", + provider: "VERTEX_AI", + }, +]; + +export function getImageGenModelsByProvider(provider: string): ImageGenModel[] { + return IMAGE_GEN_MODELS.filter((m) => m.provider === provider); +} diff --git a/surfsense_web/contracts/enums/llm-models.ts b/surfsense_web/contracts/enums/llm-models.ts new file mode 100644 index 000000000..9647c9d31 --- /dev/null +++ b/surfsense_web/contracts/enums/llm-models.ts @@ -0,0 +1,1558 @@ +export interface LLMModel { + value: string; + label: string; + provider: string; + contextWindow?: string; +} + +// Comprehensive models database organized by provider +export const LLM_MODELS: LLMModel[] = [ + // OpenAI + { + value: "gpt-4o", + label: "GPT-4o", + provider: "OPENAI", + contextWindow: "128K", + }, + { + value: "gpt-4o-mini", + label: "GPT-4o Mini", + provider: "OPENAI", + contextWindow: "128K", + }, + { + value: "gpt-4o-2024-11-20", + label: "GPT-4o (Nov 2024)", + provider: "OPENAI", + contextWindow: "128K", + }, + { + value: "gpt-4o-2024-08-06", + label: "GPT-4o (Aug 2024)", + provider: "OPENAI", + contextWindow: "128K", + }, + { + value: "gpt-4o-2024-05-13", + label: "GPT-4o (May 2024)", + provider: "OPENAI", + contextWindow: "128K", + }, + { + value: "gpt-4-turbo", + label: "GPT-4 Turbo", + provider: "OPENAI", + contextWindow: "128K", + }, + { value: "gpt-4", label: "GPT-4", provider: "OPENAI", contextWindow: "8K" }, + { + value: "gpt-3.5-turbo", + label: "GPT-3.5 Turbo", + provider: "OPENAI", + contextWindow: "16K", + }, + { value: "o1", label: "O1", provider: "OPENAI", contextWindow: "200K" }, + { + value: "o1-mini", + label: "O1 Mini", + provider: "OPENAI", + contextWindow: "128K", + }, + { + value: "o1-preview", + label: "O1 Preview", + provider: "OPENAI", + contextWindow: "128K", + }, + { value: "o3", label: "O3", provider: "OPENAI", contextWindow: "200K" }, + { + value: "o3-mini", + label: "O3 Mini", + provider: "OPENAI", + contextWindow: "200K", + }, + { + value: "o4-mini", + label: "O4 Mini", + provider: "OPENAI", + contextWindow: "200K", + }, + { + value: "gpt-4.1", + label: "GPT-4.1", + provider: "OPENAI", + contextWindow: "1M", + }, + { + value: "gpt-4.1-mini", + label: "GPT-4.1 Mini", + provider: "OPENAI", + contextWindow: "1M", + }, + { + value: "gpt-4.1-nano", + label: "GPT-4.1 Nano", + provider: "OPENAI", + contextWindow: "1M", + }, + { value: "gpt-5", label: "GPT-5", provider: "OPENAI", contextWindow: "272K" }, + { + value: "gpt-5-mini", + label: "GPT-5 Mini", + provider: "OPENAI", + contextWindow: "272K", + }, + { + value: "gpt-5-nano", + label: "GPT-5 Nano", + provider: "OPENAI", + contextWindow: "272K", + }, + { + value: "chatgpt-4o-latest", + label: "ChatGPT-4o Latest", + provider: "OPENAI", + contextWindow: "128K", + }, + + // Anthropic + { + value: "claude-3-5-sonnet-20241022", + label: "Claude 3.5 Sonnet", + provider: "ANTHROPIC", + contextWindow: "200K", + }, + { + value: "claude-3-7-sonnet-20250219", + label: "Claude 3.7 Sonnet", + provider: "ANTHROPIC", + contextWindow: "200K", + }, + { + value: "claude-4-sonnet-20250514", + label: "Claude 4 Sonnet", + provider: "ANTHROPIC", + contextWindow: "1M", + }, + { + value: "claude-4-opus-20250514", + label: "Claude 4 Opus", + provider: "ANTHROPIC", + contextWindow: "200K", + }, + { + value: "claude-3-5-haiku-20241022", + label: "Claude 3.5 Haiku", + provider: "ANTHROPIC", + contextWindow: "200K", + }, + { + value: "claude-haiku-4-5-20251001", + label: "Claude Haiku 4.5", + provider: "ANTHROPIC", + contextWindow: "200K", + }, + { + value: "claude-3-opus-20240229", + label: "Claude 3 Opus", + provider: "ANTHROPIC", + contextWindow: "200K", + }, + { + value: "claude-3-haiku-20240307", + label: "Claude 3 Haiku", + provider: "ANTHROPIC", + contextWindow: "200K", + }, + { + value: "claude-sonnet-4-5-20250929", + label: "Claude Sonnet 4.5", + provider: "ANTHROPIC", + contextWindow: "200K", + }, + { + value: "claude-opus-4-1-20250805", + label: "Claude Opus 4.1", + provider: "ANTHROPIC", + contextWindow: "200K", + }, + + // Google (Gemini) + { + value: "gemini-3-flash-preview", + label: "Gemini 3 Flash", + provider: "GOOGLE", + contextWindow: "1M", + }, + { + value: "gemini-3-pro-preview", + label: "Gemini 3 Pro", + provider: "GOOGLE", + contextWindow: "1M", + }, + { + value: "gemini-2.5-flash", + label: "Gemini 2.5 Flash", + provider: "GOOGLE", + contextWindow: "1M", + }, + { + value: "gemini-2.5-pro", + label: "Gemini 2.5 Pro", + provider: "GOOGLE", + contextWindow: "1M", + }, + { + value: "gemini-2.0-flash", + label: "Gemini 2.0 Flash", + provider: "GOOGLE", + contextWindow: "1M", + }, + { + value: "gemini-2.0-flash-lite", + label: "Gemini 2.0 Flash Lite", + provider: "GOOGLE", + contextWindow: "1M", + }, + { + value: "gemini-1.5-flash", + label: "Gemini 1.5 Flash", + provider: "GOOGLE", + contextWindow: "1M", + }, + { + value: "gemini-1.5-pro", + label: "Gemini 1.5 Pro", + provider: "GOOGLE", + contextWindow: "2M", + }, + { + value: "gemini-pro", + label: "Gemini Pro", + provider: "GOOGLE", + contextWindow: "33K", + }, + { + value: "gemini-pro-vision", + label: "Gemini Pro Vision", + provider: "GOOGLE", + contextWindow: "16K", + }, + + // DeepSeek + { + value: "deepseek-chat", + label: "DeepSeek Chat", + provider: "DEEPSEEK", + contextWindow: "131K", + }, + { + value: "deepseek-reasoner", + label: "DeepSeek Reasoner", + provider: "DEEPSEEK", + contextWindow: "131K", + }, + { + value: "deepseek-coder", + label: "DeepSeek Coder", + provider: "DEEPSEEK", + contextWindow: "128K", + }, + + // xAI (Grok) + { value: "grok-4", label: "Grok 4", provider: "XAI", contextWindow: "256K" }, + { value: "grok-3", label: "Grok 3", provider: "XAI", contextWindow: "131K" }, + { + value: "grok-3-mini", + label: "Grok 3 Mini", + provider: "XAI", + contextWindow: "131K", + }, + { + value: "grok-3-fast-beta", + label: "Grok 3 Fast", + provider: "XAI", + contextWindow: "131K", + }, + { + value: "grok-3-mini-fast", + label: "Grok 3 Mini Fast", + provider: "XAI", + contextWindow: "131K", + }, + { value: "grok-2", label: "Grok 2", provider: "XAI", contextWindow: "131K" }, + { + value: "grok-2-vision", + label: "Grok 2 Vision", + provider: "XAI", + contextWindow: "33K", + }, + + // Azure OpenAI + { + value: "gpt-4o", + label: "Azure GPT-4o", + provider: "AZURE_OPENAI", + contextWindow: "128K", + }, + { + value: "gpt-4o-mini", + label: "Azure GPT-4o Mini", + provider: "AZURE_OPENAI", + contextWindow: "128K", + }, + { + value: "gpt-4o-2024-11-20", + label: "Azure GPT-4o (Nov 2024)", + provider: "AZURE_OPENAI", + contextWindow: "128K", + }, + { + value: "gpt-4-turbo", + label: "Azure GPT-4 Turbo", + provider: "AZURE_OPENAI", + contextWindow: "128K", + }, + { + value: "gpt-4", + label: "Azure GPT-4", + provider: "AZURE_OPENAI", + contextWindow: "8K", + }, + { + value: "gpt-35-turbo", + label: "Azure GPT-3.5 Turbo", + provider: "AZURE_OPENAI", + contextWindow: "4K", + }, + { + value: "o1", + label: "Azure O1", + provider: "AZURE_OPENAI", + contextWindow: "200K", + }, + { + value: "o1-mini", + label: "Azure O1 Mini", + provider: "AZURE_OPENAI", + contextWindow: "128K", + }, + { + value: "o3-mini", + label: "Azure O3 Mini", + provider: "AZURE_OPENAI", + contextWindow: "200K", + }, + { + value: "gpt-4.1", + label: "Azure GPT-4.1", + provider: "AZURE_OPENAI", + contextWindow: "1M", + }, + { + value: "gpt-4.1-mini", + label: "Azure GPT-4.1 Mini", + provider: "AZURE_OPENAI", + contextWindow: "1M", + }, + { + value: "gpt-5", + label: "Azure GPT-5", + provider: "AZURE_OPENAI", + contextWindow: "272K", + }, + + // AWS Bedrock + { + value: "anthropic.claude-3-5-sonnet-20241022-v2:0", + label: "Bedrock Claude 3.5 Sonnet", + provider: "BEDROCK", + contextWindow: "200K", + }, + { + value: "anthropic.claude-3-7-sonnet-20250219-v1:0", + label: "Bedrock Claude 3.7 Sonnet", + provider: "BEDROCK", + contextWindow: "200K", + }, + { + value: "anthropic.claude-4-sonnet-20250514-v1:0", + label: "Bedrock Claude 4 Sonnet", + provider: "BEDROCK", + contextWindow: "1M", + }, + { + value: "anthropic.claude-3-opus-20240229-v1:0", + label: "Bedrock Claude 3 Opus", + provider: "BEDROCK", + contextWindow: "200K", + }, + { + value: "anthropic.claude-3-haiku-20240307-v1:0", + label: "Bedrock Claude 3 Haiku", + provider: "BEDROCK", + contextWindow: "200K", + }, + { + value: "anthropic.claude-haiku-4-5-20251001-v1:0", + label: "Bedrock Claude Haiku 4.5", + provider: "BEDROCK", + contextWindow: "200K", + }, + { + value: "amazon.nova-pro-v1:0", + label: "Amazon Nova Pro", + provider: "BEDROCK", + contextWindow: "300K", + }, + { + value: "amazon.nova-lite-v1:0", + label: "Amazon Nova Lite", + provider: "BEDROCK", + contextWindow: "300K", + }, + { + value: "amazon.nova-micro-v1:0", + label: "Amazon Nova Micro", + provider: "BEDROCK", + contextWindow: "128K", + }, + { + value: "meta.llama3-3-70b-instruct-v1:0", + label: "Bedrock Llama 3.3 70B", + provider: "BEDROCK", + contextWindow: "128K", + }, + { + value: "meta.llama3-1-405b-instruct-v1:0", + label: "Bedrock Llama 3.1 405B", + provider: "BEDROCK", + contextWindow: "128K", + }, + { + value: "meta.llama3-1-70b-instruct-v1:0", + label: "Bedrock Llama 3.1 70B", + provider: "BEDROCK", + contextWindow: "128K", + }, + { + value: "meta.llama3-1-8b-instruct-v1:0", + label: "Bedrock Llama 3.1 8B", + provider: "BEDROCK", + contextWindow: "128K", + }, + { + value: "meta.llama4-maverick-17b-instruct-v1:0", + label: "Bedrock Llama 4 Maverick 17B", + provider: "BEDROCK", + contextWindow: "128K", + }, + { + value: "meta.llama4-scout-17b-instruct-v1:0", + label: "Bedrock Llama 4 Scout 17B", + provider: "BEDROCK", + contextWindow: "128K", + }, + { + value: "mistral.mistral-large-2407-v1:0", + label: "Bedrock Mistral Large", + provider: "BEDROCK", + contextWindow: "128K", + }, + { + value: "mistral.mixtral-8x7b-instruct-v0:1", + label: "Bedrock Mixtral 8x7B", + provider: "BEDROCK", + contextWindow: "32K", + }, + { + value: "cohere.command-r-plus-v1:0", + label: "Bedrock Cohere Command R+", + provider: "BEDROCK", + contextWindow: "128K", + }, + { + value: "cohere.command-r-v1:0", + label: "Bedrock Cohere Command R", + provider: "BEDROCK", + contextWindow: "128K", + }, + { + value: "ai21.jamba-1-5-large-v1:0", + label: "Bedrock Jamba 1.5 Large", + provider: "BEDROCK", + contextWindow: "256K", + }, + { + value: "ai21.jamba-1-5-mini-v1:0", + label: "Bedrock Jamba 1.5 Mini", + provider: "BEDROCK", + contextWindow: "256K", + }, + { + value: "deepseek.v3-v1:0", + label: "Bedrock DeepSeek V3", + provider: "BEDROCK", + contextWindow: "164K", + }, + + // Vertex AI + { + value: "gemini-2.5-flash", + label: "Vertex Gemini 2.5 Flash", + provider: "VERTEX_AI", + contextWindow: "1M", + }, + { + value: "gemini-2.5-pro", + label: "Vertex Gemini 2.5 Pro", + provider: "VERTEX_AI", + contextWindow: "1M", + }, + { + value: "gemini-2.0-flash", + label: "Vertex Gemini 2.0 Flash", + provider: "VERTEX_AI", + contextWindow: "1M", + }, + { + value: "gemini-1.5-flash", + label: "Vertex Gemini 1.5 Flash", + provider: "VERTEX_AI", + contextWindow: "1M", + }, + { + value: "gemini-1.5-pro", + label: "Vertex Gemini 1.5 Pro", + provider: "VERTEX_AI", + contextWindow: "2M", + }, + { + value: "claude-3-5-sonnet-v2@20241022", + label: "Vertex Claude 3.5 Sonnet", + provider: "VERTEX_AI", + contextWindow: "200K", + }, + { + value: "claude-3-7-sonnet@20250219", + label: "Vertex Claude 3.7 Sonnet", + provider: "VERTEX_AI", + contextWindow: "200K", + }, + { + value: "claude-sonnet-4@20250514", + label: "Vertex Claude Sonnet 4", + provider: "VERTEX_AI", + contextWindow: "1M", + }, + { + value: "claude-3-opus@20240229", + label: "Vertex Claude 3 Opus", + provider: "VERTEX_AI", + contextWindow: "200K", + }, + { + value: "claude-3-haiku@20240307", + label: "Vertex Claude 3 Haiku", + provider: "VERTEX_AI", + contextWindow: "200K", + }, + { + value: "claude-haiku-4-5@20251001", + label: "Vertex Claude Haiku 4.5", + provider: "VERTEX_AI", + contextWindow: "200K", + }, + { + value: "meta/llama-3.1-405b-instruct-maas", + label: "Vertex Llama 3.1 405B", + provider: "VERTEX_AI", + contextWindow: "128K", + }, + { + value: "mistral-large@2411-001", + label: "Vertex Mistral Large", + provider: "VERTEX_AI", + contextWindow: "128K", + }, + + // Groq + { + value: "llama-3.3-70b-versatile", + label: "Groq Llama 3.3 70B", + provider: "GROQ", + contextWindow: "128K", + }, + { + value: "llama-3.3-70b-specdec", + label: "Groq Llama 3.3 70B Specdec", + provider: "GROQ", + contextWindow: "8K", + }, + { + value: "llama-3.1-70b-versatile", + label: "Groq Llama 3.1 70B", + provider: "GROQ", + contextWindow: "8K", + }, + { + value: "llama-3.1-8b-instant", + label: "Groq Llama 3.1 8B", + provider: "GROQ", + contextWindow: "128K", + }, + { + value: "llama-3.2-90b-vision-preview", + label: "Groq Llama 3.2 90B Vision", + provider: "GROQ", + contextWindow: "8K", + }, + { + value: "llama-3.2-11b-vision-preview", + label: "Groq Llama 3.2 11B Vision", + provider: "GROQ", + contextWindow: "8K", + }, + { + value: "llama-3.2-3b-preview", + label: "Groq Llama 3.2 3B", + provider: "GROQ", + contextWindow: "8K", + }, + { + value: "llama-3.2-1b-preview", + label: "Groq Llama 3.2 1B", + provider: "GROQ", + contextWindow: "8K", + }, + { + value: "mixtral-8x7b-32768", + label: "Groq Mixtral 8x7B", + provider: "GROQ", + contextWindow: "33K", + }, + { + value: "gemma2-9b-it", + label: "Groq Gemma 2 9B", + provider: "GROQ", + contextWindow: "8K", + }, + { + value: "deepseek-r1-distill-llama-70b", + label: "Groq DeepSeek R1 Distill", + provider: "GROQ", + contextWindow: "128K", + }, + { + value: "meta-llama/llama-4-maverick-17b-128e-instruct", + label: "Groq Llama 4 Maverick", + provider: "GROQ", + contextWindow: "131K", + }, + { + value: "meta-llama/llama-4-scout-17b-16e-instruct", + label: "Groq Llama 4 Scout", + provider: "GROQ", + contextWindow: "131K", + }, + { + value: "openai/gpt-oss-120b", + label: "Groq GPT-OSS-120B", + provider: "GROQ", + contextWindow: "131K", + }, + { + value: "openai/gpt-oss-20b", + label: "Groq GPT-OSS-20B", + provider: "GROQ", + contextWindow: "131K", + }, + { + value: "moonshotai/kimi-k2-instruct", + label: "Groq Kimi K2", + provider: "GROQ", + contextWindow: "131K", + }, + + // Cohere + { + value: "command-a-03-2025", + label: "Command A (03-2025)", + provider: "COHERE", + contextWindow: "256K", + }, + { + value: "command-r-plus", + label: "Command R+", + provider: "COHERE", + contextWindow: "128K", + }, + { + value: "command-r", + label: "Command R", + provider: "COHERE", + contextWindow: "128K", + }, + { + value: "command-r-plus-08-2024", + label: "Command R+ (08-2024)", + provider: "COHERE", + contextWindow: "128K", + }, + { + value: "command-r-08-2024", + label: "Command R (08-2024)", + provider: "COHERE", + contextWindow: "128K", + }, + { + value: "command", + label: "Command", + provider: "COHERE", + contextWindow: "4K", + }, + + // Mistral + { + value: "mistral-large-latest", + label: "Mistral Large Latest", + provider: "MISTRAL", + contextWindow: "128K", + }, + { + value: "mistral-large-2411", + label: "Mistral Large 2411", + provider: "MISTRAL", + contextWindow: "128K", + }, + { + value: "mistral-medium-latest", + label: "Mistral Medium Latest", + provider: "MISTRAL", + contextWindow: "131K", + }, + { + value: "mistral-medium-2505", + label: "Mistral Medium 2505", + provider: "MISTRAL", + contextWindow: "131K", + }, + { + value: "mistral-small-latest", + label: "Mistral Small Latest", + provider: "MISTRAL", + contextWindow: "32K", + }, + { + value: "open-mistral-nemo", + label: "Mistral Nemo", + provider: "MISTRAL", + contextWindow: "128K", + }, + { + value: "open-mixtral-8x7b", + label: "Mixtral 8x7B", + provider: "MISTRAL", + contextWindow: "32K", + }, + { + value: "open-mixtral-8x22b", + label: "Mixtral 8x22B", + provider: "MISTRAL", + contextWindow: "65K", + }, + { + value: "codestral-latest", + label: "Codestral Latest", + provider: "MISTRAL", + contextWindow: "32K", + }, + { + value: "pixtral-large-latest", + label: "Pixtral Large Latest", + provider: "MISTRAL", + contextWindow: "128K", + }, + { + value: "magistral-medium-latest", + label: "Magistral Medium Latest", + provider: "MISTRAL", + contextWindow: "40K", + }, + + // Together AI + { + value: "meta-llama/Meta-Llama-3.3-70B-Instruct-Turbo", + label: "Together Llama 3.3 70B Turbo", + provider: "TOGETHER_AI", + contextWindow: "128K", + }, + { + value: "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", + label: "Together Llama 3.1 405B Turbo", + provider: "TOGETHER_AI", + contextWindow: "128K", + }, + { + value: "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", + label: "Together Llama 3.1 70B Turbo", + provider: "TOGETHER_AI", + contextWindow: "128K", + }, + { + value: "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", + label: "Together Llama 3.1 8B Turbo", + provider: "TOGETHER_AI", + contextWindow: "128K", + }, + { + value: "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", + label: "Together Llama 4 Maverick", + provider: "TOGETHER_AI", + contextWindow: "131K", + }, + { + value: "meta-llama/Llama-4-Scout-17B-16E-Instruct", + label: "Together Llama 4 Scout", + provider: "TOGETHER_AI", + contextWindow: "131K", + }, + { + value: "deepseek-ai/DeepSeek-V3.1", + label: "Together DeepSeek V3.1", + provider: "TOGETHER_AI", + contextWindow: "128K", + }, + { + value: "deepseek-ai/DeepSeek-V3", + label: "Together DeepSeek V3", + provider: "TOGETHER_AI", + contextWindow: "66K", + }, + { + value: "deepseek-ai/DeepSeek-R1", + label: "Together DeepSeek R1", + provider: "TOGETHER_AI", + contextWindow: "128K", + }, + { + value: "mistralai/Mixtral-8x7B-Instruct-v0.1", + label: "Together Mixtral 8x7B", + provider: "TOGETHER_AI", + contextWindow: "32K", + }, + { + value: "Qwen/Qwen3-Coder-480B-A35B-Instruct-FP8", + label: "Together Qwen3 Coder 480B", + provider: "TOGETHER_AI", + contextWindow: "256K", + }, + { + value: "Qwen/Qwen3-235B-A22B-Instruct-2507-tput", + label: "Together Qwen3 235B", + provider: "TOGETHER_AI", + contextWindow: "262K", + }, + { + value: "moonshotai/Kimi-K2-Instruct", + label: "Together Kimi K2", + provider: "TOGETHER_AI", + contextWindow: "131K", + }, + { + value: "openai/gpt-oss-120b", + label: "Together GPT-OSS-120B", + provider: "TOGETHER_AI", + contextWindow: "128K", + }, + { + value: "openai/gpt-oss-20b", + label: "Together GPT-OSS-20B", + provider: "TOGETHER_AI", + contextWindow: "128K", + }, + + // Fireworks AI + { + value: "accounts/fireworks/models/llama-v3p3-70b-instruct", + label: "Fireworks Llama 3.3 70B", + provider: "FIREWORKS_AI", + contextWindow: "131K", + }, + { + value: "accounts/fireworks/models/llama-v3p1-405b-instruct", + label: "Fireworks Llama 3.1 405B", + provider: "FIREWORKS_AI", + contextWindow: "128K", + }, + { + value: "accounts/fireworks/models/llama4-maverick-instruct-basic", + label: "Fireworks Llama 4 Maverick", + provider: "FIREWORKS_AI", + contextWindow: "131K", + }, + { + value: "accounts/fireworks/models/llama4-scout-instruct-basic", + label: "Fireworks Llama 4 Scout", + provider: "FIREWORKS_AI", + contextWindow: "131K", + }, + { + value: "accounts/fireworks/models/deepseek-v3p1", + label: "Fireworks DeepSeek V3.1", + provider: "FIREWORKS_AI", + contextWindow: "128K", + }, + { + value: "accounts/fireworks/models/deepseek-v3", + label: "Fireworks DeepSeek V3", + provider: "FIREWORKS_AI", + contextWindow: "128K", + }, + { + value: "accounts/fireworks/models/deepseek-r1", + label: "Fireworks DeepSeek R1", + provider: "FIREWORKS_AI", + contextWindow: "128K", + }, + { + value: "accounts/fireworks/models/mixtral-8x22b-instruct-hf", + label: "Fireworks Mixtral 8x22B", + provider: "FIREWORKS_AI", + contextWindow: "66K", + }, + { + value: "accounts/fireworks/models/qwen2p5-coder-32b-instruct", + label: "Fireworks Qwen2.5 Coder 32B", + provider: "FIREWORKS_AI", + contextWindow: "4K", + }, + { + value: "accounts/fireworks/models/kimi-k2-instruct", + label: "Fireworks Kimi K2", + provider: "FIREWORKS_AI", + contextWindow: "131K", + }, + + // Replicate + { + value: "meta/llama-3-70b-instruct", + label: "Replicate Llama 3 70B", + provider: "REPLICATE", + contextWindow: "8K", + }, + { + value: "meta/llama-3-8b-instruct", + label: "Replicate Llama 3 8B", + provider: "REPLICATE", + contextWindow: "8K", + }, + { + value: "meta/llama-2-70b-chat", + label: "Replicate Llama 2 70B", + provider: "REPLICATE", + contextWindow: "4K", + }, + { + value: "mistralai/mixtral-8x7b-instruct-v0.1", + label: "Replicate Mixtral 8x7B", + provider: "REPLICATE", + contextWindow: "4K", + }, + + // Perplexity + { + value: "sonar-pro", + label: "Sonar Pro", + provider: "PERPLEXITY", + contextWindow: "200K", + }, + { + value: "sonar", + label: "Sonar", + provider: "PERPLEXITY", + contextWindow: "128K", + }, + { + value: "sonar-reasoning-pro", + label: "Sonar Reasoning Pro", + provider: "PERPLEXITY", + contextWindow: "128K", + }, + { + value: "sonar-reasoning", + label: "Sonar Reasoning", + provider: "PERPLEXITY", + contextWindow: "128K", + }, + { + value: "llama-3.1-sonar-large-128k-online", + label: "Llama 3.1 Sonar Large Online", + provider: "PERPLEXITY", + contextWindow: "127K", + }, + { + value: "llama-3.1-sonar-small-128k-online", + label: "Llama 3.1 Sonar Small Online", + provider: "PERPLEXITY", + contextWindow: "127K", + }, + + // OpenRouter + { + value: "anthropic/claude-4-opus", + label: "OpenRouter Claude 4 Opus", + provider: "OPENROUTER", + contextWindow: "200K", + }, + { + value: "anthropic/claude-sonnet-4", + label: "OpenRouter Claude Sonnet 4", + provider: "OPENROUTER", + contextWindow: "1M", + }, + { + value: "anthropic/claude-3.7-sonnet", + label: "OpenRouter Claude 3.7 Sonnet", + provider: "OPENROUTER", + contextWindow: "200K", + }, + { + value: "anthropic/claude-3.5-sonnet", + label: "OpenRouter Claude 3.5 Sonnet", + provider: "OPENROUTER", + contextWindow: "200K", + }, + { + value: "openai/gpt-5", + label: "OpenRouter GPT-5", + provider: "OPENROUTER", + contextWindow: "272K", + }, + { + value: "openai/gpt-4.1", + label: "OpenRouter GPT-4.1", + provider: "OPENROUTER", + contextWindow: "1M", + }, + { + value: "openai/gpt-4o", + label: "OpenRouter GPT-4o", + provider: "OPENROUTER", + contextWindow: "128K", + }, + { + value: "openai/o3-mini", + label: "OpenRouter O3 Mini", + provider: "OPENROUTER", + contextWindow: "128K", + }, + { + value: "x-ai/grok-4", + label: "OpenRouter Grok 4", + provider: "OPENROUTER", + contextWindow: "256K", + }, + { + value: "deepseek/deepseek-chat-v3.1", + label: "OpenRouter DeepSeek Chat V3.1", + provider: "OPENROUTER", + contextWindow: "164K", + }, + { + value: "deepseek/deepseek-r1", + label: "OpenRouter DeepSeek R1", + provider: "OPENROUTER", + contextWindow: "65K", + }, + { + value: "google/gemini-2.5-flash", + label: "OpenRouter Gemini 2.5 Flash", + provider: "OPENROUTER", + contextWindow: "1M", + }, + { + value: "google/gemini-2.5-pro", + label: "OpenRouter Gemini 2.5 Pro", + provider: "OPENROUTER", + contextWindow: "1M", + }, + + // Ollama (Local) + { + value: "llama3.3", + label: "Ollama Llama 3.3", + provider: "OLLAMA", + contextWindow: "128K", + }, + { + value: "llama3.1", + label: "Ollama Llama 3.1", + provider: "OLLAMA", + contextWindow: "8K", + }, + { + value: "llama3", + label: "Ollama Llama 3", + provider: "OLLAMA", + contextWindow: "8K", + }, + { + value: "llama2", + label: "Ollama Llama 2", + provider: "OLLAMA", + contextWindow: "4K", + }, + { + value: "mistral", + label: "Ollama Mistral", + provider: "OLLAMA", + contextWindow: "8K", + }, + { + value: "mixtral", + label: "Ollama Mixtral 8x7B", + provider: "OLLAMA", + contextWindow: "33K", + }, + { + value: "codellama", + label: "Ollama CodeLlama", + provider: "OLLAMA", + contextWindow: "4K", + }, + { + value: "deepseek-coder-v2-instruct", + label: "Ollama DeepSeek Coder V2", + provider: "OLLAMA", + contextWindow: "33K", + }, + + // Alibaba Qwen + { + value: "qwen-plus", + label: "Qwen Plus", + provider: "ALIBABA_QWEN", + contextWindow: "129K", + }, + { + value: "qwen-turbo", + label: "Qwen Turbo", + provider: "ALIBABA_QWEN", + contextWindow: "129K", + }, + { + value: "qwen-max", + label: "Qwen Max", + provider: "ALIBABA_QWEN", + contextWindow: "31K", + }, + { + value: "qwen-coder", + label: "Qwen Coder", + provider: "ALIBABA_QWEN", + contextWindow: "1M", + }, + { + value: "qwen3-32b", + label: "Qwen3 32B", + provider: "ALIBABA_QWEN", + contextWindow: "131K", + }, + { + value: "qwen3-30b-a3b", + label: "Qwen3 30B-A3B", + provider: "ALIBABA_QWEN", + contextWindow: "129K", + }, + { + value: "qwen3-coder-plus", + label: "Qwen3 Coder Plus", + provider: "ALIBABA_QWEN", + contextWindow: "998K", + }, + { + value: "qwq-plus", + label: "QwQ Plus", + provider: "ALIBABA_QWEN", + contextWindow: "98K", + }, + + // Moonshot (Kimi) + { + value: "kimi-latest", + label: "Kimi Latest", + provider: "MOONSHOT", + contextWindow: "131K", + }, + { + value: "kimi-k2-thinking", + label: "Kimi K2 Thinking", + provider: "MOONSHOT", + contextWindow: "262K", + }, + { + value: "moonshot-v1-128k", + label: "Moonshot V1 128K", + provider: "MOONSHOT", + contextWindow: "131K", + }, + { + value: "moonshot-v1-32k", + label: "Moonshot V1 32K", + provider: "MOONSHOT", + contextWindow: "33K", + }, + { + value: "moonshot-v1-8k", + label: "Moonshot V1 8K", + provider: "MOONSHOT", + contextWindow: "8K", + }, + + // Zhipu (GLM) + { + value: "glm-4.6", + label: "GLM 4.6", + provider: "ZHIPU", + contextWindow: "203K", + }, + { + value: "glm-4.6:exacto", + label: "GLM 4.6 Exacto", + provider: "ZHIPU", + contextWindow: "203K", + }, + + // Anyscale + { + value: "meta-llama/Meta-Llama-3-70B-Instruct", + label: "Anyscale Llama 3 70B", + provider: "ANYSCALE", + contextWindow: "8K", + }, + { + value: "meta-llama/Meta-Llama-3-8B-Instruct", + label: "Anyscale Llama 3 8B", + provider: "ANYSCALE", + contextWindow: "8K", + }, + { + value: "mistralai/Mixtral-8x7B-Instruct-v0.1", + label: "Anyscale Mixtral 8x7B", + provider: "ANYSCALE", + contextWindow: "16K", + }, + + // DeepInfra + { + value: "meta-llama/Meta-Llama-3.3-70B-Instruct", + label: "DeepInfra Llama 3.3 70B", + provider: "DEEPINFRA", + contextWindow: "131K", + }, + { + value: "meta-llama/Meta-Llama-3.1-405B-Instruct", + label: "DeepInfra Llama 3.1 405B", + provider: "DEEPINFRA", + contextWindow: "33K", + }, + { + value: "meta-llama/Meta-Llama-3.1-70B-Instruct", + label: "DeepInfra Llama 3.1 70B", + provider: "DEEPINFRA", + contextWindow: "131K", + }, + { + value: "deepseek-ai/DeepSeek-V3", + label: "DeepInfra DeepSeek V3", + provider: "DEEPINFRA", + contextWindow: "164K", + }, + { + value: "deepseek-ai/DeepSeek-R1", + label: "DeepInfra DeepSeek R1", + provider: "DEEPINFRA", + contextWindow: "164K", + }, + { + value: "Qwen/Qwen2.5-72B-Instruct", + label: "DeepInfra Qwen 2.5 72B", + provider: "DEEPINFRA", + contextWindow: "33K", + }, + { + value: "Qwen/Qwen3-235B-A22B", + label: "DeepInfra Qwen3 235B", + provider: "DEEPINFRA", + contextWindow: "131K", + }, + { + value: "google/gemini-2.5-flash", + label: "DeepInfra Gemini 2.5 Flash", + provider: "DEEPINFRA", + contextWindow: "1M", + }, + { + value: "anthropic/claude-3-7-sonnet-latest", + label: "DeepInfra Claude 3.7 Sonnet", + provider: "DEEPINFRA", + contextWindow: "200K", + }, + + // Cerebras + { + value: "llama-3.3-70b", + label: "Cerebras Llama 3.3 70B", + provider: "CEREBRAS", + contextWindow: "128K", + }, + { + value: "llama3.1-70b", + label: "Cerebras Llama 3.1 70B", + provider: "CEREBRAS", + contextWindow: "128K", + }, + { + value: "llama3.1-8b", + label: "Cerebras Llama 3.1 8B", + provider: "CEREBRAS", + contextWindow: "128K", + }, + { + value: "qwen-3-32b", + label: "Cerebras Qwen 3 32B", + provider: "CEREBRAS", + contextWindow: "128K", + }, + { + value: "gpt-oss-120b", + label: "Cerebras GPT-OSS-120B", + provider: "CEREBRAS", + contextWindow: "131K", + }, + + // SambaNova + { + value: "Meta-Llama-3.3-70B-Instruct", + label: "SambaNova Llama 3.3 70B", + provider: "SAMBANOVA", + contextWindow: "131K", + }, + { + value: "Meta-Llama-3.1-405B-Instruct", + label: "SambaNova Llama 3.1 405B", + provider: "SAMBANOVA", + contextWindow: "16K", + }, + { + value: "Meta-Llama-3.1-8B-Instruct", + label: "SambaNova Llama 3.1 8B", + provider: "SAMBANOVA", + contextWindow: "16K", + }, + { + value: "DeepSeek-R1", + label: "SambaNova DeepSeek R1", + provider: "SAMBANOVA", + contextWindow: "33K", + }, + { + value: "DeepSeek-V3-0324", + label: "SambaNova DeepSeek V3", + provider: "SAMBANOVA", + contextWindow: "33K", + }, + { + value: "Llama-4-Maverick-17B-128E-Instruct", + label: "SambaNova Llama 4 Maverick", + provider: "SAMBANOVA", + contextWindow: "131K", + }, + { + value: "Llama-4-Scout-17B-16E-Instruct", + label: "SambaNova Llama 4 Scout", + provider: "SAMBANOVA", + contextWindow: "8K", + }, + { + value: "QwQ-32B", + label: "SambaNova QwQ 32B", + provider: "SAMBANOVA", + contextWindow: "16K", + }, + { + value: "Qwen3-32B", + label: "SambaNova Qwen3 32B", + provider: "SAMBANOVA", + contextWindow: "8K", + }, + + // AI21 Labs + { + value: "jamba-1.5-large", + label: "Jamba 1.5 Large", + provider: "AI21", + contextWindow: "256K", + }, + { + value: "jamba-1.5-mini", + label: "Jamba 1.5 Mini", + provider: "AI21", + contextWindow: "256K", + }, + { + value: "jamba-large-1.6", + label: "Jamba Large 1.6", + provider: "AI21", + contextWindow: "256K", + }, + { + value: "jamba-mini-1.6", + label: "Jamba Mini 1.6", + provider: "AI21", + contextWindow: "256K", + }, + + // Cloudflare + { + value: "@cf/meta/llama-2-7b-chat-fp16", + label: "Cloudflare Llama 2 7B", + provider: "CLOUDFLARE", + contextWindow: "3K", + }, + { + value: "@cf/mistral/mistral-7b-instruct-v0.1", + label: "Cloudflare Mistral 7B", + provider: "CLOUDFLARE", + contextWindow: "8K", + }, + + // Databricks + { + value: "databricks-meta-llama-3-3-70b-instruct", + label: "Databricks Llama 3.3 70B", + provider: "DATABRICKS", + contextWindow: "128K", + }, + { + value: "databricks-meta-llama-3-1-405b-instruct", + label: "Databricks Llama 3.1 405B", + provider: "DATABRICKS", + contextWindow: "128K", + }, + { + value: "databricks-claude-3-7-sonnet", + label: "Databricks Claude 3.7 Sonnet", + provider: "DATABRICKS", + contextWindow: "200K", + }, + { + value: "databricks-llama-4-maverick", + label: "Databricks Llama 4 Maverick", + provider: "DATABRICKS", + contextWindow: "128K", + }, + + // GitHub Models + { + value: "openai/gpt-5", + label: "GitHub GPT-5", + provider: "GITHUB_MODELS", + }, + { + value: "openai/gpt-4.1", + label: "GitHub GPT-4.1", + provider: "GITHUB_MODELS", + contextWindow: "1048K", + }, + { + value: "openai/gpt-4o", + label: "GitHub GPT-4o", + provider: "GITHUB_MODELS", + contextWindow: "128K", + }, + { + value: "deepseek/DeepSeek-V3-0324", + label: "GitHub DeepSeek V3", + provider: "GITHUB_MODELS", + contextWindow: "64K", + }, + { + value: "xai/grok-3", + label: "GitHub Grok 3", + provider: "GITHUB_MODELS", + contextWindow: "131K", + }, + { + value: "openai/gpt-5-mini", + label: "GitHub GPT-5 Mini", + provider: "GITHUB_MODELS", + }, + { + value: "openai/gpt-4.1-mini", + label: "GitHub GPT-4.1 Mini", + provider: "GITHUB_MODELS", + contextWindow: "1048K", + }, + { + value: "meta/Llama-4-Scout-17B-16E-Instruct", + label: "GitHub Llama 4 Scout", + provider: "GITHUB_MODELS", + contextWindow: "512K", + }, + { + value: "openai/gpt-4.1-nano", + label: "GitHub GPT-4.1 Nano", + provider: "GITHUB_MODELS", + contextWindow: "1048K", + }, + { + value: "openai/gpt-4o-mini", + label: "GitHub GPT-4o Mini", + provider: "GITHUB_MODELS", + contextWindow: "128K", + }, + { + value: "openai/o4-mini", + label: "GitHub O4 Mini", + provider: "GITHUB_MODELS", + contextWindow: "200K", + }, + { + value: "deepseek/DeepSeek-R1", + label: "GitHub DeepSeek R1", + provider: "GITHUB_MODELS", + contextWindow: "64K", + }, + + // MiniMax + { + value: "MiniMax-M3", + label: "MiniMax M3", + provider: "MINIMAX", + contextWindow: "512K", + }, + { + value: "MiniMax-M2.7", + label: "MiniMax M2.7", + provider: "MINIMAX", + contextWindow: "204K", + }, + { + value: "MiniMax-M2.7-highspeed", + label: "MiniMax M2.7 Highspeed", + provider: "MINIMAX", + contextWindow: "204K", + }, +]; + +// Helper function to get models by provider +export function getModelsByProvider(provider: string): LLMModel[] { + return LLM_MODELS.filter((model) => model.provider === provider); +} + +// Helper function to get all providers that have models +export function getProvidersWithModels(): string[] { + return Array.from(new Set(LLM_MODELS.map((model) => model.provider))); +} diff --git a/surfsense_web/contracts/enums/llm-providers.ts b/surfsense_web/contracts/enums/llm-providers.ts new file mode 100644 index 000000000..c04a44923 --- /dev/null +++ b/surfsense_web/contracts/enums/llm-providers.ts @@ -0,0 +1,197 @@ +export interface LLMProvider { + value: string; + label: string; + example: string; + description: string; + apiBase?: string; +} + +export const LLM_PROVIDERS: LLMProvider[] = [ + { + value: "OPENAI", + label: "OpenAI", + example: "gpt-4o, gpt-4o-mini, o1, o3-mini", + description: "Industry-leading GPT models", + }, + { + value: "ANTHROPIC", + label: "Anthropic", + example: "claude-3-5-sonnet, claude-3-opus, claude-4-sonnet", + description: "Claude models with strong reasoning", + }, + { + value: "GOOGLE", + label: "Google (Gemini)", + example: "gemini-2.5-flash, gemini-2.5-pro, gemini-1.5-pro", + description: "Gemini models with multimodal capabilities", + }, + { + value: "AZURE_OPENAI", + label: "Azure OpenAI", + example: "azure/gpt-4o, azure/gpt-4o-mini", + description: "OpenAI models on Azure", + }, + { + value: "BEDROCK", + label: "AWS Bedrock", + example: "anthropic.claude-3-5-sonnet, meta.llama3-70b", + description: "Foundation models on AWS", + }, + { + value: "VERTEX_AI", + label: "Google Vertex AI", + example: "vertex_ai/claude-3-5-sonnet, vertex_ai/gemini-2.5-pro", + description: "Models on Google Cloud Vertex AI", + }, + { + value: "GROQ", + label: "Groq", + example: "groq/llama-3.3-70b-versatile, groq/mixtral-8x7b", + description: "Ultra-fast inference", + }, + { + value: "COHERE", + label: "Cohere", + example: "command-a-03-2025, command-r-plus", + description: "Enterprise NLP models", + }, + { + value: "MISTRAL", + label: "Mistral AI", + example: "mistral-large-latest, mistral-medium-latest", + description: "European open-source models", + }, + { + value: "DEEPSEEK", + label: "DeepSeek", + example: "deepseek-chat, deepseek-reasoner", + description: "High-performance reasoning models", + apiBase: "https://api.deepseek.com", + }, + { + value: "XAI", + label: "xAI (Grok)", + example: "grok-4, grok-3, grok-3-mini", + description: "Grok models from xAI", + }, + { + value: "OPENROUTER", + label: "OpenRouter", + example: "openrouter/anthropic/claude-4-opus", + description: "Unified API for multiple providers", + }, + { + value: "TOGETHER_AI", + label: "Together AI", + example: "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo", + description: "Fast open-source models", + }, + { + value: "FIREWORKS_AI", + label: "Fireworks AI", + example: "fireworks_ai/accounts/fireworks/models/llama-v3p3-70b-instruct", + description: "Scalable inference platform", + }, + { + value: "REPLICATE", + label: "Replicate", + example: "replicate/meta/llama-3-70b-instruct", + description: "ML model hosting platform", + }, + { + value: "PERPLEXITY", + label: "Perplexity", + example: "perplexity/sonar-pro, perplexity/sonar-reasoning", + description: "Search-augmented models", + }, + { + value: "OLLAMA", + label: "Ollama", + example: "ollama/llama3.1, ollama/mistral", + description: "Run models locally", + apiBase: "http://localhost:11434", + }, + { + value: "ALIBABA_QWEN", + label: "Alibaba Qwen", + example: "dashscope/qwen-plus, dashscope/qwen-turbo", + description: "Qwen series models", + apiBase: "https://dashscope.aliyuncs.com/compatible-mode/v1", + }, + { + value: "MOONSHOT", + label: "Moonshot (Kimi)", + example: "moonshot/kimi-latest, moonshot/kimi-k2-thinking", + description: "Kimi AI models", + apiBase: "https://api.moonshot.cn/v1", + }, + { + value: "ZHIPU", + label: "Zhipu (GLM)", + example: "glm-4.6, glm-4.6:exacto", + description: "GLM series models", + apiBase: "https://open.bigmodel.cn/api/paas/v4", + }, + { + value: "ANYSCALE", + label: "Anyscale", + example: "anyscale/meta-llama/Meta-Llama-3-70B-Instruct", + description: "Ray-based inference platform", + }, + { + value: "DEEPINFRA", + label: "DeepInfra", + example: "deepinfra/meta-llama/Meta-Llama-3.3-70B-Instruct", + description: "Serverless GPU inference", + }, + { + value: "CEREBRAS", + label: "Cerebras", + example: "cerebras/llama-3.3-70b, cerebras/qwen-3-32b", + description: "Fastest inference with Wafer-Scale Engine", + }, + { + value: "SAMBANOVA", + label: "SambaNova", + example: "sambanova/Meta-Llama-3.3-70B-Instruct", + description: "AI inference platform", + }, + { + value: "AI21", + label: "AI21 Labs", + example: "jamba-1.5-large, jamba-1.5-mini", + description: "Jamba series models", + }, + { + value: "CLOUDFLARE", + label: "Cloudflare Workers AI", + example: "cloudflare/@cf/meta/llama-2-7b-chat", + description: "AI on Cloudflare edge network", + }, + { + value: "DATABRICKS", + label: "Databricks", + example: "databricks/databricks-meta-llama-3-3-70b-instruct", + description: "Databricks Model Serving", + }, + { + value: "GITHUB_MODELS", + label: "GitHub Models", + example: "openai/gpt-5, meta/llama-3.1-405b-instruct", + description: "AI models from GitHub Marketplace", + apiBase: "https://models.github.ai/inference", + }, + { + value: "MINIMAX", + label: "MiniMax", + example: "MiniMax-M3, MiniMax-M2.7", + description: "High-performance models with up to 512K context", + apiBase: "https://api.minimax.io/v1", + }, + { + value: "CUSTOM", + label: "Custom Provider", + example: "your-custom-model", + description: "Custom OpenAI-compatible endpoint", + }, +]; diff --git a/surfsense_web/contracts/enums/toolIcons.tsx b/surfsense_web/contracts/enums/toolIcons.tsx index 8b2f08cd1..494c0eaee 100644 --- a/surfsense_web/contracts/enums/toolIcons.tsx +++ b/surfsense_web/contracts/enums/toolIcons.tsx @@ -1,5 +1,4 @@ import { - AlarmClock, Brain, Calendar, FileEdit, @@ -25,6 +24,7 @@ import { SearchCheck, Send, Trash2, + Workflow, Wrench, } from "lucide-react"; @@ -47,7 +47,7 @@ const TOOL_ICONS: Record<string, LucideIcon> = { scrape_webpage: ScanLine, web_search: Globe, // Automations - create_automation: AlarmClock, + create_automation: Workflow, // Memory update_memory: Brain, // Filesystem (built-in deepagent + middleware) diff --git a/surfsense_web/contracts/enums/vision-providers.ts b/surfsense_web/contracts/enums/vision-providers.ts new file mode 100644 index 000000000..477fd5c53 --- /dev/null +++ b/surfsense_web/contracts/enums/vision-providers.ts @@ -0,0 +1,168 @@ +import type { LLMModel } from "./llm-models"; + +export interface VisionProviderInfo { + value: string; + label: string; + example: string; + description: string; + apiBase?: string; +} + +export const VISION_PROVIDERS: VisionProviderInfo[] = [ + { + value: "OPENAI", + label: "OpenAI", + example: "gpt-4o, gpt-4o-mini", + description: "GPT-4o vision models", + }, + { + value: "ANTHROPIC", + label: "Anthropic", + example: "claude-sonnet-4-20250514", + description: "Claude vision models", + }, + { + value: "GOOGLE", + label: "Google AI Studio", + example: "gemini-2.5-flash, gemini-2.0-flash", + description: "Gemini vision models", + }, + { + value: "AZURE_OPENAI", + label: "Azure OpenAI", + example: "azure/gpt-4o", + description: "OpenAI vision models on Azure", + }, + { + value: "VERTEX_AI", + label: "Google Vertex AI", + example: "vertex_ai/gemini-2.5-flash", + description: "Gemini vision models on Vertex AI", + }, + { + value: "BEDROCK", + label: "AWS Bedrock", + example: "bedrock/anthropic.claude-sonnet-4-20250514-v1:0", + description: "Vision models on AWS Bedrock", + }, + { + value: "XAI", + label: "xAI", + example: "grok-2-vision", + description: "Grok vision models", + }, + { + value: "OPENROUTER", + label: "OpenRouter", + example: "openrouter/openai/gpt-4o", + description: "Vision models via OpenRouter", + }, + { + value: "OLLAMA", + label: "Ollama", + example: "llava, bakllava", + description: "Local vision models via Ollama", + apiBase: "http://localhost:11434", + }, + { + value: "GROQ", + label: "Groq", + example: "llama-4-scout-17b-16e-instruct", + description: "Vision models on Groq", + }, + { + value: "TOGETHER_AI", + label: "Together AI", + example: "meta-llama/Llama-4-Scout-17B-16E-Instruct", + description: "Vision models on Together AI", + }, + { + value: "FIREWORKS_AI", + label: "Fireworks AI", + example: "fireworks_ai/phi-3-vision-128k-instruct", + description: "Vision models on Fireworks AI", + }, + { + value: "DEEPSEEK", + label: "DeepSeek", + example: "deepseek-chat", + description: "DeepSeek vision models", + apiBase: "https://api.deepseek.com", + }, + { + value: "MISTRAL", + label: "Mistral", + example: "pixtral-large-latest", + description: "Pixtral vision models", + }, + { + value: "CUSTOM", + label: "Custom Provider", + example: "custom/my-vision-model", + description: "Custom OpenAI-compatible vision endpoint", + }, +]; + +export const VISION_MODELS: LLMModel[] = [ + { value: "gpt-4o", label: "GPT-4o", provider: "OPENAI", contextWindow: "128K" }, + { value: "gpt-4o-mini", label: "GPT-4o Mini", provider: "OPENAI", contextWindow: "128K" }, + { value: "gpt-4-turbo", label: "GPT-4 Turbo", provider: "OPENAI", contextWindow: "128K" }, + { + value: "claude-sonnet-4-20250514", + label: "Claude Sonnet 4", + provider: "ANTHROPIC", + contextWindow: "200K", + }, + { + value: "claude-3-7-sonnet-20250219", + label: "Claude 3.7 Sonnet", + provider: "ANTHROPIC", + contextWindow: "200K", + }, + { + value: "claude-3-5-sonnet-20241022", + label: "Claude 3.5 Sonnet", + provider: "ANTHROPIC", + contextWindow: "200K", + }, + { + value: "claude-3-opus-20240229", + label: "Claude 3 Opus", + provider: "ANTHROPIC", + contextWindow: "200K", + }, + { + value: "claude-3-haiku-20240307", + label: "Claude 3 Haiku", + provider: "ANTHROPIC", + contextWindow: "200K", + }, + { value: "gemini-2.5-flash", label: "Gemini 2.5 Flash", provider: "GOOGLE", contextWindow: "1M" }, + { value: "gemini-2.5-pro", label: "Gemini 2.5 Pro", provider: "GOOGLE", contextWindow: "1M" }, + { value: "gemini-2.0-flash", label: "Gemini 2.0 Flash", provider: "GOOGLE", contextWindow: "1M" }, + { value: "gemini-1.5-pro", label: "Gemini 1.5 Pro", provider: "GOOGLE", contextWindow: "1M" }, + { value: "gemini-1.5-flash", label: "Gemini 1.5 Flash", provider: "GOOGLE", contextWindow: "1M" }, + { + value: "pixtral-large-latest", + label: "Pixtral Large", + provider: "MISTRAL", + contextWindow: "128K", + }, + { value: "pixtral-12b-2409", label: "Pixtral 12B", provider: "MISTRAL", contextWindow: "128K" }, + { value: "grok-2-vision-1212", label: "Grok 2 Vision", provider: "XAI", contextWindow: "32K" }, + { value: "llava", label: "LLaVA", provider: "OLLAMA" }, + { value: "bakllava", label: "BakLLaVA", provider: "OLLAMA" }, + { value: "llava-llama3", label: "LLaVA Llama 3", provider: "OLLAMA" }, + { + value: "llama-4-scout-17b-16e-instruct", + label: "Llama 4 Scout 17B", + provider: "GROQ", + contextWindow: "128K", + }, + { + value: "meta-llama/Llama-4-Scout-17B-16E-Instruct", + label: "Llama 4 Scout 17B", + provider: "TOGETHER_AI", + contextWindow: "128K", + }, +]; diff --git a/surfsense_web/contracts/types/anonymous-chat.types.ts b/surfsense_web/contracts/types/anonymous-chat.types.ts index 21284267c..864810d8e 100644 --- a/surfsense_web/contracts/types/anonymous-chat.types.ts +++ b/surfsense_web/contracts/types/anonymous-chat.types.ts @@ -3,6 +3,7 @@ import { z } from "zod"; export const anonModel = z.object({ id: z.number(), name: z.string(), + description: z.string().nullable().optional(), provider: z.string(), model_name: z.string(), billing_tier: z.string().default("free"), diff --git a/surfsense_web/contracts/types/auth.types.ts b/surfsense_web/contracts/types/auth.types.ts index b630c461b..29a296c11 100644 --- a/surfsense_web/contracts/types/auth.types.ts +++ b/surfsense_web/contracts/types/auth.types.ts @@ -20,7 +20,8 @@ export const registerRequest = loginRequest.omit({ grant_type: true, username: t export const registerResponse = registerRequest.omit({ password: true }).extend({ id: z.string(), - credit_micros_balance: z.number(), + pages_limit: z.number(), + pages_used: z.number(), }); export type LoginRequest = z.infer<typeof loginRequest>; diff --git a/surfsense_web/contracts/types/automation.types.ts b/surfsense_web/contracts/types/automation.types.ts index 6331a663c..45670d245 100644 --- a/surfsense_web/contracts/types/automation.types.ts +++ b/surfsense_web/contracts/types/automation.types.ts @@ -63,9 +63,9 @@ export type Inputs = z.infer<typeof inputs>; // Captured model snapshot (server-managed). Set at create time and preserved // across edits so runs are insulated from later chat/search-space model changes. export const automationModels = z.object({ - chat_model_id: z.number().int().default(0), - image_gen_model_id: z.number().int().default(0), - vision_model_id: z.number().int().default(0), + agent_llm_id: z.number().int().default(0), + image_generation_config_id: z.number().int().default(0), + vision_llm_config_id: z.number().int().default(0), }); export type AutomationModels = z.infer<typeof automationModels>; diff --git a/surfsense_web/contracts/types/inbox.types.ts b/surfsense_web/contracts/types/inbox.types.ts index 94e533809..b4cf01710 100644 --- a/surfsense_web/contracts/types/inbox.types.ts +++ b/surfsense_web/contracts/types/inbox.types.ts @@ -11,7 +11,7 @@ export const inboxItemTypeEnum = z.enum([ "document_processing", "new_mention", "comment_reply", - "insufficient_credits", + "page_limit_exceeded", ]); /** @@ -116,17 +116,15 @@ export const commentReplyMetadata = z.object({ }); /** - * Insufficient credits metadata schema. - * - * ``balance_micros`` / ``required_micros`` are integer micro-USD - * (1_000_000 == $1.00); the UI divides by 1M when displaying. + * Page limit exceeded metadata schema */ -export const insufficientCreditsMetadata = baseInboxItemMetadata.extend({ +export const pageLimitExceededMetadata = baseInboxItemMetadata.extend({ document_name: z.string(), document_type: z.string(), - balance_micros: z.number(), - required_micros: z.number(), - error_type: z.literal("insufficient_credits"), + pages_used: z.number(), + pages_limit: z.number(), + pages_to_add: z.number(), + error_type: z.literal("page_limit_exceeded"), // Navigation target for frontend action_url: z.string(), action_label: z.string(), @@ -142,7 +140,7 @@ export const inboxItemMetadata = z.union([ documentProcessingMetadata, newMentionMetadata, commentReplyMetadata, - insufficientCreditsMetadata, + pageLimitExceededMetadata, baseInboxItemMetadata, ]); @@ -190,9 +188,9 @@ export const commentReplyInboxItem = inboxItem.extend({ metadata: commentReplyMetadata, }); -export const insufficientCreditsInboxItem = inboxItem.extend({ - type: z.literal("insufficient_credits"), - metadata: insufficientCreditsMetadata, +export const pageLimitExceededInboxItem = inboxItem.extend({ + type: z.literal("page_limit_exceeded"), + metadata: pageLimitExceededMetadata, }); // ============================================================================= @@ -343,12 +341,12 @@ export function isCommentReplyMetadata(metadata: unknown): metadata is CommentRe } /** - * Type guard for InsufficientCreditsMetadata + * Type guard for PageLimitExceededMetadata */ -export function isInsufficientCreditsMetadata( +export function isPageLimitExceededMetadata( metadata: unknown -): metadata is InsufficientCreditsMetadata { - return insufficientCreditsMetadata.safeParse(metadata).success; +): metadata is PageLimitExceededMetadata { + return pageLimitExceededMetadata.safeParse(metadata).success; } /** @@ -363,7 +361,7 @@ export function parseInboxItemMetadata( | DocumentProcessingMetadata | NewMentionMetadata | CommentReplyMetadata - | InsufficientCreditsMetadata + | PageLimitExceededMetadata | null { switch (type) { case "connector_indexing": { @@ -386,8 +384,8 @@ export function parseInboxItemMetadata( const result = commentReplyMetadata.safeParse(metadata); return result.success ? result.data : null; } - case "insufficient_credits": { - const result = insufficientCreditsMetadata.safeParse(metadata); + case "page_limit_exceeded": { + const result = pageLimitExceededMetadata.safeParse(metadata); return result.success ? result.data : null; } default: @@ -408,7 +406,7 @@ export type ConnectorDeletionMetadata = z.infer<typeof connectorDeletionMetadata export type DocumentProcessingMetadata = z.infer<typeof documentProcessingMetadata>; export type NewMentionMetadata = z.infer<typeof newMentionMetadata>; export type CommentReplyMetadata = z.infer<typeof commentReplyMetadata>; -export type InsufficientCreditsMetadata = z.infer<typeof insufficientCreditsMetadata>; +export type PageLimitExceededMetadata = z.infer<typeof pageLimitExceededMetadata>; export type InboxItemMetadata = z.infer<typeof inboxItemMetadata>; export type InboxItem = z.infer<typeof inboxItem>; export type ConnectorIndexingInboxItem = z.infer<typeof connectorIndexingInboxItem>; @@ -416,7 +414,7 @@ export type ConnectorDeletionInboxItem = z.infer<typeof connectorDeletionInboxIt export type DocumentProcessingInboxItem = z.infer<typeof documentProcessingInboxItem>; export type NewMentionInboxItem = z.infer<typeof newMentionInboxItem>; export type CommentReplyInboxItem = z.infer<typeof commentReplyInboxItem>; -export type InsufficientCreditsInboxItem = z.infer<typeof insufficientCreditsInboxItem>; +export type PageLimitExceededInboxItem = z.infer<typeof pageLimitExceededInboxItem>; // API Request/Response types export type GetNotificationsRequest = z.infer<typeof getNotificationsRequest>; diff --git a/surfsense_web/contracts/types/incentive-tasks.types.ts b/surfsense_web/contracts/types/incentive-tasks.types.ts index abe91d905..c45121c29 100644 --- a/surfsense_web/contracts/types/incentive-tasks.types.ts +++ b/surfsense_web/contracts/types/incentive-tasks.types.ts @@ -12,8 +12,7 @@ export const incentiveTaskInfo = z.object({ task_type: incentiveTaskTypeEnum, title: z.string(), description: z.string(), - // Reward in micro-USD (1_000_000 == $1.00) credited to the wallet. - credit_micros_reward: z.number(), + pages_reward: z.number(), action_url: z.string(), completed: z.boolean(), completed_at: z.string().nullable(), @@ -24,7 +23,7 @@ export const incentiveTaskInfo = z.object({ */ export const getIncentiveTasksResponse = z.object({ tasks: z.array(incentiveTaskInfo), - total_credit_micros_earned: z.number(), + total_pages_earned: z.number(), }); /** @@ -33,8 +32,8 @@ export const getIncentiveTasksResponse = z.object({ export const completeTaskSuccessResponse = z.object({ success: z.literal(true), message: z.string(), - credit_micros_awarded: z.number(), - new_balance_micros: z.number(), + pages_awarded: z.number(), + new_pages_limit: z.number(), }); /** diff --git a/surfsense_web/contracts/types/model-connections.types.ts b/surfsense_web/contracts/types/model-connections.types.ts deleted file mode 100644 index 0f0c7591e..000000000 --- a/surfsense_web/contracts/types/model-connections.types.ts +++ /dev/null @@ -1,144 +0,0 @@ -import { z } from "zod"; - -export const connectionScopeEnum = z.enum(["GLOBAL", "SEARCH_SPACE", "USER"]); -export const modelSourceEnum = z.enum(["DISCOVERED", "MANUAL"]); - -export const modelRead = z.object({ - id: z.number(), - connection_id: z.number(), - model_id: z.string(), - display_name: z.string().nullable().optional(), - source: z.union([modelSourceEnum, z.string()]), - supports_chat: z.boolean().nullable().optional(), - max_input_tokens: z.number().nullable().optional(), - supports_image_input: z.boolean().nullable().optional(), - supports_tools: z.boolean().nullable().optional(), - supports_image_generation: z.boolean().nullable().optional(), - capabilities_override: z.record(z.string(), z.any()).default({}), - enabled: z.boolean(), - billing_tier: z.string().nullable().optional(), - catalog: z.record(z.string(), z.any()).default({}), - created_at: z.string().nullable().optional(), -}); - -export const connectionRead = z.object({ - id: z.number(), - provider: z.string(), - base_url: z.string().nullable().optional(), - api_key: z.string().nullable().optional(), - extra: z.record(z.string(), z.any()).default({}), - scope: z.union([connectionScopeEnum, z.string()]), - search_space_id: z.number().nullable().optional(), - user_id: z.string().nullable().optional(), - enabled: z.boolean(), - has_api_key: z.boolean(), - models: z.array(modelRead).default([]), - created_at: z.string().nullable().optional(), -}); - -export const modelSelection = z.object({ - model_id: z.string().min(1), - display_name: z.string().nullable().optional(), - source: z.union([modelSourceEnum, z.string()]).default("DISCOVERED"), - supports_chat: z.boolean().nullable().optional(), - max_input_tokens: z.number().nullable().optional(), - supports_image_input: z.boolean().nullable().optional(), - supports_tools: z.boolean().nullable().optional(), - supports_image_generation: z.boolean().nullable().optional(), - enabled: z.boolean().default(false), - metadata: z.record(z.string(), z.any()).default({}), -}); - -export const modelPreviewRead = modelSelection; - -export const connectionCreateRequest = z.object({ - provider: z.string().min(1), - base_url: z.string().nullable().optional(), - api_key: z.string().nullable().optional(), - extra: z.record(z.string(), z.any()).default({}), - scope: connectionScopeEnum.default("SEARCH_SPACE"), - search_space_id: z.number().nullable().optional(), - enabled: z.boolean().default(true), - models: z.array(modelSelection).default([]), -}); - -export const modelTestPreviewRequest = connectionCreateRequest.extend({ - model_id: z.string().min(1), -}); - -export const connectionUpdateRequest = z.object({ - provider: z.string().nullable().optional(), - base_url: z.string().nullable().optional(), - api_key: z.string().nullable().optional(), - extra: z.record(z.string(), z.any()).optional(), - enabled: z.boolean().optional(), -}); - -export const modelCreateRequest = z.object({ - model_id: z.string().min(1), - display_name: z.string().nullable().optional(), -}); - -export const modelUpdateRequest = z.object({ - display_name: z.string().nullable().optional(), - enabled: z.boolean().optional(), - supports_chat: z.boolean().nullable().optional(), - max_input_tokens: z.number().nullable().optional(), - supports_image_input: z.boolean().nullable().optional(), - supports_tools: z.boolean().nullable().optional(), - supports_image_generation: z.boolean().nullable().optional(), - capabilities_override: z.record(z.string(), z.any()).optional(), -}); - -export const modelsBulkUpdateRequest = z.object({ - model_ids: z.array(z.number()).min(1).max(1000), - enabled: z.boolean(), -}); - -export const verifyConnectionResponse = z.object({ - status: z.string(), - ok: z.boolean(), - message: z.string().default(""), -}); - -export const modelRoles = z.object({ - chat_model_id: z.number().nullable().optional(), - vision_model_id: z.number().nullable().optional(), - image_gen_model_id: z.number().nullable().optional(), -}); - -export const globalLlmConfigStatus = z.object({ - exists: z.boolean(), -}); - -export const modelProviderRead = z.object({ - provider: z.string(), - transport: z.string(), - discovery: z.string(), - default_base_url: z.string().nullable().optional(), - base_url_required: z.boolean(), - auth_style: z.string(), - local_only: z.boolean().default(false), -}); - -export const modelProviderListResponse = z.array(modelProviderRead); - -export const connectionListResponse = z.array(connectionRead); -export const modelListResponse = z.array(modelRead); -export const modelPreviewListResponse = z.array(modelPreviewRead); - -export type ConnectionScope = z.infer<typeof connectionScopeEnum>; -export type ModelRead = z.infer<typeof modelRead>; -export type ModelPreviewRead = z.infer<typeof modelPreviewRead>; -export type ModelSelection = z.infer<typeof modelSelection>; -export type ConnectionRead = z.infer<typeof connectionRead>; -export type ConnectionCreateRequest = z.infer<typeof connectionCreateRequest>; -export type ModelTestPreviewRequest = z.infer<typeof modelTestPreviewRequest>; -export type ConnectionUpdateRequest = z.infer<typeof connectionUpdateRequest>; -export type ModelCreateRequest = z.infer<typeof modelCreateRequest>; -export type ModelUpdateRequest = z.infer<typeof modelUpdateRequest>; -export type ModelsBulkUpdateRequest = z.infer<typeof modelsBulkUpdateRequest>; -export type ModelRoles = z.infer<typeof modelRoles>; -export type GlobalLlmConfigStatus = z.infer<typeof globalLlmConfigStatus>; -export type VerifyConnectionResponse = z.infer<typeof verifyConnectionResponse>; -export type ModelProviderRead = z.infer<typeof modelProviderRead>; diff --git a/surfsense_web/contracts/types/new-llm-config.types.ts b/surfsense_web/contracts/types/new-llm-config.types.ts new file mode 100644 index 000000000..2fa7a37be --- /dev/null +++ b/surfsense_web/contracts/types/new-llm-config.types.ts @@ -0,0 +1,476 @@ +import { z } from "zod"; + +/** + * LiteLLM Provider enum - all supported LLM providers + */ +export const liteLLMProviderEnum = z.enum([ + "OPENAI", + "ANTHROPIC", + "GOOGLE", + "AZURE_OPENAI", + "BEDROCK", + "VERTEX_AI", + "GROQ", + "COHERE", + "MISTRAL", + "DEEPSEEK", + "XAI", + "OPENROUTER", + "TOGETHER_AI", + "FIREWORKS_AI", + "REPLICATE", + "PERPLEXITY", + "OLLAMA", + "ALIBABA_QWEN", + "MOONSHOT", + "ZHIPU", + "ANYSCALE", + "DEEPINFRA", + "CEREBRAS", + "SAMBANOVA", + "AI21", + "CLOUDFLARE", + "DATABRICKS", + "COMETAPI", + "HUGGINGFACE", + "GITHUB_MODELS", + "MINIMAX", + "CUSTOM", +]); + +export type LiteLLMProvider = z.infer<typeof liteLLMProviderEnum>; + +/** + * NewLLMConfig - combines model settings with prompt configuration + */ +export const newLLMConfig = z.object({ + id: z.number(), + name: z.string().max(100), + description: z.string().max(500).nullable().optional(), + + // Model Configuration + provider: liteLLMProviderEnum, + custom_provider: z.string().max(100).nullable().optional(), + model_name: z.string().max(100), + api_key: z.string(), + api_base: z.string().max(500).nullable().optional(), + litellm_params: z.record(z.string(), z.any()).nullable().optional(), + + // Prompt Configuration + system_instructions: z.string().default(""), + use_default_system_instructions: z.boolean().default(true), + citations_enabled: z.boolean().default(true), + + // Metadata + created_at: z.string(), + search_space_id: z.number(), + user_id: z.string(), + + // Capability flag — derived server-side at the route boundary from + // LiteLLM's authoritative model map. There is no DB column. Default + // `true` is the conservative-allow stance for unknown / unmapped + // BYOK rows; the streaming-task safety net is the only place a + // `false` actually blocks a request. + supports_image_input: z.boolean().default(true), +}); + +/** + * Public version without api_key (for list views) + */ +export const newLLMConfigPublic = newLLMConfig.omit({ api_key: true }); + +/** + * Create NewLLMConfig + * + * `supports_image_input` is omitted because it is derived server-side + * from LiteLLM's model map at read time — there is no DB column to + * persist a client-supplied value into. + */ +export const createNewLLMConfigRequest = newLLMConfig.omit({ + id: true, + created_at: true, + user_id: true, + supports_image_input: true, +}); + +export const createNewLLMConfigResponse = newLLMConfig; + +/** + * Get NewLLMConfigs list + */ +export const getNewLLMConfigsRequest = z.object({ + search_space_id: z.number(), + skip: z.number().optional(), + limit: z.number().optional(), +}); + +export const getNewLLMConfigsResponse = z.array(newLLMConfig); + +/** + * Get single NewLLMConfig + */ +export const getNewLLMConfigRequest = z.object({ + id: z.number(), +}); + +export const getNewLLMConfigResponse = newLLMConfig; + +/** + * Update NewLLMConfig + */ +export const updateNewLLMConfigRequest = z.object({ + id: z.number(), + data: newLLMConfig + .omit({ + id: true, + created_at: true, + search_space_id: true, + user_id: true, + // Derived server-side; not part of the writable surface. + supports_image_input: true, + }) + .partial(), +}); + +export const updateNewLLMConfigResponse = newLLMConfig; + +/** + * Delete NewLLMConfig + */ +export const deleteNewLLMConfigRequest = z.object({ + id: z.number(), +}); + +export const deleteNewLLMConfigResponse = z.object({ + message: z.string(), + id: z.number(), +}); + +/** + * Get default system instructions + */ +export const getDefaultSystemInstructionsResponse = z.object({ + default_system_instructions: z.string(), +}); + +/** + * Global NewLLMConfig - from YAML, has negative IDs + * ID 0 is reserved for "Auto" mode which uses LiteLLM Router for load balancing + */ +export const globalNewLLMConfig = z.object({ + id: z.number(), // 0 for Auto mode, negative IDs for global configs + name: z.string(), + description: z.string().nullable().optional(), + + // Model Configuration (no api_key) + provider: z.string(), // String because YAML doesn't enforce enum, "AUTO" for Auto mode + custom_provider: z.string().nullable().optional(), + model_name: z.string(), + api_base: z.string().nullable().optional(), + litellm_params: z.record(z.string(), z.any()).nullable().optional(), + + // Prompt Configuration + system_instructions: z.string().default(""), + use_default_system_instructions: z.boolean().default(true), + citations_enabled: z.boolean().default(true), + + is_global: z.literal(true), + is_auto_mode: z.boolean().optional().default(false), // True only for Auto mode (ID 0) + + // Token quota and billing policy + billing_tier: z.string().default("free"), + is_premium: z.boolean().default(false), + anonymous_enabled: z.boolean().default(false), + seo_enabled: z.boolean().default(false), + seo_slug: z.string().nullable().optional(), + seo_title: z.string().nullable().optional(), + seo_description: z.string().nullable().optional(), + quota_reserve_tokens: z.number().nullable().optional(), + // Capability flag — true when the model can accept image inputs. + // Resolved server-side (OpenRouter dynamic configs use the OR + // `architecture.input_modalities` field; YAML / BYOK use LiteLLM's + // authoritative `supports_vision` map). The chat selector renders + // an amber "No image" hint when this is false and there are + // pending image attachments, but does not block selection — the + // backend safety net only rejects when LiteLLM *explicitly* marks + // the model as text-only, so unknown / new models still flow + // through. Default `true` matches that conservative-allow stance. + supports_image_input: z.boolean().default(true), +}); + +export const getGlobalNewLLMConfigsResponse = z.array(globalNewLLMConfig); + +// ============================================================================= +// Image Generation Config (separate table from NewLLMConfig) +// ============================================================================= + +/** + * ImageGenProvider enum - only providers that support image generation + * See: https://docs.litellm.ai/docs/image_generation#supported-providers + */ +export const imageGenProviderEnum = z.enum([ + "OPENAI", + "AZURE_OPENAI", + "GOOGLE", + "VERTEX_AI", + "BEDROCK", + "RECRAFT", + "OPENROUTER", + "XINFERENCE", + "NSCALE", +]); + +export type ImageGenProvider = z.infer<typeof imageGenProviderEnum>; + +/** + * ImageGenerationConfig - user-created image gen model configs + * Separate from NewLLMConfig: no system_instructions, no citations_enabled. + */ +export const imageGenerationConfig = z.object({ + id: z.number(), + name: z.string().max(100), + description: z.string().max(500).nullable().optional(), + provider: imageGenProviderEnum, + custom_provider: z.string().max(100).nullable().optional(), + model_name: z.string().max(100), + api_key: z.string(), + api_base: z.string().max(500).nullable().optional(), + api_version: z.string().max(50).nullable().optional(), + litellm_params: z.record(z.string(), z.any()).nullable().optional(), + created_at: z.string(), + search_space_id: z.number(), + user_id: z.string(), +}); + +export const createImageGenConfigRequest = imageGenerationConfig.omit({ + id: true, + created_at: true, + user_id: true, +}); + +export const createImageGenConfigResponse = imageGenerationConfig; + +export const getImageGenConfigsResponse = z.array(imageGenerationConfig); + +export const updateImageGenConfigRequest = z.object({ + id: z.number(), + data: imageGenerationConfig + .omit({ id: true, created_at: true, search_space_id: true, user_id: true }) + .partial(), +}); + +export const updateImageGenConfigResponse = imageGenerationConfig; + +export const deleteImageGenConfigResponse = z.object({ + message: z.string(), + id: z.number(), +}); + +/** + * Global Image Generation Config - from YAML, has negative IDs + * ID 0 is reserved for "Auto" mode (LiteLLM Router load balancing) + */ +export const globalImageGenConfig = z.object({ + id: z.number(), + name: z.string(), + description: z.string().nullable().optional(), + provider: z.string(), + custom_provider: z.string().nullable().optional(), + model_name: z.string(), + api_base: z.string().nullable().optional(), + api_version: z.string().nullable().optional(), + litellm_params: z.record(z.string(), z.any()).nullable().optional(), + is_global: z.literal(true), + is_auto_mode: z.boolean().optional().default(false), + billing_tier: z.string().default("free"), + // Mirrors `globalNewLLMConfig.is_premium` so the new-chat selector's + // Free/Premium badge logic lights up automatically for image-gen too. + is_premium: z.boolean().default(false), + quota_reserve_micros: z.number().nullable().optional(), +}); + +export const getGlobalImageGenConfigsResponse = z.array(globalImageGenConfig); + +// ============================================================================= +// Vision LLM Config (separate table for vision-capable models) +// ============================================================================= + +export const visionProviderEnum = z.enum([ + "OPENAI", + "ANTHROPIC", + "GOOGLE", + "AZURE_OPENAI", + "VERTEX_AI", + "BEDROCK", + "XAI", + "OPENROUTER", + "OLLAMA", + "GROQ", + "TOGETHER_AI", + "FIREWORKS_AI", + "DEEPSEEK", + "MISTRAL", + "CUSTOM", +]); + +export type VisionProvider = z.infer<typeof visionProviderEnum>; + +export const visionLLMConfig = z.object({ + id: z.number(), + name: z.string().max(100), + description: z.string().max(500).nullable().optional(), + provider: visionProviderEnum, + custom_provider: z.string().max(100).nullable().optional(), + model_name: z.string().max(100), + api_key: z.string(), + api_base: z.string().max(500).nullable().optional(), + api_version: z.string().max(50).nullable().optional(), + litellm_params: z.record(z.string(), z.any()).nullable().optional(), + created_at: z.string(), + search_space_id: z.number(), + user_id: z.string(), +}); + +export const createVisionLLMConfigRequest = visionLLMConfig.omit({ + id: true, + created_at: true, + user_id: true, +}); + +export const createVisionLLMConfigResponse = visionLLMConfig; + +export const getVisionLLMConfigsResponse = z.array(visionLLMConfig); + +export const updateVisionLLMConfigRequest = z.object({ + id: z.number(), + data: visionLLMConfig + .omit({ id: true, created_at: true, search_space_id: true, user_id: true }) + .partial(), +}); + +export const updateVisionLLMConfigResponse = visionLLMConfig; + +export const deleteVisionLLMConfigResponse = z.object({ + message: z.string(), + id: z.number(), +}); + +export const globalVisionLLMConfig = z.object({ + id: z.number(), + name: z.string(), + description: z.string().nullable().optional(), + provider: z.string(), + custom_provider: z.string().nullable().optional(), + model_name: z.string(), + api_base: z.string().nullable().optional(), + api_version: z.string().nullable().optional(), + litellm_params: z.record(z.string(), z.any()).nullable().optional(), + is_global: z.literal(true), + is_auto_mode: z.boolean().optional().default(false), + billing_tier: z.string().default("free"), + // Mirrors `globalNewLLMConfig.is_premium` so the new-chat selector's + // Free/Premium badge logic lights up automatically for vision too. + is_premium: z.boolean().default(false), + quota_reserve_tokens: z.number().nullable().optional(), + input_cost_per_token: z.number().nullable().optional(), + output_cost_per_token: z.number().nullable().optional(), +}); + +export const getGlobalVisionLLMConfigsResponse = z.array(globalVisionLLMConfig); + +// ============================================================================= +// LLM Preferences (Role Assignments) +// ============================================================================= + +export const llmPreferences = z.object({ + agent_llm_id: z.union([z.number(), z.null()]).optional(), + image_generation_config_id: z.union([z.number(), z.null()]).optional(), + vision_llm_config_id: z.union([z.number(), z.null()]).optional(), + agent_llm: z.union([z.record(z.string(), z.unknown()), z.null()]).optional(), + image_generation_config: z.union([z.record(z.string(), z.unknown()), z.null()]).optional(), + vision_llm_config: z.union([z.record(z.string(), z.unknown()), z.null()]).optional(), +}); + +/** + * Get LLM preferences + */ +export const getLLMPreferencesRequest = z.object({ + search_space_id: z.number(), +}); + +export const getLLMPreferencesResponse = llmPreferences; + +/** + * Update LLM preferences + */ +export const updateLLMPreferencesRequest = z.object({ + search_space_id: z.number(), + data: llmPreferences.pick({ + agent_llm_id: true, + image_generation_config_id: true, + vision_llm_config_id: true, + }), +}); + +export const updateLLMPreferencesResponse = llmPreferences; + +// ============================================================================= +// Model List (dynamic catalogue from OpenRouter API) +// ============================================================================= + +export const modelListItem = z.object({ + value: z.string(), + label: z.string(), + provider: z.string(), + context_window: z.string().nullable().optional(), +}); + +export const getModelListResponse = z.array(modelListItem); + +export type ModelListItem = z.infer<typeof modelListItem>; +export type GetModelListResponse = z.infer<typeof getModelListResponse>; + +// ============================================================================= +// Type Exports +// ============================================================================= + +export type NewLLMConfig = z.infer<typeof newLLMConfig>; +export type NewLLMConfigPublic = z.infer<typeof newLLMConfigPublic>; +export type CreateNewLLMConfigRequest = z.infer<typeof createNewLLMConfigRequest>; +export type CreateNewLLMConfigResponse = z.infer<typeof createNewLLMConfigResponse>; +export type GetNewLLMConfigsRequest = z.infer<typeof getNewLLMConfigsRequest>; +export type GetNewLLMConfigsResponse = z.infer<typeof getNewLLMConfigsResponse>; +export type GetNewLLMConfigRequest = z.infer<typeof getNewLLMConfigRequest>; +export type GetNewLLMConfigResponse = z.infer<typeof getNewLLMConfigResponse>; +export type UpdateNewLLMConfigRequest = z.infer<typeof updateNewLLMConfigRequest>; +export type UpdateNewLLMConfigResponse = z.infer<typeof updateNewLLMConfigResponse>; +export type DeleteNewLLMConfigRequest = z.infer<typeof deleteNewLLMConfigRequest>; +export type DeleteNewLLMConfigResponse = z.infer<typeof deleteNewLLMConfigResponse>; +export type GetDefaultSystemInstructionsResponse = z.infer< + typeof getDefaultSystemInstructionsResponse +>; +export type GlobalNewLLMConfig = z.infer<typeof globalNewLLMConfig>; +export type GetGlobalNewLLMConfigsResponse = z.infer<typeof getGlobalNewLLMConfigsResponse>; +export type ImageGenerationConfig = z.infer<typeof imageGenerationConfig>; +export type CreateImageGenConfigRequest = z.infer<typeof createImageGenConfigRequest>; +export type CreateImageGenConfigResponse = z.infer<typeof createImageGenConfigResponse>; +export type GetImageGenConfigsResponse = z.infer<typeof getImageGenConfigsResponse>; +export type UpdateImageGenConfigRequest = z.infer<typeof updateImageGenConfigRequest>; +export type UpdateImageGenConfigResponse = z.infer<typeof updateImageGenConfigResponse>; +export type DeleteImageGenConfigResponse = z.infer<typeof deleteImageGenConfigResponse>; +export type GlobalImageGenConfig = z.infer<typeof globalImageGenConfig>; +export type GetGlobalImageGenConfigsResponse = z.infer<typeof getGlobalImageGenConfigsResponse>; +export type VisionLLMConfig = z.infer<typeof visionLLMConfig>; +export type CreateVisionLLMConfigRequest = z.infer<typeof createVisionLLMConfigRequest>; +export type CreateVisionLLMConfigResponse = z.infer<typeof createVisionLLMConfigResponse>; +export type GetVisionLLMConfigsResponse = z.infer<typeof getVisionLLMConfigsResponse>; +export type UpdateVisionLLMConfigRequest = z.infer<typeof updateVisionLLMConfigRequest>; +export type UpdateVisionLLMConfigResponse = z.infer<typeof updateVisionLLMConfigResponse>; +export type DeleteVisionLLMConfigResponse = z.infer<typeof deleteVisionLLMConfigResponse>; +export type GlobalVisionLLMConfig = z.infer<typeof globalVisionLLMConfig>; +export type GetGlobalVisionLLMConfigsResponse = z.infer<typeof getGlobalVisionLLMConfigsResponse>; +export type LLMPreferences = z.infer<typeof llmPreferences>; +export type GetLLMPreferencesRequest = z.infer<typeof getLLMPreferencesRequest>; +export type GetLLMPreferencesResponse = z.infer<typeof getLLMPreferencesResponse>; +export type UpdateLLMPreferencesRequest = z.infer<typeof updateLLMPreferencesRequest>; +export type UpdateLLMPreferencesResponse = z.infer<typeof updateLLMPreferencesResponse>; diff --git a/surfsense_web/contracts/types/podcast.types.ts b/surfsense_web/contracts/types/podcast.types.ts deleted file mode 100644 index 31311c469..000000000 --- a/surfsense_web/contracts/types/podcast.types.ts +++ /dev/null @@ -1,157 +0,0 @@ -import { z } from "zod"; - -// ============================================================================= -// Lifecycle — mirror app/podcasts/persistence/enums/podcast_status.py -// ============================================================================= - -export const podcastStatus = z.enum([ - "pending", - "awaiting_brief", - "drafting", - "awaiting_review", - "rendering", - "ready", - "failed", - "cancelled", -]); -export type PodcastStatus = z.infer<typeof podcastStatus>; - -/** - * States waiting on user input before the lifecycle can proceed. The brief is - * the only approval gate; `awaiting_review` survives in the enum for legacy - * rows but is never entered anymore. - */ -export const GATE_STATUSES: ReadonlySet<PodcastStatus> = new Set(["awaiting_brief"]); - -/** - * States from which no further transition is possible. A `ready` episode is - * not terminal: it can be sent back to drafting for regeneration. - */ -export const TERMINAL_STATUSES: ReadonlySet<PodcastStatus> = new Set(["failed", "cancelled"]); - -// ============================================================================= -// Brief (spec) — mirror app/podcasts/schemas/spec.py -// ============================================================================= - -export const speakerRole = z.enum(["host", "cohost", "guest", "expert", "narrator"]); -export type SpeakerRole = z.infer<typeof speakerRole>; - -export const podcastStyle = z.enum([ - "conversational", - "interview", - "debate", - "monologue", - "narrative", -]); -export type PodcastStyle = z.infer<typeof podcastStyle>; - -export const MAX_SPEAKERS = 6; - -export const MAX_DURATION_SECONDS = 24 * 60 * 60; -export const MIN_DURATION_SECONDS = 15; -export const DEFAULT_MIN_SECONDS = 20; -export const DEFAULT_MAX_SECONDS = 30; - -export const speakerSpec = z.object({ - slot: z.number().int().min(0), - name: z.string().min(1).max(120), - role: speakerRole, - voice_id: z.string().min(1), -}); -export type SpeakerSpec = z.infer<typeof speakerSpec>; - -export const durationTarget = z.preprocess( - (raw) => { - if (raw && typeof raw === "object" && "min_minutes" in raw && !("min_seconds" in raw)) { - const legacy = raw as { min_minutes: number; max_minutes: number }; - return { - min_seconds: legacy.min_minutes * 60, - max_seconds: legacy.max_minutes * 60, - }; - } - return raw; - }, - z - .object({ - min_seconds: z.number().int().min(MIN_DURATION_SECONDS).max(MAX_DURATION_SECONDS), - max_seconds: z.number().int().min(MIN_DURATION_SECONDS).max(MAX_DURATION_SECONDS), - }) - .refine((duration) => duration.max_seconds >= duration.min_seconds, { - message: "Max length must be at least min length", - path: ["max_seconds"], - }) -); -export type DurationTarget = z.infer<typeof durationTarget>; - -export const podcastSpec = z - .object({ - language: z.string().min(2), - style: podcastStyle, - speakers: z.array(speakerSpec).min(1).max(MAX_SPEAKERS), - duration: durationTarget, - focus: z.string().max(2000).nullable().optional(), - }) - // Mirrors the backend invariant: one voice is what "monologue" means. - .refine((spec) => spec.style !== "monologue" || spec.speakers.length === 1, { - message: "A monologue has exactly one speaker", - path: ["speakers"], - }); -export type PodcastSpec = z.infer<typeof podcastSpec>; - -// ============================================================================= -// Transcript — mirror app/podcasts/schemas/transcript.py -// ============================================================================= - -export const transcriptTurn = z.object({ - speaker: z.number().int().min(0), - text: z.string().min(1), -}); -export type TranscriptTurn = z.infer<typeof transcriptTurn>; - -export const transcript = z.object({ - turns: z.array(transcriptTurn).min(1), -}); -export type Transcript = z.infer<typeof transcript>; - -// ============================================================================= -// API shapes — mirror app/podcasts/api/schemas.py -// ============================================================================= - -export const voiceOption = z.object({ - voice_id: z.string(), - display_name: z.string(), - language: z.string(), - gender: z.string(), -}); -export type VoiceOption = z.infer<typeof voiceOption>; - -// The languages the backend offers for the active TTS provider. When -// `allows_custom` is true the list is a starting point and any BCP-47 tag -// may be entered. -export const languageOptions = z.object({ - languages: z.array(z.string()), - allows_custom: z.boolean(), -}); -export type LanguageOptions = z.infer<typeof languageOptions>; - -export const updateSpecRequest = z.object({ - spec: podcastSpec, - expected_version: z.number().int().min(1), -}); -export type UpdateSpecRequest = z.infer<typeof updateSpecRequest>; - -export const podcastDetail = z.object({ - id: z.number(), - title: z.string(), - status: podcastStatus, - spec_version: z.number(), - spec: podcastSpec.nullable(), - transcript: transcript.nullable(), - has_audio: z.boolean(), - duration_seconds: z.number().nullable(), - error: z.string().nullable(), - created_at: z.string(), - search_space_id: z.number(), - thread_id: z.number().nullable(), -}); -export type PodcastDetail = z.infer<typeof podcastDetail>; diff --git a/surfsense_web/contracts/types/stripe.types.ts b/surfsense_web/contracts/types/stripe.types.ts index c548a3dd0..35ec0cb17 100644 --- a/surfsense_web/contracts/types/stripe.types.ts +++ b/surfsense_web/contracts/types/stripe.types.ts @@ -1,49 +1,20 @@ import { z } from "zod"; -export const purchaseStatusEnum = z.enum(["pending", "completed", "failed"]); +export const pagePurchaseStatusEnum = z.enum(["pending", "completed", "failed"]); -// --------------------------------------------------------------------------- -// Credit purchases ($1 packs that top up credit_micros_balance) -// --------------------------------------------------------------------------- - -export const createCreditCheckoutSessionRequest = z.object({ - quantity: z.number().int().min(1).max(10_000), +export const createCheckoutSessionRequest = z.object({ + quantity: z.number().int().min(1).max(100), search_space_id: z.number().int().min(1), }); -export const createCreditCheckoutSessionResponse = z.object({ +export const createCheckoutSessionResponse = z.object({ checkout_url: z.string(), }); -// Credit balance availability + records. Unit is integer micro-USD -// (1_000_000 == $1.00); the FE divides by 1M when displaying. -export const creditStripeStatusResponse = z.object({ - credit_buying_enabled: z.boolean(), - credit_micros_balance: z.number().default(0), +export const stripeStatusResponse = z.object({ + page_buying_enabled: z.boolean(), }); -export const creditPurchase = z.object({ - id: z.uuid(), - stripe_checkout_session_id: z.string(), - stripe_payment_intent_id: z.string().nullable(), - quantity: z.number(), - credit_micros_granted: z.number(), - amount_total: z.number().nullable(), - currency: z.string().nullable(), - source: z.string().default("checkout"), - status: purchaseStatusEnum, - completed_at: z.string().nullable(), - created_at: z.string(), -}); - -export const getCreditPurchasesResponse = z.object({ - purchases: z.array(creditPurchase), -}); - -// --------------------------------------------------------------------------- -// Legacy page purchases (read-only history; page buying is removed) -// --------------------------------------------------------------------------- - export const pagePurchase = z.object({ id: z.uuid(), stripe_checkout_session_id: z.string(), @@ -52,7 +23,7 @@ export const pagePurchase = z.object({ pages_granted: z.number(), amount_total: z.number().nullable(), currency: z.string().nullable(), - status: purchaseStatusEnum, + status: pagePurchaseStatusEnum, completed_at: z.string().nullable(), created_at: z.string(), }); @@ -61,59 +32,70 @@ export const getPagePurchasesResponse = z.object({ purchases: z.array(pagePurchase), }); -// Response from /stripe/finalize-checkout (credit purchases only). -export const finalizeCheckoutResponse = z.object({ - status: purchaseStatusEnum, - credit_micros_balance: z.number().default(0), - credit_micros_granted: z.number().nullable().optional(), -}); - -// --------------------------------------------------------------------------- -// Auto-reload (off-session top-up when the balance drops below a threshold) -// All *_micros fields are integer micro-USD (1_000_000 == $1.00). -// --------------------------------------------------------------------------- - -export const autoReloadSettingsResponse = z.object({ - feature_enabled: z.boolean(), - enabled: z.boolean().default(false), - threshold_micros: z.number().nullable(), - amount_micros: z.number().nullable(), - min_amount_micros: z.number(), - has_payment_method: z.boolean().default(false), - failed_at: z.string().nullable(), -}); - -export const updateAutoReloadSettingsRequest = z.object({ - enabled: z.boolean(), - threshold_micros: z.number().int().min(0).nullable().optional(), - amount_micros: z.number().int().min(0).nullable().optional(), -}); - -export const createAutoReloadSetupSessionRequest = z.object({ +// Premium credit purchases +export const createTokenCheckoutSessionRequest = z.object({ + quantity: z.number().int().min(1).max(100), search_space_id: z.number().int().min(1), }); -export const createAutoReloadSetupSessionResponse = z.object({ +export const createTokenCheckoutSessionResponse = z.object({ checkout_url: z.string(), }); -export type AutoReloadSettingsResponse = z.infer<typeof autoReloadSettingsResponse>; -export type UpdateAutoReloadSettingsRequest = z.infer<typeof updateAutoReloadSettingsRequest>; -export type CreateAutoReloadSetupSessionRequest = z.infer< - typeof createAutoReloadSetupSessionRequest ->; -export type CreateAutoReloadSetupSessionResponse = z.infer< - typeof createAutoReloadSetupSessionResponse ->; +// Premium credit balance + purchase records. +// +// The unit is integer micro-USD (1_000_000 == $1.00). The schema names +// kept the ``Token`` prefix for API back-compat with pinned clients; +// the field names below are authoritative. +export const tokenStripeStatusResponse = z.object({ + token_buying_enabled: z.boolean(), + premium_credit_micros_used: z.number().default(0), + premium_credit_micros_limit: z.number().default(0), + premium_credit_micros_remaining: z.number().default(0), +}); -export type PurchaseStatus = z.infer<typeof purchaseStatusEnum>; -export type CreateCreditCheckoutSessionRequest = z.infer<typeof createCreditCheckoutSessionRequest>; -export type CreateCreditCheckoutSessionResponse = z.infer< - typeof createCreditCheckoutSessionResponse ->; -export type CreditStripeStatusResponse = z.infer<typeof creditStripeStatusResponse>; -export type CreditPurchase = z.infer<typeof creditPurchase>; -export type GetCreditPurchasesResponse = z.infer<typeof getCreditPurchasesResponse>; +export const tokenPurchaseStatusEnum = pagePurchaseStatusEnum; + +export const tokenPurchase = z.object({ + id: z.uuid(), + stripe_checkout_session_id: z.string(), + stripe_payment_intent_id: z.string().nullable(), + quantity: z.number(), + credit_micros_granted: z.number(), + amount_total: z.number().nullable(), + currency: z.string().nullable(), + status: tokenPurchaseStatusEnum, + completed_at: z.string().nullable(), + created_at: z.string(), +}); + +export const getTokenPurchasesResponse = z.object({ + purchases: z.array(tokenPurchase), +}); + +// Response from /stripe/finalize-checkout. Either page or token fields +// are populated depending on purchase_type. +export const finalizeCheckoutResponse = z.object({ + purchase_type: z.enum(["page_packs", "premium_tokens"]), + status: pagePurchaseStatusEnum, + pages_limit: z.number().nullable().optional(), + pages_used: z.number().nullable().optional(), + pages_granted: z.number().nullable().optional(), + premium_credit_micros_limit: z.number().nullable().optional(), + premium_credit_micros_used: z.number().nullable().optional(), + premium_credit_micros_granted: z.number().nullable().optional(), +}); + +export type PagePurchaseStatus = z.infer<typeof pagePurchaseStatusEnum>; +export type CreateCheckoutSessionRequest = z.infer<typeof createCheckoutSessionRequest>; +export type CreateCheckoutSessionResponse = z.infer<typeof createCheckoutSessionResponse>; +export type StripeStatusResponse = z.infer<typeof stripeStatusResponse>; export type PagePurchase = z.infer<typeof pagePurchase>; export type GetPagePurchasesResponse = z.infer<typeof getPagePurchasesResponse>; +export type CreateTokenCheckoutSessionRequest = z.infer<typeof createTokenCheckoutSessionRequest>; +export type CreateTokenCheckoutSessionResponse = z.infer<typeof createTokenCheckoutSessionResponse>; +export type TokenStripeStatusResponse = z.infer<typeof tokenStripeStatusResponse>; +export type TokenPurchaseStatus = z.infer<typeof tokenPurchaseStatusEnum>; +export type TokenPurchase = z.infer<typeof tokenPurchase>; +export type GetTokenPurchasesResponse = z.infer<typeof getTokenPurchasesResponse>; export type FinalizeCheckoutResponse = z.infer<typeof finalizeCheckoutResponse>; diff --git a/surfsense_web/contracts/types/user.types.ts b/surfsense_web/contracts/types/user.types.ts index 706656064..85fee49a8 100644 --- a/surfsense_web/contracts/types/user.types.ts +++ b/surfsense_web/contracts/types/user.types.ts @@ -6,7 +6,8 @@ export const user = z.object({ is_active: z.boolean(), is_superuser: z.boolean(), is_verified: z.boolean(), - credit_micros_balance: z.number(), + pages_limit: z.number(), + pages_used: z.number(), display_name: z.string().nullish(), avatar_url: z.string().nullish(), }); diff --git a/surfsense_web/docker-entrypoint.js b/surfsense_web/docker-entrypoint.js new file mode 100644 index 000000000..8323f5652 --- /dev/null +++ b/surfsense_web/docker-entrypoint.js @@ -0,0 +1,88 @@ +/** + * Runtime environment variable substitution for Next.js Docker images. + * + * Next.js inlines NEXT_PUBLIC_* values at build time. The Docker image is built + * with unique placeholder strings (e.g. __NEXT_PUBLIC_FASTAPI_BACKEND_URL__). + * This script replaces those placeholders with real values from the container's + * environment variables before the server starts. + * + * Runs once at container startup via docker-entrypoint.sh. + */ + +const fs = require("fs"); +const path = require("path"); + +const replacements = [ + [ + "__NEXT_PUBLIC_FASTAPI_BACKEND_URL__", + process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000", + ], + [ + "__NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE__", + process.env.NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE || "LOCAL", + ], + ["__NEXT_PUBLIC_ETL_SERVICE__", process.env.NEXT_PUBLIC_ETL_SERVICE || "DOCLING"], + [ + "__NEXT_PUBLIC_ZERO_CACHE_URL__", + process.env.NEXT_PUBLIC_ZERO_CACHE_URL || "http://localhost:4848", + ], + ["__NEXT_PUBLIC_DEPLOYMENT_MODE__", process.env.NEXT_PUBLIC_DEPLOYMENT_MODE || "self-hosted"], +]; + +let filesProcessed = 0; +let filesModified = 0; + +function walk(dir) { + let entries; + try { + entries = fs.readdirSync(dir, { withFileTypes: true }); + } catch { + return; + } + for (const entry of entries) { + const full = path.join(dir, entry.name); + if (entry.isDirectory()) { + walk(full); + } else if (entry.name.endsWith(".js")) { + filesProcessed++; + let content = fs.readFileSync(full, "utf8"); + let changed = false; + for (const [placeholder, value] of replacements) { + if (content.includes(placeholder)) { + content = content.replaceAll(placeholder, value); + changed = true; + } + } + if (changed) { + fs.writeFileSync(full, content); + filesModified++; + } + } + } +} + +console.log("[entrypoint] Replacing environment variable placeholders..."); +for (const [placeholder, value] of replacements) { + console.log(` ${placeholder} -> ${value}`); +} + +walk(path.join(__dirname, ".next")); + +const serverJs = path.join(__dirname, "server.js"); +if (fs.existsSync(serverJs)) { + let content = fs.readFileSync(serverJs, "utf8"); + let changed = false; + filesProcessed++; + for (const [placeholder, value] of replacements) { + if (content.includes(placeholder)) { + content = content.replaceAll(placeholder, value); + changed = true; + } + } + if (changed) { + fs.writeFileSync(serverJs, content); + filesModified++; + } +} + +console.log(`[entrypoint] Done. Scanned ${filesProcessed} files, modified ${filesModified}.`); diff --git a/surfsense_web/docker-entrypoint.sh b/surfsense_web/docker-entrypoint.sh new file mode 100644 index 000000000..7f4dfbf25 --- /dev/null +++ b/surfsense_web/docker-entrypoint.sh @@ -0,0 +1,6 @@ +#!/bin/sh +set -e + +node /app/docker-entrypoint.js + +exec node server.js diff --git a/surfsense_web/eslint.config.mjs b/surfsense_web/eslint.config.mjs index 9531332bb..530b65478 100644 --- a/surfsense_web/eslint.config.mjs +++ b/surfsense_web/eslint.config.mjs @@ -9,33 +9,6 @@ const compat = new FlatCompat({ baseDirectory: __dirname, }); -const eslintConfig = [ - ...compat.extends("next/core-web-vitals", "next/typescript"), - { - rules: { - "no-restricted-imports": [ - "error", - { - paths: [ - { - name: "@/lib/env-config", - importNames: ["BACKEND_URL"], - message: - "Use buildBackendUrl(path, params) for browser-facing backend URLs. BACKEND_URL is empty in proxy mode; importing it bypasses the single URL seam.", - }, - ], - patterns: [ - { - group: ["**/env-config", "**/env-config.ts"], - importNames: ["BACKEND_URL"], - message: - "Use buildBackendUrl(path, params). Import BACKEND_URL only inside lib/env-config.ts.", - }, - ], - }, - ], - }, - }, -]; +const eslintConfig = [...compat.extends("next/core-web-vitals", "next/typescript")]; export default eslintConfig; diff --git a/surfsense_web/hooks/use-automation-eligible-models.ts b/surfsense_web/hooks/use-automation-eligible-models.ts index fd3ad3a6a..e74994221 100644 --- a/surfsense_web/hooks/use-automation-eligible-models.ts +++ b/surfsense_web/hooks/use-automation-eligible-models.ts @@ -3,11 +3,18 @@ import { useAtomValue } from "jotai"; import { useMemo } from "react"; import { - globalModelConnectionsAtom, - modelConnectionsAtom, - modelRolesAtom, -} from "@/atoms/model-connections/model-connections-query.atoms"; -import type { ConnectionRead, ModelRead } from "@/contracts/types/model-connections.types"; + globalImageGenConfigsAtom, + imageGenConfigsAtom, +} from "@/atoms/image-gen-config/image-gen-config-query.atoms"; +import { + globalNewLLMConfigsAtom, + llmPreferencesAtom, + newLLMConfigsAtom, +} from "@/atoms/new-llm-config/new-llm-config-query.atoms"; +import { + globalVisionLLMConfigsAtom, + visionLLMConfigsAtom, +} from "@/atoms/vision-llm-config/vision-llm-config-query.atoms"; /** * A single model the user may pick for an automation slot. @@ -37,45 +44,48 @@ export interface AutomationEligibleModels { isLoading: boolean; } +interface GlobalConfigLike { + id: number; + name: string; + model_name: string; + provider: string; + is_premium?: boolean; + is_auto_mode?: boolean; +} + +interface UserConfigLike { + id: number; + name: string; + model_name: string; + provider: string; +} + /** * Build the eligible option list for one model kind: premium globals - * followed by all BYOK/search-space models. + * (`is_premium === true`, never Auto mode) followed by all BYOK configs. */ function buildKind( - globals: ConnectionRead[] | undefined, - byok: ConnectionRead[] | undefined, - capability: "chat" | "image_gen" | "vision", + globals: GlobalConfigLike[] | undefined, + byok: UserConfigLike[] | undefined, prefId: number | null | undefined ): EligibleModelKind { - const supportsCapability = (model: ModelRead) => { - if (capability === "chat") return Boolean(model.supports_chat); - if (capability === "vision") return Boolean(model.supports_image_input); - return Boolean(model.supports_image_generation); - }; - const toOption = (connection: ConnectionRead, model: ModelRead, isBYOK: boolean) => ({ - id: model.id, - name: model.display_name || model.model_id, - modelName: model.model_id, - provider: connection.provider, - isBYOK, - }); + const premiumGlobals: EligibleModelOption[] = (globals ?? []) + .filter((c) => c.is_premium === true && !c.is_auto_mode) + .map((c) => ({ + id: c.id, + name: c.name, + modelName: c.model_name, + provider: c.provider, + isBYOK: false, + })); - const premiumGlobals: EligibleModelOption[] = (globals ?? []).flatMap((connection) => - connection.models - .filter( - (model) => - model.enabled && - supportsCapability(model) && - String(model.billing_tier ?? "").toLowerCase() === "premium" - ) - .map((model) => toOption(connection, model, false)) - ); - - const byokOptions: EligibleModelOption[] = (byok ?? []).flatMap((connection) => - connection.models - .filter((model) => model.enabled && supportsCapability(model)) - .map((model) => toOption(connection, model, true)) - ); + const byokOptions: EligibleModelOption[] = (byok ?? []).map((c) => ({ + id: c.id, + name: c.name, + modelName: c.model_name, + provider: c.provider, + isBYOK: true, + })); const options = [...premiumGlobals, ...byokOptions]; const byId = new Map<number, EligibleModelOption>(options.map((o) => [o.id, o])); @@ -95,32 +105,46 @@ function buildKind( * (premium globals + user BYOK — never free globals or Auto mode), with a * default selection seeded from the search space's role preferences. * - * Everything is derived during render from the connection/model query atoms; + * Everything is derived during render from the existing config query atoms; * there are no effects, so option lists/maps keep stable references. */ export function useAutomationEligibleModels(): AutomationEligibleModels { - const { data: byokConnections, isLoading: byokLoading } = useAtomValue(modelConnectionsAtom); - const { data: globalConnections, isLoading: globalLoading } = useAtomValue( - globalModelConnectionsAtom + const { data: llmUserConfigs, isLoading: llmUserLoading } = useAtomValue(newLLMConfigsAtom); + const { data: llmGlobalConfigs, isLoading: llmGlobalLoading } = + useAtomValue(globalNewLLMConfigsAtom); + const { data: preferences, isLoading: prefsLoading } = useAtomValue(llmPreferencesAtom); + const { data: imageGlobalConfigs, isLoading: imageGlobalLoading } = + useAtomValue(globalImageGenConfigsAtom); + const { data: imageUserConfigs, isLoading: imageUserLoading } = useAtomValue(imageGenConfigsAtom); + const { data: visionGlobalConfigs, isLoading: visionGlobalLoading } = useAtomValue( + globalVisionLLMConfigsAtom ); - const { data: roles, isLoading: rolesLoading } = useAtomValue(modelRolesAtom); + const { data: visionUserConfigs, isLoading: visionUserLoading } = + useAtomValue(visionLLMConfigsAtom); const llm = useMemo( - () => buildKind(globalConnections, byokConnections, "chat", roles?.chat_model_id), - [globalConnections, byokConnections, roles?.chat_model_id] + () => buildKind(llmGlobalConfigs, llmUserConfigs, preferences?.agent_llm_id), + [llmGlobalConfigs, llmUserConfigs, preferences?.agent_llm_id] ); const image = useMemo( - () => buildKind(globalConnections, byokConnections, "image_gen", roles?.image_gen_model_id), - [globalConnections, byokConnections, roles?.image_gen_model_id] + () => buildKind(imageGlobalConfigs, imageUserConfigs, preferences?.image_generation_config_id), + [imageGlobalConfigs, imageUserConfigs, preferences?.image_generation_config_id] ); const vision = useMemo( - () => buildKind(globalConnections, byokConnections, "vision", roles?.vision_model_id), - [globalConnections, byokConnections, roles?.vision_model_id] + () => buildKind(visionGlobalConfigs, visionUserConfigs, preferences?.vision_llm_config_id), + [visionGlobalConfigs, visionUserConfigs, preferences?.vision_llm_config_id] ); - const isLoading = byokLoading || globalLoading || rolesLoading; + const isLoading = + llmUserLoading || + llmGlobalLoading || + prefsLoading || + imageGlobalLoading || + imageUserLoading || + visionGlobalLoading || + visionUserLoading; return useMemo(() => ({ llm, image, vision, isLoading }), [llm, image, vision, isLoading]); } diff --git a/surfsense_web/hooks/use-inbox.ts b/surfsense_web/hooks/use-inbox.ts index 860c0e01a..e1070219a 100644 --- a/surfsense_web/hooks/use-inbox.ts +++ b/surfsense_web/hooks/use-inbox.ts @@ -22,7 +22,7 @@ const CATEGORY_TYPES: Record<NotificationCategory, string[]> = { "connector_indexing", "connector_deletion", "document_processing", - "insufficient_credits", + "page_limit_exceeded", ], }; diff --git a/surfsense_web/hooks/use-podcast-live.ts b/surfsense_web/hooks/use-podcast-live.ts deleted file mode 100644 index e0a30e05b..000000000 --- a/surfsense_web/hooks/use-podcast-live.ts +++ /dev/null @@ -1,59 +0,0 @@ -"use client"; - -import { useQuery } from "@rocicorp/zero/react"; -import { useMemo } from "react"; -import { type PodcastSpec, type PodcastStatus, podcastSpec } from "@/contracts/types/podcast.types"; -import { queries } from "@/zero/queries"; - -/** - * Thin live row sourced from Zero's `podcasts` publication. Drives the - * lifecycle UI by push (no polling); heavy fields (transcript, audio) stay on - * REST and are fetched lazily when a gate or the player needs them. - */ -export interface LivePodcast { - id: number; - title: string; - status: PodcastStatus; - spec: PodcastSpec | null; - specVersion: number; - durationSeconds: number | null; - error: string | null; - searchSpaceId: number; - threadId: number | null; -} - -interface UsePodcastLiveResult { - podcast: LivePodcast | undefined; - isLoading: boolean; -} - -export function usePodcastLive(podcastId: number | undefined): UsePodcastLiveResult { - const [row, result] = useQuery(queries.podcasts.byId({ podcastId: podcastId ?? -1 })); - - const podcast = useMemo<LivePodcast | undefined>(() => { - if (!podcastId || !row) return undefined; - return { - id: row.id, - title: row.title, - status: row.status as PodcastStatus, - spec: parseSpec(row.spec), - specVersion: row.specVersion, - durationSeconds: row.durationSeconds ?? null, - error: row.error ?? null, - searchSpaceId: row.searchSpaceId, - threadId: row.threadId ?? null, - }; - }, [podcastId, row]); - - // Pre-hydration window: no row AND Zero hasn't confirmed completeness yet. - const isLoading = !!podcastId && !row && result.type !== "complete"; - - return { podcast, isLoading }; -} - -/** The JSONB column holds the snake_case spec; reject anything malformed. */ -function parseSpec(raw: unknown): PodcastSpec | null { - if (raw == null) return null; - const parsed = podcastSpec.safeParse(raw); - return parsed.success ? parsed.data : null; -} diff --git a/surfsense_web/hooks/use-search-source-connectors.ts b/surfsense_web/hooks/use-search-source-connectors.ts index 30083dcc3..ad0db3de6 100644 --- a/surfsense_web/hooks/use-search-source-connectors.ts +++ b/surfsense_web/hooks/use-search-source-connectors.ts @@ -1,6 +1,6 @@ import { useCallback, useEffect, useState } from "react"; import { authenticatedFetch } from "@/lib/auth-utils"; -import { buildBackendUrl } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; export interface SearchSourceConnector { id: number; name: string; @@ -106,15 +106,16 @@ export const useSearchSourceConnectors = (lazy: boolean = false, searchSpaceId?: setIsLoading(true); setError(null); - const response = await authenticatedFetch( - buildBackendUrl("/api/v1/search-source-connectors", { - search_space_id: spaceId, - }), - { - method: "GET", - headers: { "Content-Type": "application/json" }, - } - ); + // Build URL with optional search_space_id query parameter + const url = new URL(`${BACKEND_URL}/api/v1/search-source-connectors`); + if (spaceId !== undefined) { + url.searchParams.append("search_space_id", spaceId.toString()); + } + + const response = await authenticatedFetch(url.toString(), { + method: "GET", + headers: { "Content-Type": "application/json" }, + }); if (!response.ok) { throw new Error(`Failed to fetch connectors: ${response.statusText}`); @@ -165,16 +166,15 @@ export const useSearchSourceConnectors = (lazy: boolean = false, searchSpaceId?: spaceId: number ) => { try { - const response = await authenticatedFetch( - buildBackendUrl("/api/v1/search-source-connectors", { - search_space_id: spaceId, - }), - { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify(connectorData), - } - ); + // Add search_space_id as a query parameter + const url = new URL(`${BACKEND_URL}/api/v1/search-source-connectors`); + url.searchParams.append("search_space_id", spaceId.toString()); + + const response = await authenticatedFetch(url.toString(), { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(connectorData), + }); if (!response.ok) { throw new Error(`Failed to create connector: ${response.statusText}`); @@ -204,7 +204,7 @@ export const useSearchSourceConnectors = (lazy: boolean = false, searchSpaceId?: ) => { try { const response = await authenticatedFetch( - buildBackendUrl(`/api/v1/search-source-connectors/${connectorId}`), + `${BACKEND_URL}/api/v1/search-source-connectors/${connectorId}`, { method: "PUT", headers: { "Content-Type": "application/json" }, @@ -235,7 +235,7 @@ export const useSearchSourceConnectors = (lazy: boolean = false, searchSpaceId?: const deleteConnector = async (connectorId: number) => { try { const response = await authenticatedFetch( - buildBackendUrl(`/api/v1/search-source-connectors/${connectorId}`), + `${BACKEND_URL}/api/v1/search-source-connectors/${connectorId}`, { method: "DELETE", headers: { "Content-Type": "application/json" }, @@ -267,12 +267,19 @@ export const useSearchSourceConnectors = (lazy: boolean = false, searchSpaceId?: endDate?: string ) => { try { + // Build query parameters + const params = new URLSearchParams({ + search_space_id: searchSpaceId.toString(), + }); + if (startDate) { + params.append("start_date", startDate); + } + if (endDate) { + params.append("end_date", endDate); + } + const response = await authenticatedFetch( - buildBackendUrl(`/api/v1/search-source-connectors/${connectorId}/index`, { - search_space_id: searchSpaceId, - start_date: startDate, - end_date: endDate, - }), + `${BACKEND_URL}/api/v1/search-source-connectors/${connectorId}/index?${params.toString()}`, { method: "POST", headers: { "Content-Type": "application/json" }, diff --git a/surfsense_web/lib/apis/anonymous-chat-api.service.ts b/surfsense_web/lib/apis/anonymous-chat-api.service.ts index 5cdc139e1..843576a50 100644 --- a/surfsense_web/lib/apis/anonymous-chat-api.service.ts +++ b/surfsense_web/lib/apis/anonymous-chat-api.service.ts @@ -7,7 +7,7 @@ import { getAnonModelResponse, getAnonModelsResponse, } from "@/contracts/types/anonymous-chat.types"; -import { buildBackendUrl } from "../env-config"; +import { BACKEND_URL } from "../env-config"; import { ValidationError } from "../error"; const BASE = "/api/v1/public/anon-chat"; @@ -17,8 +17,14 @@ export type AnonUploadResult = | { ok: false; reason: "quota_exceeded" }; class AnonymousChatApiService { + private baseUrl: string; + + constructor(baseUrl: string) { + this.baseUrl = baseUrl; + } + private fullUrl(path: string): string { - return buildBackendUrl(`${BASE}${path}`); + return `${this.baseUrl}${BASE}${path}`; } getModels = async (): Promise<AnonModel[]> => { @@ -96,4 +102,4 @@ class AnonymousChatApiService { }; } -export const anonymousChatApiService = new AnonymousChatApiService(); +export const anonymousChatApiService = new AnonymousChatApiService(BACKEND_URL); diff --git a/surfsense_web/lib/apis/base-api.service.ts b/surfsense_web/lib/apis/base-api.service.ts index 678293d8e..a0039b63a 100644 --- a/surfsense_web/lib/apis/base-api.service.ts +++ b/surfsense_web/lib/apis/base-api.service.ts @@ -1,5 +1,5 @@ import type { ZodType } from "zod"; -import { buildBackendUrl } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; import { getClientPlatform } from "../agent-filesystem"; import { getBearerToken, handleUnauthorized, refreshAccessToken } from "../auth-utils"; import { @@ -31,6 +31,8 @@ export type RequestOptions = { }; class BaseApiService { + baseUrl: string; + noAuthEndpoints: string[] = ["/auth/jwt/login", "/auth/register", "/auth/refresh"]; // Prefixes that don't require auth (checked with startsWith) @@ -42,9 +44,12 @@ class BaseApiService { return typeof window !== "undefined" ? getBearerToken() || "" : ""; } + constructor(baseUrl: string) { + this.baseUrl = baseUrl; + } + // Keep for backward compatibility, but token is now always read from localStorage setBearerToken(_bearerToken: string) { - void _bearerToken; // No-op: token is now always read fresh from localStorage via the getter } @@ -88,6 +93,11 @@ class BaseApiService { }, }; + // Validate the base URL + if (!this.baseUrl) { + throw new AppError("Base URL is not set."); + } + // Validate the bearer token const isNoAuthEndpoint = this.noAuthEndpoints.includes(url) || @@ -97,7 +107,8 @@ class BaseApiService { throw new AuthenticationError("You are not authenticated. Please login again."); } - const fullUrl = buildBackendUrl(url); + // Construct the full URL + const fullUrl = new URL(url, this.baseUrl).toString(); // Prepare fetch options const fetchOptions: RequestInit = { @@ -373,8 +384,7 @@ class BaseApiService { options?: Omit<RequestOptions, "method" | "responseType" | "body"> & { body: FormData } ) { // Remove Content-Type from options headers if present - const headersWithoutContentType = { ...(options?.headers ?? {}) }; - delete headersWithoutContentType["Content-Type"]; + const { "Content-Type": _, ...headersWithoutContentType } = options?.headers ?? {}; return this.request(url, responseSchema, { method: "POST", @@ -389,4 +399,4 @@ class BaseApiService { } } -export const baseApiService = new BaseApiService(); +export const baseApiService = new BaseApiService(BACKEND_URL); diff --git a/surfsense_web/lib/apis/image-gen-config-api.service.ts b/surfsense_web/lib/apis/image-gen-config-api.service.ts new file mode 100644 index 000000000..a9d444d21 --- /dev/null +++ b/surfsense_web/lib/apis/image-gen-config-api.service.ts @@ -0,0 +1,81 @@ +import { + type CreateImageGenConfigRequest, + createImageGenConfigRequest, + createImageGenConfigResponse, + deleteImageGenConfigResponse, + getGlobalImageGenConfigsResponse, + getImageGenConfigsResponse, + type UpdateImageGenConfigRequest, + updateImageGenConfigRequest, + updateImageGenConfigResponse, +} from "@/contracts/types/new-llm-config.types"; +import { ValidationError } from "../error"; +import { baseApiService } from "./base-api.service"; + +class ImageGenConfigApiService { + /** + * Get all global image generation configs (from YAML, negative IDs) + */ + getGlobalConfigs = async () => { + return baseApiService.get( + `/api/v1/global-image-generation-configs`, + getGlobalImageGenConfigsResponse + ); + }; + + /** + * Create a new image generation config for a search space + */ + createConfig = async (request: CreateImageGenConfigRequest) => { + const parsed = createImageGenConfigRequest.safeParse(request); + if (!parsed.success) { + const msg = parsed.error.issues.map((i) => i.message).join(", "); + throw new ValidationError(`Invalid request: ${msg}`); + } + return baseApiService.post(`/api/v1/image-generation-configs`, createImageGenConfigResponse, { + body: parsed.data, + }); + }; + + /** + * Get image generation configs for a search space + */ + getConfigs = async (searchSpaceId: number) => { + const params = new URLSearchParams({ + search_space_id: String(searchSpaceId), + }).toString(); + return baseApiService.get( + `/api/v1/image-generation-configs?${params}`, + getImageGenConfigsResponse + ); + }; + + /** + * Update an existing image generation config + */ + updateConfig = async (request: UpdateImageGenConfigRequest) => { + const parsed = updateImageGenConfigRequest.safeParse(request); + if (!parsed.success) { + const msg = parsed.error.issues.map((i) => i.message).join(", "); + throw new ValidationError(`Invalid request: ${msg}`); + } + const { id, data } = parsed.data; + return baseApiService.put( + `/api/v1/image-generation-configs/${id}`, + updateImageGenConfigResponse, + { body: data } + ); + }; + + /** + * Delete an image generation config + */ + deleteConfig = async (id: number) => { + return baseApiService.delete( + `/api/v1/image-generation-configs/${id}`, + deleteImageGenConfigResponse + ); + }; +} + +export const imageGenConfigApiService = new ImageGenConfigApiService(); diff --git a/surfsense_web/lib/apis/model-connections-api.service.ts b/surfsense_web/lib/apis/model-connections-api.service.ts deleted file mode 100644 index c69bcbef2..000000000 --- a/surfsense_web/lib/apis/model-connections-api.service.ts +++ /dev/null @@ -1,172 +0,0 @@ -import { - type ConnectionCreateRequest, - type ConnectionRead, - type ConnectionUpdateRequest, - connectionCreateRequest, - connectionListResponse, - connectionRead, - connectionUpdateRequest, - type GlobalLlmConfigStatus, - globalLlmConfigStatus, - type ModelCreateRequest, - type ModelPreviewRead, - type ModelProviderRead, - type ModelRead, - type ModelRoles, - type ModelsBulkUpdateRequest, - type ModelTestPreviewRequest, - type ModelUpdateRequest, - modelCreateRequest, - modelListResponse, - modelPreviewListResponse, - modelProviderListResponse, - modelRead, - modelRoles, - modelsBulkUpdateRequest, - modelTestPreviewRequest, - modelUpdateRequest, - type VerifyConnectionResponse, - verifyConnectionResponse, -} from "@/contracts/types/model-connections.types"; -import { ValidationError } from "../error"; -import { baseApiService } from "./base-api.service"; - -class ModelConnectionsApiService { - getGlobalConnections = async (): Promise<ConnectionRead[]> => { - return baseApiService.get(`/api/v1/global-model-connections`, connectionListResponse); - }; - - getGlobalLlmConfigStatus = async (): Promise<GlobalLlmConfigStatus> => { - return baseApiService.get(`/api/v1/global-llm-config-status`, globalLlmConfigStatus); - }; - - getModelProviders = async (): Promise<ModelProviderRead[]> => { - return baseApiService.get(`/api/v1/model-providers`, modelProviderListResponse); - }; - - getConnections = async (searchSpaceId: number): Promise<ConnectionRead[]> => { - return baseApiService.get( - `/api/v1/model-connections?search_space_id=${searchSpaceId}`, - connectionListResponse - ); - }; - - createConnection = async (request: ConnectionCreateRequest): Promise<ConnectionRead> => { - const parsed = connectionCreateRequest.safeParse(request); - if (!parsed.success) { - throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", ")); - } - return baseApiService.post(`/api/v1/model-connections`, connectionRead, { - body: parsed.data, - }); - }; - - updateConnection = async ( - id: number, - request: ConnectionUpdateRequest - ): Promise<ConnectionRead> => { - const parsed = connectionUpdateRequest.safeParse(request); - if (!parsed.success) { - throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", ")); - } - return baseApiService.put(`/api/v1/model-connections/${id}`, connectionRead, { - body: parsed.data, - }); - }; - - deleteConnection = async (id: number) => { - return baseApiService.delete(`/api/v1/model-connections/${id}`, undefined); - }; - - verifyConnection = async (id: number): Promise<VerifyConnectionResponse> => { - return baseApiService.post(`/api/v1/model-connections/${id}/verify`, verifyConnectionResponse); - }; - - discoverModels = async (id: number): Promise<ModelRead[]> => { - return baseApiService.post(`/api/v1/model-connections/${id}/discover`, modelListResponse); - }; - - previewModels = async (request: ConnectionCreateRequest): Promise<ModelPreviewRead[]> => { - const parsed = connectionCreateRequest.safeParse(request); - if (!parsed.success) { - throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", ")); - } - return baseApiService.post( - `/api/v1/model-connections/discover-preview`, - modelPreviewListResponse, - { - body: parsed.data, - } - ); - }; - - testPreviewModel = async ( - request: ModelTestPreviewRequest - ): Promise<VerifyConnectionResponse> => { - const parsed = modelTestPreviewRequest.safeParse(request); - if (!parsed.success) { - throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", ")); - } - return baseApiService.post(`/api/v1/model-connections/test-preview`, verifyConnectionResponse, { - body: parsed.data, - }); - }; - - addManualModel = async ( - connectionId: number, - request: ModelCreateRequest - ): Promise<ModelRead> => { - const parsed = modelCreateRequest.safeParse(request); - if (!parsed.success) { - throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", ")); - } - return baseApiService.post(`/api/v1/model-connections/${connectionId}/models`, modelRead, { - body: parsed.data, - }); - }; - - updateModel = async (id: number, request: ModelUpdateRequest): Promise<ModelRead> => { - const parsed = modelUpdateRequest.safeParse(request); - if (!parsed.success) { - throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", ")); - } - return baseApiService.put(`/api/v1/models/${id}`, modelRead, { - body: parsed.data, - }); - }; - - bulkUpdateModels = async ( - connectionId: number, - request: ModelsBulkUpdateRequest - ): Promise<ModelRead[]> => { - const parsed = modelsBulkUpdateRequest.safeParse(request); - if (!parsed.success) { - throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", ")); - } - return baseApiService.request( - `/api/v1/model-connections/${connectionId}/models`, - modelListResponse, - { - method: "PATCH", - headers: { "Content-Type": "application/json" }, - body: parsed.data, - } - ); - }; - - testModel = async (id: number): Promise<VerifyConnectionResponse> => { - return baseApiService.post(`/api/v1/models/${id}/test`, verifyConnectionResponse); - }; - - getModelRoles = async (searchSpaceId: number): Promise<ModelRoles> => { - return baseApiService.get(`/api/v1/search-spaces/${searchSpaceId}/model-roles`, modelRoles); - }; - - updateModelRoles = async (searchSpaceId: number, roles: ModelRoles): Promise<ModelRoles> => { - return baseApiService.put(`/api/v1/search-spaces/${searchSpaceId}/model-roles`, modelRoles, { - body: roles, - }); - }; -} - -export const modelConnectionsApiService = new ModelConnectionsApiService(); diff --git a/surfsense_web/lib/apis/new-llm-config-api.service.ts b/surfsense_web/lib/apis/new-llm-config-api.service.ts new file mode 100644 index 000000000..a1040a9bc --- /dev/null +++ b/surfsense_web/lib/apis/new-llm-config-api.service.ts @@ -0,0 +1,178 @@ +import { + type CreateNewLLMConfigRequest, + createNewLLMConfigRequest, + createNewLLMConfigResponse, + type DeleteNewLLMConfigRequest, + deleteNewLLMConfigRequest, + deleteNewLLMConfigResponse, + type GetNewLLMConfigRequest, + type GetNewLLMConfigsRequest, + getDefaultSystemInstructionsResponse, + getGlobalNewLLMConfigsResponse, + getLLMPreferencesResponse, + getModelListResponse, + getNewLLMConfigRequest, + getNewLLMConfigResponse, + getNewLLMConfigsRequest, + getNewLLMConfigsResponse, + type UpdateLLMPreferencesRequest, + type UpdateNewLLMConfigRequest, + updateLLMPreferencesRequest, + updateLLMPreferencesResponse, + updateNewLLMConfigRequest, + updateNewLLMConfigResponse, +} from "@/contracts/types/new-llm-config.types"; +import { ValidationError } from "../error"; +import { baseApiService } from "./base-api.service"; + +class NewLLMConfigApiService { + /** + * Get all global NewLLMConfigs available to all users + */ + getGlobalConfigs = async () => { + return baseApiService.get(`/api/v1/global-new-llm-configs`, getGlobalNewLLMConfigsResponse); + }; + + /** + * Get default system instructions template + */ + getDefaultSystemInstructions = async () => { + return baseApiService.get( + `/api/v1/new-llm-configs/default-system-instructions`, + getDefaultSystemInstructionsResponse + ); + }; + + /** + * Create a new NewLLMConfig for a search space + */ + createConfig = async (request: CreateNewLLMConfigRequest) => { + const parsedRequest = createNewLLMConfigRequest.safeParse(request); + + if (!parsedRequest.success) { + console.error("Invalid request:", parsedRequest.error); + const errorMessage = parsedRequest.error.issues.map((issue) => issue.message).join(", "); + throw new ValidationError(`Invalid request: ${errorMessage}`); + } + + return baseApiService.post(`/api/v1/new-llm-configs`, createNewLLMConfigResponse, { + body: parsedRequest.data, + }); + }; + + /** + * Get a list of NewLLMConfigs for a search space + */ + getConfigs = async (request: GetNewLLMConfigsRequest) => { + const parsedRequest = getNewLLMConfigsRequest.safeParse(request); + + if (!parsedRequest.success) { + console.error("Invalid request:", parsedRequest.error); + const errorMessage = parsedRequest.error.issues.map((issue) => issue.message).join(", "); + throw new ValidationError(`Invalid request: ${errorMessage}`); + } + + const queryParams = new URLSearchParams({ + search_space_id: String(parsedRequest.data.search_space_id), + ...(parsedRequest.data.skip !== undefined && { skip: String(parsedRequest.data.skip) }), + ...(parsedRequest.data.limit !== undefined && { limit: String(parsedRequest.data.limit) }), + }).toString(); + + return baseApiService.get(`/api/v1/new-llm-configs?${queryParams}`, getNewLLMConfigsResponse); + }; + + /** + * Get a single NewLLMConfig by ID + */ + getConfig = async (request: GetNewLLMConfigRequest) => { + const parsedRequest = getNewLLMConfigRequest.safeParse(request); + + if (!parsedRequest.success) { + console.error("Invalid request:", parsedRequest.error); + const errorMessage = parsedRequest.error.issues.map((issue) => issue.message).join(", "); + throw new ValidationError(`Invalid request: ${errorMessage}`); + } + + return baseApiService.get( + `/api/v1/new-llm-configs/${parsedRequest.data.id}`, + getNewLLMConfigResponse + ); + }; + + /** + * Update an existing NewLLMConfig + */ + updateConfig = async (request: UpdateNewLLMConfigRequest) => { + const parsedRequest = updateNewLLMConfigRequest.safeParse(request); + + if (!parsedRequest.success) { + console.error("Invalid request:", parsedRequest.error); + const errorMessage = parsedRequest.error.issues.map((issue) => issue.message).join(", "); + throw new ValidationError(`Invalid request: ${errorMessage}`); + } + + const { id, data } = parsedRequest.data; + + return baseApiService.put(`/api/v1/new-llm-configs/${id}`, updateNewLLMConfigResponse, { + body: data, + }); + }; + + /** + * Delete a NewLLMConfig + */ + deleteConfig = async (request: DeleteNewLLMConfigRequest) => { + const parsedRequest = deleteNewLLMConfigRequest.safeParse(request); + + if (!parsedRequest.success) { + console.error("Invalid request:", parsedRequest.error); + const errorMessage = parsedRequest.error.issues.map((issue) => issue.message).join(", "); + throw new ValidationError(`Invalid request: ${errorMessage}`); + } + + return baseApiService.delete( + `/api/v1/new-llm-configs/${parsedRequest.data.id}`, + deleteNewLLMConfigResponse + ); + }; + + /** + * Get LLM preferences for a search space + */ + getLLMPreferences = async (searchSpaceId: number) => { + return baseApiService.get( + `/api/v1/search-spaces/${searchSpaceId}/llm-preferences`, + getLLMPreferencesResponse + ); + }; + + /** + * Get the dynamic model catalogue (sourced from OpenRouter API) + */ + getModels = async () => { + return baseApiService.get(`/api/v1/models`, getModelListResponse); + }; + + /** + * Update LLM preferences for a search space + */ + updateLLMPreferences = async (request: UpdateLLMPreferencesRequest) => { + const parsedRequest = updateLLMPreferencesRequest.safeParse(request); + + if (!parsedRequest.success) { + console.error("Invalid request:", parsedRequest.error); + const errorMessage = parsedRequest.error.issues.map((issue) => issue.message).join(", "); + throw new ValidationError(`Invalid request: ${errorMessage}`); + } + + const { search_space_id, data } = parsedRequest.data; + + return baseApiService.put( + `/api/v1/search-spaces/${search_space_id}/llm-preferences`, + updateLLMPreferencesResponse, + { body: data } + ); + }; +} + +export const newLLMConfigApiService = new NewLLMConfigApiService(); diff --git a/surfsense_web/lib/apis/podcasts-api.service.ts b/surfsense_web/lib/apis/podcasts-api.service.ts deleted file mode 100644 index 2e13d63cc..000000000 --- a/surfsense_web/lib/apis/podcasts-api.service.ts +++ /dev/null @@ -1,76 +0,0 @@ -import { z } from "zod"; -import { - languageOptions, - type PodcastSpec, - podcastDetail, - updateSpecRequest, - voiceOption, -} from "@/contracts/types/podcast.types"; -import { ValidationError } from "../error"; -import { baseApiService } from "./base-api.service"; - -const BASE = "/api/v1/podcasts"; - -const voiceOptionList = z.array(voiceOption); - -class PodcastsApiService { - // Full state including the deserialized brief and transcript; thin lifecycle - // fields (status, spec, spec_version) also arrive live via Zero. - getDetail = async (podcastId: number) => { - return baseApiService.get(`${BASE}/${podcastId}`, podcastDetail); - }; - - // Guarded by the version the caller last saw; the backend answers 409 when - // the brief changed underneath them. - updateSpec = async (podcastId: number, spec: PodcastSpec, expectedVersion: number) => { - const parsed = updateSpecRequest.safeParse({ spec, expected_version: expectedVersion }); - if (!parsed.success) { - throw new ValidationError( - `Invalid request: ${parsed.error.issues.map((i) => i.message).join(", ")}` - ); - } - return baseApiService.patch(`${BASE}/${podcastId}/spec`, podcastDetail, { - body: parsed.data, - }); - }; - - approveBrief = async (podcastId: number) => { - return baseApiService.post(`${BASE}/${podcastId}/brief/approve`, podcastDetail); - }; - - // Reopens the brief gate; the transcript and audio are replaced once the - // user re-approves. - regenerate = async (podcastId: number) => { - return baseApiService.post(`${BASE}/${podcastId}/transcript/regenerate`, podcastDetail); - }; - - // Backs out of a regeneration: the podcast returns to ready with its - // existing audio untouched. 409 when there is no episode to fall back to. - revertRegeneration = async (podcastId: number) => { - return baseApiService.post(`${BASE}/${podcastId}/regenerate/revert`, podcastDetail); - }; - - // Only for podcasts that have produced nothing yet; once an episode - // exists the backend refuses (409) and revertRegeneration is the way back. - cancel = async (podcastId: number) => { - return baseApiService.post(`${BASE}/${podcastId}/cancel`, podcastDetail); - }; - - listVoices = async (language?: string) => { - const qs = language ? `?${new URLSearchParams({ language })}` : ""; - return baseApiService.get(`${BASE}/voices${qs}`, voiceOptionList); - }; - - // The languages the active provider can offer; the brief form renders - // exactly this list and only opens free entry when the backend allows it. - listLanguages = async () => { - return baseApiService.get(`${BASE}/languages`, languageOptions); - }; - - // A short audio sample of a voice, cached server-side per voice. - previewVoice = async (voiceId: string) => { - return baseApiService.getBlob(`${BASE}/voices/${encodeURIComponent(voiceId)}/preview`); - }; -} - -export const podcastsApiService = new PodcastsApiService(); diff --git a/surfsense_web/lib/apis/stripe-api.service.ts b/surfsense_web/lib/apis/stripe-api.service.ts index b2f5698fb..f119fbf6a 100644 --- a/surfsense_web/lib/apis/stripe-api.service.ts +++ b/surfsense_web/lib/apis/stripe-api.service.ts @@ -1,50 +1,64 @@ import { - type AutoReloadSettingsResponse, - autoReloadSettingsResponse, - type CreateAutoReloadSetupSessionRequest, - type CreateAutoReloadSetupSessionResponse, - type CreateCreditCheckoutSessionRequest, - type CreateCreditCheckoutSessionResponse, - type CreditStripeStatusResponse, - createAutoReloadSetupSessionResponse, - createCreditCheckoutSessionResponse, - creditStripeStatusResponse, + type CreateCheckoutSessionRequest, + type CreateCheckoutSessionResponse, + type CreateTokenCheckoutSessionRequest, + type CreateTokenCheckoutSessionResponse, + createCheckoutSessionResponse, + createTokenCheckoutSessionResponse, type FinalizeCheckoutResponse, finalizeCheckoutResponse, - type GetCreditPurchasesResponse, type GetPagePurchasesResponse, - getCreditPurchasesResponse, + type GetTokenPurchasesResponse, getPagePurchasesResponse, - type UpdateAutoReloadSettingsRequest, + getTokenPurchasesResponse, + type StripeStatusResponse, + stripeStatusResponse, + type TokenStripeStatusResponse, + tokenStripeStatusResponse, } from "@/contracts/types/stripe.types"; import { baseApiService } from "./base-api.service"; class StripeApiService { - createCreditCheckoutSession = async ( - request: CreateCreditCheckoutSessionRequest - ): Promise<CreateCreditCheckoutSessionResponse> => { + createCheckoutSession = async ( + request: CreateCheckoutSessionRequest + ): Promise<CreateCheckoutSessionResponse> => { return baseApiService.post( - "/api/v1/stripe/create-credit-checkout-session", - createCreditCheckoutSessionResponse, + "/api/v1/stripe/create-checkout-session", + createCheckoutSessionResponse, + { + body: request, + } + ); + }; + + getPurchases = async (): Promise<GetPagePurchasesResponse> => { + return baseApiService.get("/api/v1/stripe/purchases", getPagePurchasesResponse); + }; + + getStatus = async (): Promise<StripeStatusResponse> => { + return baseApiService.get("/api/v1/stripe/status", stripeStatusResponse); + }; + + createTokenCheckoutSession = async ( + request: CreateTokenCheckoutSessionRequest + ): Promise<CreateTokenCheckoutSessionResponse> => { + return baseApiService.post( + "/api/v1/stripe/create-token-checkout-session", + createTokenCheckoutSessionResponse, { body: request } ); }; - getCreditStatus = async (): Promise<CreditStripeStatusResponse> => { - return baseApiService.get("/api/v1/stripe/credit-status", creditStripeStatusResponse); + getTokenStatus = async (): Promise<TokenStripeStatusResponse> => { + return baseApiService.get("/api/v1/stripe/token-status", tokenStripeStatusResponse); }; - getCreditPurchases = async (): Promise<GetCreditPurchasesResponse> => { - return baseApiService.get("/api/v1/stripe/credit-purchases", getCreditPurchasesResponse); - }; - - /** Legacy page-purchase history (read-only; page buying is removed). */ - getPagePurchases = async (): Promise<GetPagePurchasesResponse> => { - return baseApiService.get("/api/v1/stripe/purchases", getPagePurchasesResponse); + getTokenPurchases = async (): Promise<GetTokenPurchasesResponse> => { + return baseApiService.get("/api/v1/stripe/token-purchases", getTokenPurchasesResponse); }; /** - * Synchronously fulfil a credit checkout session from the success page. + * Synchronously fulfil a checkout session from the success page. * * Solves the race where the user lands on /purchase-success before * Stripe's checkout.session.completed webhook arrives. Idempotent — @@ -56,30 +70,6 @@ class StripeApiService { finalizeCheckoutResponse ); }; - - // --- Auto-reload -------------------------------------------------------- - - getAutoReloadSettings = async (): Promise<AutoReloadSettingsResponse> => { - return baseApiService.get("/api/v1/stripe/auto-reload", autoReloadSettingsResponse); - }; - - updateAutoReloadSettings = async ( - request: UpdateAutoReloadSettingsRequest - ): Promise<AutoReloadSettingsResponse> => { - return baseApiService.put("/api/v1/stripe/auto-reload", autoReloadSettingsResponse, { - body: request, - }); - }; - - createAutoReloadSetupSession = async ( - request: CreateAutoReloadSetupSessionRequest - ): Promise<CreateAutoReloadSetupSessionResponse> => { - return baseApiService.post( - "/api/v1/stripe/auto-reload/setup", - createAutoReloadSetupSessionResponse, - { body: request } - ); - }; } export const stripeApiService = new StripeApiService(); diff --git a/surfsense_web/lib/apis/vision-llm-config-api.service.ts b/surfsense_web/lib/apis/vision-llm-config-api.service.ts new file mode 100644 index 000000000..537cecbd1 --- /dev/null +++ b/surfsense_web/lib/apis/vision-llm-config-api.service.ts @@ -0,0 +1,63 @@ +import { + type CreateVisionLLMConfigRequest, + createVisionLLMConfigRequest, + createVisionLLMConfigResponse, + deleteVisionLLMConfigResponse, + getGlobalVisionLLMConfigsResponse, + getModelListResponse, + getVisionLLMConfigsResponse, + type UpdateVisionLLMConfigRequest, + updateVisionLLMConfigRequest, + updateVisionLLMConfigResponse, +} from "@/contracts/types/new-llm-config.types"; +import { ValidationError } from "../error"; +import { baseApiService } from "./base-api.service"; + +class VisionLLMConfigApiService { + getModels = async () => { + return baseApiService.get(`/api/v1/vision-models`, getModelListResponse); + }; + + getGlobalConfigs = async () => { + return baseApiService.get( + `/api/v1/global-vision-llm-configs`, + getGlobalVisionLLMConfigsResponse + ); + }; + + createConfig = async (request: CreateVisionLLMConfigRequest) => { + const parsed = createVisionLLMConfigRequest.safeParse(request); + if (!parsed.success) { + const msg = parsed.error.issues.map((i) => i.message).join(", "); + throw new ValidationError(`Invalid request: ${msg}`); + } + return baseApiService.post(`/api/v1/vision-llm-configs`, createVisionLLMConfigResponse, { + body: parsed.data, + }); + }; + + getConfigs = async (searchSpaceId: number) => { + const params = new URLSearchParams({ + search_space_id: String(searchSpaceId), + }).toString(); + return baseApiService.get(`/api/v1/vision-llm-configs?${params}`, getVisionLLMConfigsResponse); + }; + + updateConfig = async (request: UpdateVisionLLMConfigRequest) => { + const parsed = updateVisionLLMConfigRequest.safeParse(request); + if (!parsed.success) { + const msg = parsed.error.issues.map((i) => i.message).join(", "); + throw new ValidationError(`Invalid request: ${msg}`); + } + const { id, data } = parsed.data; + return baseApiService.put(`/api/v1/vision-llm-configs/${id}`, updateVisionLLMConfigResponse, { + body: data, + }); + }; + + deleteConfig = async (id: number) => { + return baseApiService.delete(`/api/v1/vision-llm-configs/${id}`, deleteVisionLLMConfigResponse); + }; +} + +export const visionLLMConfigApiService = new VisionLLMConfigApiService(); diff --git a/surfsense_web/lib/auth-utils.ts b/surfsense_web/lib/auth-utils.ts index 8ad10308b..b7dab7717 100644 --- a/surfsense_web/lib/auth-utils.ts +++ b/surfsense_web/lib/auth-utils.ts @@ -1,7 +1,7 @@ /** * Authentication utilities for handling token expiration and redirects */ -import { buildBackendUrl } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; const REDIRECT_PATH_KEY = "surfsense_redirect_path"; const BEARER_TOKEN_KEY = "surfsense_bearer_token"; @@ -195,7 +195,7 @@ export async function logout(): Promise<boolean> { // Call backend to revoke the refresh token if (refreshToken) { try { - const response = await fetch(buildBackendUrl("/auth/jwt/revoke"), { + const response = await fetch(`${BACKEND_URL}/auth/jwt/revoke`, { method: "POST", headers: { "Content-Type": "application/json", @@ -273,7 +273,7 @@ export async function refreshAccessToken(): Promise<string | null> { isRefreshing = true; refreshPromise = (async () => { try { - const response = await fetch(buildBackendUrl("/auth/jwt/refresh"), { + const response = await fetch(`${BACKEND_URL}/auth/jwt/refresh`, { method: "POST", headers: { "Content-Type": "application/json", diff --git a/surfsense_web/lib/automations/builder-schema.ts b/surfsense_web/lib/automations/builder-schema.ts index 5bb034bef..c2bd69209 100644 --- a/surfsense_web/lib/automations/builder-schema.ts +++ b/surfsense_web/lib/automations/builder-schema.ts @@ -73,7 +73,7 @@ export type BuilderExecution = z.infer<typeof builderExecutionSchema>; * later chat/search-space model changes. */ export const builderModelsSchema = z.object({ - chatModelId: z.number().int(), + agentLlmId: z.number().int(), imageConfigId: z.number().int(), visionConfigId: z.number().int(), }); @@ -90,7 +90,7 @@ export const builderFormSchema = z.object({ tags: z.array(z.string()), /** Carried through from an edited definition so we don't drop it. */ goal: z.string().nullable(), - /** Selected chat/image/vision models (``0`` = use the eligible default). */ + /** Selected agent/image/vision models (``0`` = use the eligible default). */ models: builderModelsSchema, }); export type BuilderForm = z.infer<typeof builderFormSchema>; @@ -147,7 +147,7 @@ export function createEmptyForm(): BuilderForm { }, tags: [], goal: null, - models: { chatModelId: 0, imageConfigId: 0, visionConfigId: 0 }, + models: { agentLlmId: 0, imageConfigId: 0, visionConfigId: 0 }, }; } @@ -240,9 +240,9 @@ function buildDefinition(form: BuilderForm): AutomationDefinition { ...(hasResolvedModels(form.models) ? { models: { - chat_model_id: form.models.chatModelId, - image_gen_model_id: form.models.imageConfigId, - vision_model_id: form.models.visionConfigId, + agent_llm_id: form.models.agentLlmId, + image_generation_config_id: form.models.imageConfigId, + vision_llm_config_id: form.models.visionConfigId, }, } : {}), @@ -251,7 +251,7 @@ function buildDefinition(form: BuilderForm): AutomationDefinition { /** True once every model slot holds a concrete (non-zero) id. */ export function hasResolvedModels(models: BuilderModels): boolean { - return models.chatModelId !== 0 && models.imageConfigId !== 0 && models.visionConfigId !== 0; + return models.agentLlmId !== 0 && models.imageConfigId !== 0 && models.visionConfigId !== 0; } /** The desired schedule trigger for this form, or ``null`` if none. */ @@ -500,9 +500,9 @@ function modelsFromDefinition(raw: unknown): BuilderModels { const m = asRecord(raw); const num = (value: unknown) => (typeof value === "number" ? value : 0); return { - chatModelId: num(m.chat_model_id), - imageConfigId: num(m.image_gen_model_id), - visionConfigId: num(m.vision_model_id), + agentLlmId: num(m.agent_llm_id), + imageConfigId: num(m.image_generation_config_id), + visionConfigId: num(m.vision_llm_config_id), }; } diff --git a/surfsense_web/lib/chat/chat-error-classifier.ts b/surfsense_web/lib/chat/chat-error-classifier.ts index 92924f0f7..1c67d59a1 100644 --- a/surfsense_web/lib/chat/chat-error-classifier.ts +++ b/surfsense_web/lib/chat/chat-error-classifier.ts @@ -5,10 +5,6 @@ export type ChatErrorKind = | "thread_busy" | "send_failed_pre_accept" | "auth_expired" - | "model_auth_failed" - | "model_not_found" - | "model_context_limit" - | "model_provider_unavailable" | "rate_limited" | "network_offline" | "stream_interrupted" @@ -18,7 +14,7 @@ export type ChatErrorKind = | "server_error" | "unknown"; -export type ChatErrorChannel = "pinned_inline" | "inline" | "toast" | "silent"; +export type ChatErrorChannel = "pinned_inline" | "toast" | "silent"; export type ChatTelemetryEvent = "chat_blocked" | "chat_error"; export type ChatErrorSeverity = "info" | "warn" | "error"; @@ -210,66 +206,6 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError }; } - if (errorCode === "MODEL_AUTH_FAILED") { - return { - kind: "model_auth_failed", - channel: "toast", - severity: "warn", - telemetryEvent: "chat_blocked", - isExpected: true, - userMessage: - "This model’s API key is invalid or expired. Switch models, or update the API key.", - rawMessage, - errorCode: errorCode ?? "MODEL_AUTH_FAILED", - details: { flow: input.flow, providerErrorType }, - }; - } - - if (errorCode === "MODEL_NOT_FOUND") { - return { - kind: "model_not_found", - channel: "toast", - severity: "warn", - telemetryEvent: "chat_blocked", - isExpected: true, - userMessage: - "This model is unavailable or no longer exists. Switch to another model and try again.", - rawMessage, - errorCode: errorCode ?? "MODEL_NOT_FOUND", - details: { flow: input.flow, providerErrorType }, - }; - } - - if (errorCode === "MODEL_CONTEXT_LIMIT") { - return { - kind: "model_context_limit", - channel: "toast", - severity: "warn", - telemetryEvent: "chat_blocked", - isExpected: true, - userMessage: - "This request is too large for the selected model. Reduce the input or switch models.", - rawMessage, - errorCode: errorCode ?? "MODEL_CONTEXT_LIMIT", - details: { flow: input.flow, providerErrorType }, - }; - } - - if (errorCode === "MODEL_PROVIDER_UNAVAILABLE") { - return { - kind: "model_provider_unavailable", - channel: "toast", - severity: "warn", - telemetryEvent: "chat_blocked", - isExpected: true, - userMessage: - "The selected model provider is temporarily unavailable. Please try again or switch models.", - rawMessage, - errorCode: errorCode ?? "MODEL_PROVIDER_UNAVAILABLE", - details: { flow: input.flow, providerErrorType }, - }; - } - if (errorCode === "RATE_LIMITED" || providerTypeNormalized === "rate_limit_error") { return { kind: "rate_limited", diff --git a/surfsense_web/lib/chat/chat-request-errors.ts b/surfsense_web/lib/chat/chat-request-errors.ts index c86c72d66..e0dfb3cc4 100644 --- a/surfsense_web/lib/chat/chat-request-errors.ts +++ b/surfsense_web/lib/chat/chat-request-errors.ts @@ -91,10 +91,6 @@ export function tagPreAcceptSendFailure(error: unknown): unknown { "TURN_CANCELLING", "AUTH_EXPIRED", "UNAUTHORIZED", - "MODEL_AUTH_FAILED", - "MODEL_NOT_FOUND", - "MODEL_CONTEXT_LIMIT", - "MODEL_PROVIDER_UNAVAILABLE", "RATE_LIMITED", "NETWORK_ERROR", "STREAM_PARSE_ERROR", diff --git a/surfsense_web/lib/chat/podcast-state.ts b/surfsense_web/lib/chat/podcast-state.ts new file mode 100644 index 000000000..061a89b63 --- /dev/null +++ b/surfsense_web/lib/chat/podcast-state.ts @@ -0,0 +1,73 @@ +/** + * Module-level state for tracking active podcast generation. + * Used by the new-chat adapter to prevent duplicate podcast requests. + */ + +type PodcastStateListener = (isGenerating: boolean) => void; + +let _activePodcastTaskId: string | null = null; +const _listeners: Set<PodcastStateListener> = new Set(); + +/** + * Check if a podcast is currently being generated + */ +export function isPodcastGenerating(): boolean { + return _activePodcastTaskId !== null; +} + +/** + * Get the active podcast task ID + */ +export function getActivePodcastTaskId(): string | null { + return _activePodcastTaskId; +} + +/** + * Set the active podcast task ID (when podcast generation starts) + */ +export function setActivePodcastTaskId(taskId: string): void { + _activePodcastTaskId = taskId; + notifyListeners(); +} + +/** + * Clear the active podcast task ID (when podcast generation completes or errors) + */ +export function clearActivePodcastTaskId(): void { + _activePodcastTaskId = null; + notifyListeners(); +} + +/** + * Subscribe to podcast state changes + */ +export function subscribeToPodcastState(listener: PodcastStateListener): () => void { + _listeners.add(listener); + return () => { + _listeners.delete(listener); + }; +} + +function notifyListeners(): void { + const isGenerating = _activePodcastTaskId !== null; + for (const listener of _listeners) { + listener(isGenerating); + } +} + +/** + * Check if a message looks like a podcast request + */ +export function looksLikePodcastRequest(message: string): boolean { + const podcastPatterns = [ + /\bpodcast\b/i, + /\bcreate.*podcast\b/i, + /\bgenerate.*podcast\b/i, + /\bmake.*podcast\b/i, + /\bturn.*into.*podcast\b/i, + /\bpodcast.*about\b/i, + /\bgive.*podcast\b/i, + ]; + + return podcastPatterns.some((pattern) => pattern.test(message)); +} diff --git a/surfsense_web/lib/chat/thread-persistence.ts b/surfsense_web/lib/chat/thread-persistence.ts index dc5846f23..d30b87665 100644 --- a/surfsense_web/lib/chat/thread-persistence.ts +++ b/surfsense_web/lib/chat/thread-persistence.ts @@ -4,7 +4,7 @@ */ import { baseApiService } from "@/lib/apis/base-api.service"; -import { buildBackendUrl } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; // ============================================================================= // Types matching backend schemas // ============================================================================= @@ -227,5 +227,5 @@ export interface RegenerateParams { * Get the URL for the regenerate endpoint (for streaming fetch) */ export function getRegenerateUrl(threadId: number): string { - return buildBackendUrl(`/api/v1/threads/${threadId}/regenerate`); + return `${BACKEND_URL}/api/v1/threads/${threadId}/regenerate`; } diff --git a/surfsense_web/lib/env-config.ts b/surfsense_web/lib/env-config.ts index 8c671029c..80db395c6 100644 --- a/surfsense_web/lib/env-config.ts +++ b/surfsense_web/lib/env-config.ts @@ -1,81 +1,47 @@ /** * Environment configuration for the frontend. * - * Docker deployments use same-origin relative browser URLs behind Caddy. - * NEXT_PUBLIC_* values remain only as build-time fallbacks for packaged clients - * like Electron, where there is no bundled Caddy origin. + * This file centralizes access to NEXT_PUBLIC_* environment variables. + * For Docker deployments, these placeholders are replaced at container startup + * via sed in the entrypoint script. + * + * IMPORTANT: Do not use template literals or complex expressions with these values + * as it may prevent the sed replacement from working correctly. */ import packageJson from "../package.json"; -// Build-time fallback for packaged clients. Docker runtime reads plain AUTH_TYPE -// through the runtime config provider first, then falls back to this baked value. -export const BUILD_TIME_AUTH_TYPE = process.env.NEXT_PUBLIC_AUTH_TYPE || "GOOGLE"; +// Auth type: "LOCAL" for email/password, "GOOGLE" for OAuth +// Placeholder: __NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE__ +export const AUTH_TYPE = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE || "GOOGLE"; -// Backend API URL. An empty string is valid in proxy mode and means -// same-origin relative requests (e.g. /api/v1/... and /auth/...). -export const BACKEND_URL = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL ?? ""; +// Backend API URL +// Placeholder: __NEXT_PUBLIC_FASTAPI_BACKEND_URL__ +export const BACKEND_URL = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; -type BackendUrlParam = string | number | boolean | null | undefined; +// ETL Service: "DOCLING", "UNSTRUCTURED", or "LLAMACLOUD" +// Placeholder: __NEXT_PUBLIC_ETL_SERVICE__ +export const ETL_SERVICE = process.env.NEXT_PUBLIC_ETL_SERVICE || "DOCLING"; -/** - * Build browser-facing backend URLs without breaking proxy mode. - * - * In proxy mode BACKEND_URL intentionally stays empty, so callers must keep - * same-origin relative URLs ("/api/v1/...") and let Caddy route them. When - * BACKEND_URL is explicitly configured, the same path resolves against that - * absolute backend origin. - */ -export function buildBackendUrl(path: string, params?: Record<string, BackendUrlParam>): string { - const backendPath = path.startsWith("/") ? path : `/${path}`; - const queryParams = new URLSearchParams(); - - if (params) { - for (const [key, value] of Object.entries(params)) { - if (value !== null && value !== undefined) { - queryParams.append(key, String(value)); - } - } - } - - if (BACKEND_URL) { - const url = new URL(backendPath, BACKEND_URL); - for (const [key, value] of queryParams) { - url.searchParams.append(key, value); - } - return url.toString(); - } - - const queryString = queryParams.toString(); - if (!queryString) return backendPath; - return `${backendPath}${backendPath.includes("?") ? "&" : "?"}${queryString}`; -} - -// Server-side backend URL. Relative browser URLs do not work from RSC/API route -// code, so server callers should use Docker DNS or an explicit public backend. -export const SERVER_BACKEND_URL = - process.env.SURFSENSE_BACKEND_INTERNAL_URL || - // TODO: Remove FASTAPI_BACKEND_INTERNAL_URL after the post-Caddy env migration window. - process.env.FASTAPI_BACKEND_INTERNAL_URL || - "http://backend:8000"; - -// Build-time fallback for packaged clients. Docker runtime reads plain ETL_SERVICE -// through the runtime config provider first, then falls back to this baked value. -export const BUILD_TIME_ETL_SERVICE = process.env.NEXT_PUBLIC_ETL_SERVICE || "DOCLING"; - -// Build-time fallback for packaged clients. Docker runtime reads plain -// DEPLOYMENT_MODE through the runtime config provider first, then falls back to this baked value. -export const BUILD_TIME_DEPLOYMENT_MODE = process.env.NEXT_PUBLIC_DEPLOYMENT_MODE || "self-hosted"; +// Deployment Mode: "self-hosted" or "cloud" +// Matches backend's SURFSENSE_DEPLOYMENT_MODE - defaults to "self-hosted" +// self-hosted: Full access to local file system connectors (Obsidian, etc.) +// cloud: Only cloud-based connectors available +// Placeholder: __NEXT_PUBLIC_DEPLOYMENT_MODE__ +export const DEPLOYMENT_MODE = process.env.NEXT_PUBLIC_DEPLOYMENT_MODE || "self-hosted"; // App version - defaults to package.json version // Can be overridden at build time with NEXT_PUBLIC_APP_VERSION for full git tag version export const APP_VERSION = process.env.NEXT_PUBLIC_APP_VERSION || packageJson.version; -// Global announcement banner. Useful for planned downtime / maintenance notices. -// Toggle on Vercel via NEXT_PUBLIC_GLOBAL_ANNOUNCEMENT_ENABLED ("true" to show) and -// set the copy with NEXT_PUBLIC_GLOBAL_ANNOUNCEMENT_MESSAGE. -export const GLOBAL_ANNOUNCEMENT_ENABLED = - process.env.NEXT_PUBLIC_GLOBAL_ANNOUNCEMENT_ENABLED === "true"; +// Helper to check if local auth is enabled +export const isLocalAuth = () => AUTH_TYPE === "LOCAL"; -export const GLOBAL_ANNOUNCEMENT_MESSAGE = - process.env.NEXT_PUBLIC_GLOBAL_ANNOUNCEMENT_MESSAGE ?? ""; +// Helper to check if Google auth is enabled +export const isGoogleAuth = () => AUTH_TYPE === "GOOGLE"; + +// Helper to check if running in self-hosted mode +export const isSelfHosted = () => DEPLOYMENT_MODE === "self-hosted"; + +// Helper to check if running in cloud mode +export const isCloud = () => DEPLOYMENT_MODE === "cloud"; diff --git a/surfsense_web/lib/onboarding.ts b/surfsense_web/lib/onboarding.ts index 5ca597137..b87f822a0 100644 --- a/surfsense_web/lib/onboarding.ts +++ b/surfsense_web/lib/onboarding.ts @@ -1,28 +1,8 @@ -import type { ConnectionRead } from "@/contracts/types/model-connections.types"; - -export function hasEnabledChatModel(connections: ConnectionRead[]): boolean { - return connections.some( - (connection) => - connection.enabled && - connection.models.some((model) => model.enabled && Boolean(model.supports_chat)) - ); -} - export function isLlmOnboardingComplete( - chatModelId: number | null | undefined, - globalConnections: ConnectionRead[], - searchSpaceConnections: ConnectionRead[] + agentLlmId: number | null | undefined, + hasGlobalConfigs: boolean ): boolean { - const connections = [...globalConnections, ...searchSpaceConnections]; - const resolvedChatModelId = chatModelId ?? 0; - - if (resolvedChatModelId === 0) { - return hasEnabledChatModel(connections); - } - - return connections.some((connection) => - connection.models.some( - (model) => model.id === resolvedChatModelId && model.enabled && Boolean(model.supports_chat) - ) - ); + if (agentLlmId === null || agentLlmId === undefined) return false; + if (agentLlmId === 0) return hasGlobalConfigs; + return true; } diff --git a/surfsense_web/lib/posthog/events.ts b/surfsense_web/lib/posthog/events.ts index 41ac7e7b2..4dc644d5e 100644 --- a/surfsense_web/lib/posthog/events.ts +++ b/surfsense_web/lib/posthog/events.ts @@ -569,10 +569,10 @@ export function trackIncentivePageViewed() { safeCapture("incentive_page_viewed"); } -export function trackIncentiveTaskCompleted(taskType: string, creditMicrosRewarded: number) { +export function trackIncentiveTaskCompleted(taskType: string, pagesRewarded: number) { safeCapture("incentive_task_completed", { task_type: taskType, - credit_micros_rewarded: creditMicrosRewarded, + pages_rewarded: pagesRewarded, }); } @@ -609,9 +609,9 @@ interface AutomationCreatedProps { task_count?: number; trigger_type?: string; has_schedule?: boolean; - chat_model_id?: number; - image_gen_model_id?: number; - vision_model_id?: number; + agent_llm_id?: number; + image_generation_config_id?: number; + vision_llm_config_id?: number; tags_count?: number; } @@ -705,9 +705,9 @@ interface AutomationChatDecisionProps { edited?: boolean; task_count?: number; trigger_type?: string; - chat_model_id?: number; - image_gen_model_id?: number; - vision_model_id?: number; + agent_llm_id?: number; + image_generation_config_id?: number; + vision_llm_config_id?: number; } export function trackAutomationChatApproved(props: AutomationChatDecisionProps) { diff --git a/surfsense_web/lib/provider-icons.tsx b/surfsense_web/lib/provider-icons.tsx index d3e799720..e63c5eb2f 100644 --- a/surfsense_web/lib/provider-icons.tsx +++ b/surfsense_web/lib/provider-icons.tsx @@ -1,11 +1,10 @@ -import { Cpu, Shuffle } from "lucide-react"; +import { Bot, Shuffle } from "lucide-react"; import { Ai21Icon, + AnthropicIcon, AnyscaleIcon, - AzureIcon, BedrockIcon, CerebrasIcon, - ClaudeIcon, CloudflareIcon, CohereIcon, CometApiIcon, @@ -17,7 +16,6 @@ import { GitHubModelsIcon, GroqIcon, HuggingFaceIcon, - LmStudioIcon, MiniMaxIcon, MistralIcon, MoonshotIcon, @@ -38,8 +36,6 @@ import { } from "@/components/icons/providers"; import { cn } from "@/lib/utils"; -export const AUTO_PROVIDER_ICON_KEY = "AUTO"; - /** * Returns a Lucide icon element for the given LLM / image-gen provider. * Accepts an optional `className` override for the icon size. @@ -48,7 +44,7 @@ export function getProviderIcon( provider: string, { isAutoMode, className = "size-4" }: { isAutoMode?: boolean; className?: string } = {} ) { - if (isAutoMode || provider?.toUpperCase() === AUTO_PROVIDER_ICON_KEY) { + if (isAutoMode || provider?.toUpperCase() === "AUTO") { return <Shuffle className={cn(className, "text-muted-foreground")} />; } @@ -58,13 +54,12 @@ export function getProviderIcon( case "ALIBABA_QWEN": return <QwenIcon className={cn(className)} />; case "ANTHROPIC": - case "CLAUDE": - return <ClaudeIcon className={cn(className)} />; + return <AnthropicIcon className={cn(className)} />; case "ANYSCALE": return <AnyscaleIcon className={cn(className)} />; case "AZURE": case "AZURE_OPENAI": - return <AzureIcon className={cn(className)} />; + return <OpenaiIcon className={cn(className)} />; case "AWS_BEDROCK": case "BEDROCK": return <BedrockIcon className={cn(className)} />; @@ -77,7 +72,7 @@ export function getProviderIcon( case "COMETAPI": return <CometApiIcon className={cn(className)} />; case "CUSTOM": - return <Cpu className={cn(className)} />; + return <Bot className={cn(className, "text-gray-400")} />; case "DATABRICKS": return <DatabricksIcon className={cn(className)} />; case "DEEPINFRA": @@ -94,8 +89,6 @@ export function getProviderIcon( return <GroqIcon className={cn(className)} />; case "HUGGINGFACE": return <HuggingFaceIcon className={cn(className)} />; - case "LM_STUDIO": - return <LmStudioIcon className={cn(className)} />; case "MINIMAX": return <MiniMaxIcon className={cn(className)} />; case "MISTRAL": @@ -105,7 +98,6 @@ export function getProviderIcon( case "NSCALE": return <NscaleIcon className={cn(className)} />; case "OLLAMA": - case "OLLAMA_CHAT": return <OllamaIcon className={cn(className)} />; case "OPENAI": return <OpenaiIcon className={cn(className)} />; @@ -130,6 +122,6 @@ export function getProviderIcon( case "ZHIPU": return <ZhipuIcon className={cn(className)} />; default: - return <Cpu className={cn(className, "text-muted-foreground")} />; + return <Bot className={cn(className, "text-muted-foreground")} />; } } diff --git a/surfsense_web/lib/query-client/cache-keys.ts b/surfsense_web/lib/query-client/cache-keys.ts index 193f53b3c..6f8885d7e 100644 --- a/surfsense_web/lib/query-client/cache-keys.ts +++ b/surfsense_web/lib/query-client/cache-keys.ts @@ -36,12 +36,24 @@ export const cacheKeys = { withQueryParams: (queries: GetLogsRequest["queryParams"]) => ["logs", "with-query-params", ...stableEntries(queries)] as const, }, - modelConnections: { - all: (searchSpaceId: number) => ["model-connections", searchSpaceId] as const, - global: () => ["model-connections", "global"] as const, - globalConfigStatus: () => ["model-connections", "global-config-status"] as const, - providers: () => ["model-connections", "providers"] as const, - roles: (searchSpaceId: number) => ["model-roles", searchSpaceId] as const, + newLLMConfigs: { + all: (searchSpaceId: number) => ["new-llm-configs", searchSpaceId] as const, + byId: (configId: number) => ["new-llm-configs", "detail", configId] as const, + preferences: (searchSpaceId: number) => ["llm-preferences", searchSpaceId] as const, + defaultInstructions: () => ["new-llm-configs", "default-instructions"] as const, + global: () => ["new-llm-configs", "global"] as const, + modelList: () => ["models", "catalogue"] as const, + }, + imageGenConfigs: { + all: (searchSpaceId: number) => ["image-gen-configs", searchSpaceId] as const, + byId: (configId: number) => ["image-gen-configs", "detail", configId] as const, + global: () => ["image-gen-configs", "global"] as const, + }, + visionLLMConfigs: { + all: (searchSpaceId: number) => ["vision-llm-configs", searchSpaceId] as const, + byId: (configId: number) => ["vision-llm-configs", "detail", configId] as const, + global: () => ["vision-llm-configs", "global"] as const, + modelList: () => ["vision-models", "catalogue"] as const, }, auth: { user: ["auth", "user"] as const, diff --git a/surfsense_web/lib/runtime-auth-config.ts b/surfsense_web/lib/runtime-auth-config.ts deleted file mode 100644 index 9e8d1921d..000000000 --- a/surfsense_web/lib/runtime-auth-config.ts +++ /dev/null @@ -1,52 +0,0 @@ -export const RUNTIME_AUTH_TYPE_COOKIE_NAME = "surfsense_auth_type"; - -export type RuntimeAuthUiMode = "GOOGLE" | "LOCAL"; - -export function resolveRuntimeAuthUiMode( - value: string | null | undefined, - fallback: string | null | undefined = "GOOGLE" -): RuntimeAuthUiMode { - const candidate = value?.trim().toUpperCase(); - if (candidate === "GOOGLE") return "GOOGLE"; - if (candidate === "LOCAL") return "LOCAL"; - - const fallbackCandidate = fallback?.trim().toUpperCase(); - return fallbackCandidate === "GOOGLE" ? "GOOGLE" : "LOCAL"; -} - -export function getRuntimeAuthInitScript(fallbackAuthType: string): string { - const fallback = resolveRuntimeAuthUiMode(fallbackAuthType); - const cookieName = JSON.stringify(RUNTIME_AUTH_TYPE_COOKIE_NAME); - const fallbackValue = JSON.stringify(fallback); - - return ` -(function() { - try { - var cookieName = ${cookieName}; - var fallback = ${fallbackValue}; - var prefix = cookieName + "="; - var rawValue = fallback; - var cookies = document.cookie ? document.cookie.split(";") : []; - for (var i = 0; i < cookies.length; i++) { - var cookie = cookies[i].trim(); - if (cookie.indexOf(prefix) === 0) { - rawValue = decodeURIComponent(cookie.slice(prefix.length)); - break; - } - } - var normalized = String(rawValue || fallback).toUpperCase() === "GOOGLE" ? "GOOGLE" : "LOCAL"; - window.__SURFSENSE_AUTH_TYPE__ = normalized; - document.documentElement.setAttribute("data-surfsense-auth-type", normalized); - } catch (_) { - window.__SURFSENSE_AUTH_TYPE__ = ${fallbackValue}; - document.documentElement.setAttribute("data-surfsense-auth-type", ${fallbackValue}); - } -})(); -`; -} - -declare global { - interface Window { - __SURFSENSE_AUTH_TYPE__?: RuntimeAuthUiMode; - } -} diff --git a/surfsense_web/lib/source.ts b/surfsense_web/lib/source.ts index 13fb58f22..f71e8b688 100644 --- a/surfsense_web/lib/source.ts +++ b/surfsense_web/lib/source.ts @@ -4,7 +4,6 @@ import { ClipboardCheck, Compass, Container, - Cpu, Download, FlaskConical, Heart, @@ -26,7 +25,6 @@ const DOCS_ICONS: Record<string, React.ComponentType> = { ClipboardCheck, Compass, Container, - Cpu, Download, FlaskConical, Heart, diff --git a/surfsense_web/lib/supported-extensions.ts b/surfsense_web/lib/supported-extensions.ts index 7005cd698..f615b3d46 100644 --- a/surfsense_web/lib/supported-extensions.ts +++ b/surfsense_web/lib/supported-extensions.ts @@ -75,23 +75,18 @@ export const FILE_TYPE_CONFIG: Record<string, Record<string, string[]>> = { }, }; -export function getAcceptedFileTypes(etlService?: string): Record<string, string[]> { +export function getAcceptedFileTypes(): Record<string, string[]> { + const etlService = process.env.NEXT_PUBLIC_ETL_SERVICE; return FILE_TYPE_CONFIG[etlService || "default"] || FILE_TYPE_CONFIG.default; } -export function getSupportedExtensions( - acceptedFileTypes?: Record<string, string[]>, - etlService?: string -): string[] { - const types = acceptedFileTypes ?? getAcceptedFileTypes(etlService); +export function getSupportedExtensions(acceptedFileTypes?: Record<string, string[]>): string[] { + const types = acceptedFileTypes ?? getAcceptedFileTypes(); return Array.from(new Set(Object.values(types).flat())).sort(); } export function getSupportedExtensionsSet( - acceptedFileTypes?: Record<string, string[]>, - etlService?: string + acceptedFileTypes?: Record<string, string[]> ): Set<string> { - return new Set( - getSupportedExtensions(acceptedFileTypes, etlService).map((ext) => ext.toLowerCase()) - ); + return new Set(getSupportedExtensions(acceptedFileTypes).map((ext) => ext.toLowerCase())); } diff --git a/surfsense_web/messages/en.json b/surfsense_web/messages/en.json index 866ba4844..a13942e64 100644 --- a/surfsense_web/messages/en.json +++ b/surfsense_web/messages/en.json @@ -476,7 +476,9 @@ "title": "Settings", "subtitle": "Manage your LLM configurations and role assignments for this search space.", "back_to_dashboard": "Back to Dashboard", + "model_configs": "Model Configs", "models": "Models", + "llm_roles": "LLM Roles", "roles": "Roles", "llm_role_management": "LLM Role Management", "llm_role_desc": "Assign your LLM configurations to specific roles for different purposes.", @@ -741,9 +743,14 @@ "back_to_app": "Back to app", "nav_general": "General", "nav_general_desc": "Name, description & basic info", - "nav_models": "Models", - "nav_agent_models": "Chat Models", + "nav_agent_models": "Agent Models", "nav_agent_models_desc": "Models with prompts & citations", + "nav_role_assignments": "Role Assignments", + "nav_role_assignments_desc": "Assign configs to agent roles", + "nav_image_models": "Image Models", + "nav_image_models_desc": "Configure image generation models", + "nav_vision_models": "Vision Models", + "nav_vision_models_desc": "Configure vision-capable LLM models", "nav_system_instructions": "System Instructions", "nav_system_instructions_desc": "SearchSpace-wide AI instructions", "nav_public_links": "Public Chat Links", diff --git a/surfsense_web/messages/es.json b/surfsense_web/messages/es.json index f7755b47e..33ae79c52 100644 --- a/surfsense_web/messages/es.json +++ b/surfsense_web/messages/es.json @@ -476,7 +476,9 @@ "title": "Configuración", "subtitle": "Administra tus configuraciones de LLM y asignaciones de roles para este espacio de búsqueda.", "back_to_dashboard": "Volver al panel de control", + "model_configs": "Configuraciones de modelos", "models": "Modelos", + "llm_roles": "Roles de LLM", "roles": "Roles", "llm_role_management": "Gestión de roles de LLM", "llm_role_desc": "Asigna tus configuraciones de LLM a roles específicos para diferentes propósitos.", @@ -741,9 +743,14 @@ "back_to_app": "Volver a la app", "nav_general": "General", "nav_general_desc": "Nombre, descripción e información básica", - "nav_models": "Modelos", - "nav_agent_models": "Modelos de chat", + "nav_agent_models": "Modelos de agente", "nav_agent_models_desc": "Modelos LLM con prompts y citas", + "nav_role_assignments": "Asignaciones de roles", + "nav_role_assignments_desc": "Asignar configuraciones a roles de agente", + "nav_image_models": "Modelos de imagen", + "nav_image_models_desc": "Configurar modelos de generación de imágenes", + "nav_vision_models": "Modelos de visión", + "nav_vision_models_desc": "Configurar modelos LLM con capacidad de visión", "nav_system_instructions": "Instrucciones del sistema", "nav_system_instructions_desc": "Instrucciones de IA a nivel del espacio de búsqueda", "nav_public_links": "Enlaces de chat públicos", @@ -759,27 +766,7 @@ "general_reset": "Restablecer cambios", "general_save": "Guardar cambios", "general_saving": "Guardando", - "general_unsaved_changes": "Tienes cambios sin guardar. Haz clic en \"Guardar cambios\" para aplicarlos.", - "nav_web_search": "Búsqueda web", - "nav_web_search_desc": "Configuración de búsqueda web integrada", - "web_search_title": "Búsqueda web", - "web_search_description": "La búsqueda web funciona con una instancia SearXNG integrada. Todas las consultas se procesan a través de tu servidor; no se envían datos a terceros.", - "web_search_enabled_label": "Activar búsqueda web", - "web_search_enabled_description": "Cuando está activada, el agente de IA puede buscar en la web información en tiempo real como noticias, precios y eventos actuales.", - "web_search_status_healthy": "El servicio de búsqueda web está funcionando", - "web_search_status_unhealthy": "El servicio de búsqueda web no está disponible", - "web_search_status_not_configured": "El servicio de búsqueda web no está configurado", - "web_search_engines_label": "Motores de búsqueda", - "web_search_engines_placeholder": "google,brave,duckduckgo", - "web_search_engines_description": "Lista separada por comas de motores SearXNG a usar. Déjalo vacío para usar los valores predeterminados.", - "web_search_language_label": "Idioma preferido", - "web_search_language_placeholder": "es", - "web_search_language_description": "Etiqueta de idioma IETF (por ejemplo, es, es-ES). Déjalo vacío para detección automática.", - "web_search_safesearch_label": "Nivel de SafeSearch", - "web_search_safesearch_description": "0 = desactivado, 1 = moderado, 2 = estricto", - "web_search_save": "Guardar configuración de búsqueda web", - "web_search_saving": "Guardando...", - "web_search_saved": "Configuración de búsqueda web guardada" + "general_unsaved_changes": "Tienes cambios sin guardar. Haz clic en \"Guardar cambios\" para aplicarlos." }, "homepage": { "hero_title_part1": "El espacio de trabajo con IA", diff --git a/surfsense_web/messages/hi.json b/surfsense_web/messages/hi.json index 038555f1e..7a26d0c1d 100644 --- a/surfsense_web/messages/hi.json +++ b/surfsense_web/messages/hi.json @@ -476,7 +476,9 @@ "title": "सेटिंग्स", "subtitle": "इस सर्च स्पेस के लिए अपनी LLM कॉन्फ़िगरेशन और भूमिका असाइनमेंट प्रबंधित करें।", "back_to_dashboard": "डैशबोर्ड पर वापस जाएं", + "model_configs": "मॉडल कॉन्फ़िगरेशन", "models": "मॉडल", + "llm_roles": "LLM भूमिकाएं", "roles": "भूमिकाएं", "llm_role_management": "LLM भूमिका प्रबंधन", "llm_role_desc": "विभिन्न उद्देश्यों के लिए अपनी LLM कॉन्फ़िगरेशन को विशिष्ट भूमिकाओं में असाइन करें।", @@ -741,9 +743,14 @@ "back_to_app": "ऐप पर वापस जाएं", "nav_general": "सामान्य", "nav_general_desc": "नाम, विवरण और बुनियादी जानकारी", - "nav_models": "मॉडल", - "nav_agent_models": "चैट मॉडल", + "nav_agent_models": "एजेंट मॉडल", "nav_agent_models_desc": "प्रॉम्प्ट और उद्धरण के साथ LLM मॉडल", + "nav_role_assignments": "भूमिका असाइनमेंट", + "nav_role_assignments_desc": "एजेंट भूमिकाओं को कॉन्फ़िगरेशन असाइन करें", + "nav_image_models": "इमेज मॉडल", + "nav_image_models_desc": "इमेज जनरेशन मॉडल कॉन्फ़िगर करें", + "nav_vision_models": "विज़न मॉडल", + "nav_vision_models_desc": "विज़न-सक्षम LLM मॉडल कॉन्फ़िगर करें", "nav_system_instructions": "सिस्टम निर्देश", "nav_system_instructions_desc": "सर्च स्पेस-व्यापी AI निर्देश", "nav_public_links": "सार्वजनिक चैट लिंक", @@ -759,27 +766,7 @@ "general_reset": "परिवर्तन रीसेट करें", "general_save": "परिवर्तन सहेजें", "general_saving": "सहेजा जा रहा है", - "general_unsaved_changes": "आपके पास सहेजे नहीं गए परिवर्तन हैं। उन्हें लागू करने के लिए \"परिवर्तन सहेजें\" पर क्लिक करें।", - "nav_web_search": "वेब खोज", - "nav_web_search_desc": "बिल्ट-इन वेब खोज सेटिंग्स", - "web_search_title": "वेब खोज", - "web_search_description": "वेब खोज एक बिल्ट-इन SearXNG इंस्टेंस द्वारा संचालित है। सभी क्वेरी आपके सर्वर के माध्यम से प्रॉक्सी की जाती हैं; कोई डेटा तृतीय पक्षों को नहीं भेजा जाता।", - "web_search_enabled_label": "वेब खोज सक्षम करें", - "web_search_enabled_description": "सक्षम होने पर, AI एजेंट समाचार, कीमतों और वर्तमान घटनाओं जैसी वास्तविक समय की जानकारी के लिए वेब खोज सकता है।", - "web_search_status_healthy": "वेब खोज सेवा स्वस्थ है", - "web_search_status_unhealthy": "वेब खोज सेवा उपलब्ध नहीं है", - "web_search_status_not_configured": "वेब खोज सेवा कॉन्फ़िगर नहीं है", - "web_search_engines_label": "खोज इंजन", - "web_search_engines_placeholder": "google,brave,duckduckgo", - "web_search_engines_description": "उपयोग करने के लिए SearXNG इंजनों की कॉमा-सेपरेटेड सूची। डिफ़ॉल्ट के लिए खाली छोड़ें।", - "web_search_language_label": "पसंदीदा भाषा", - "web_search_language_placeholder": "hi", - "web_search_language_description": "IETF भाषा टैग (जैसे hi, hi-IN)। ऑटो-डिटेक्ट के लिए खाली छोड़ें।", - "web_search_safesearch_label": "SafeSearch स्तर", - "web_search_safesearch_description": "0 = बंद, 1 = मध्यम, 2 = सख्त", - "web_search_save": "वेब खोज सेटिंग्स सहेजें", - "web_search_saving": "सहेजा जा रहा है...", - "web_search_saved": "वेब खोज सेटिंग्स सहेजी गईं" + "general_unsaved_changes": "आपके पास सहेजे नहीं गए परिवर्तन हैं। उन्हें लागू करने के लिए \"परिवर्तन सहेजें\" पर क्लिक करें।" }, "homepage": { "hero_title_part1": "AI कार्यक्षेत्र", diff --git a/surfsense_web/messages/pt.json b/surfsense_web/messages/pt.json index bcba8f70c..61c22e086 100644 --- a/surfsense_web/messages/pt.json +++ b/surfsense_web/messages/pt.json @@ -476,7 +476,9 @@ "title": "Configurações", "subtitle": "Gerencie suas configurações de LLM e atribuições de funções para este espaço de pesquisa.", "back_to_dashboard": "Voltar ao painel", + "model_configs": "Configurações de modelos", "models": "Modelos", + "llm_roles": "Funções de LLM", "roles": "Funções", "llm_role_management": "Gerenciamento de funções de LLM", "llm_role_desc": "Atribua suas configurações de LLM a funções específicas para diferentes propósitos.", @@ -741,9 +743,14 @@ "back_to_app": "Voltar ao app", "nav_general": "Geral", "nav_general_desc": "Nome, descrição e informações básicas", - "nav_models": "Modelos", - "nav_agent_models": "Modelos de chat", + "nav_agent_models": "Modelos do agente", "nav_agent_models_desc": "Modelos LLM com prompts e citações", + "nav_role_assignments": "Atribuições de funções", + "nav_role_assignments_desc": "Atribuir configurações a funções do agente", + "nav_image_models": "Modelos de imagem", + "nav_image_models_desc": "Configurar modelos de geração de imagens", + "nav_vision_models": "Modelos de visão", + "nav_vision_models_desc": "Configurar modelos LLM com capacidade de visão", "nav_system_instructions": "Instruções do sistema", "nav_system_instructions_desc": "Instruções de IA em nível do espaço de pesquisa", "nav_public_links": "Links de chat públicos", @@ -759,27 +766,7 @@ "general_reset": "Redefinir alterações", "general_save": "Salvar alterações", "general_saving": "Salvando", - "general_unsaved_changes": "Você tem alterações não salvas. Clique em \"Salvar alterações\" para aplicá-las.", - "nav_web_search": "Pesquisa na web", - "nav_web_search_desc": "Configurações integradas de pesquisa na web", - "web_search_title": "Pesquisa na web", - "web_search_description": "A pesquisa na web é alimentada por uma instância SearXNG integrada. Todas as consultas passam pelo seu servidor; nenhum dado é enviado a terceiros.", - "web_search_enabled_label": "Ativar pesquisa na web", - "web_search_enabled_description": "Quando ativado, o agente de IA pode pesquisar na web informações em tempo real, como notícias, preços e eventos atuais.", - "web_search_status_healthy": "O serviço de pesquisa na web está saudável", - "web_search_status_unhealthy": "O serviço de pesquisa na web está indisponível", - "web_search_status_not_configured": "O serviço de pesquisa na web não está configurado", - "web_search_engines_label": "Mecanismos de pesquisa", - "web_search_engines_placeholder": "google,brave,duckduckgo", - "web_search_engines_description": "Lista separada por vírgulas de mecanismos SearXNG a usar. Deixe vazio para os padrões.", - "web_search_language_label": "Idioma preferido", - "web_search_language_placeholder": "pt", - "web_search_language_description": "Tag de idioma IETF (por exemplo, pt, pt-BR). Deixe vazio para detecção automática.", - "web_search_safesearch_label": "Nível de SafeSearch", - "web_search_safesearch_description": "0 = desativado, 1 = moderado, 2 = rigoroso", - "web_search_save": "Salvar configurações de pesquisa na web", - "web_search_saving": "Salvando...", - "web_search_saved": "Configurações de pesquisa na web salvas" + "general_unsaved_changes": "Você tem alterações não salvas. Clique em \"Salvar alterações\" para aplicá-las." }, "homepage": { "hero_title_part1": "O espaço de trabalho com IA", diff --git a/surfsense_web/messages/zh.json b/surfsense_web/messages/zh.json index 5fea60eb8..7d0419cbd 100644 --- a/surfsense_web/messages/zh.json +++ b/surfsense_web/messages/zh.json @@ -96,10 +96,6 @@ "create_new_search_space": "创建新的搜索空间", "delete_title": "删除搜索空间", "delete_confirm": "您确定要删除「{name}」吗?此操作无法撤销,将永久删除所有数据。", - "leave": "退出", - "leave_title": "退出搜索空间", - "leave_confirm": "您确定要退出「{name}」吗?您将无法访问此搜索空间中的所有文档和对话。", - "leaving": "退出中...", "welcome_title": "欢迎使用 SurfSense", "welcome_description": "创建您的第一个搜索空间,开始组织知识、连接数据源并与AI对话。", "create_first_button": "创建第一个搜索空间" @@ -108,17 +104,6 @@ "title": "用户设置", "description": "管理您的账户设置和API访问", "back_to_app": "返回应用", - "profile_nav_label": "个人资料", - "profile_nav_description": "管理您的显示名称和头像", - "profile_title": "个人资料", - "profile_description": "更新您的个人信息", - "profile_avatar": "个人头像", - "profile_display_name": "显示名称", - "profile_display_name_hint": "这是您的名称在应用中的显示方式", - "profile_email": "电子邮件", - "profile_save": "保存更改", - "profile_saved": "个人资料已成功更新", - "profile_save_error": "无法更新个人资料", "api_key_nav_label": "API密钥", "api_key_nav_description": "管理您的API访问令牌", "api_key_title": "API密钥", @@ -475,7 +460,9 @@ "title": "设置", "subtitle": "管理此搜索空间的 LLM 配置和角色分配。", "back_to_dashboard": "返回仪表盘", + "model_configs": "模型配置", "models": "模型", + "llm_roles": "LLM 角色", "roles": "角色", "llm_role_management": "LLM 角色管理", "llm_role_desc": "为不同用途分配您的 LLM 配置到特定角色。", @@ -740,9 +727,14 @@ "back_to_app": "返回应用", "nav_general": "常规", "nav_general_desc": "名称、描述和基本信息", - "nav_models": "模型", - "nav_agent_models": "聊天模型", + "nav_agent_models": "代理模型", "nav_agent_models_desc": "LLM 模型配置提示词和引用", + "nav_role_assignments": "角色分配", + "nav_role_assignments_desc": "为代理角色分配配置", + "nav_image_models": "图像模型", + "nav_image_models_desc": "配置图像生成模型", + "nav_vision_models": "视觉模型", + "nav_vision_models_desc": "配置具有视觉能力的LLM模型", "nav_system_instructions": "系统指令", "nav_system_instructions_desc": "搜索空间级别的 AI 指令", "nav_public_links": "公开聊天链接", @@ -758,27 +750,7 @@ "general_reset": "重置更改", "general_save": "保存更改", "general_saving": "保存中...", - "general_unsaved_changes": "您有未保存的更改。点击\"保存更改\"以应用它们。", - "nav_web_search": "网页搜索", - "nav_web_search_desc": "内置网页搜索设置", - "web_search_title": "网页搜索", - "web_search_description": "网页搜索由内置 SearXNG 实例提供支持。所有查询都通过您的服务器代理;不会向第三方发送数据。", - "web_search_enabled_label": "启用网页搜索", - "web_search_enabled_description": "启用后,AI 代理可以搜索网页以获取新闻、价格和当前事件等实时信息。", - "web_search_status_healthy": "网页搜索服务运行正常", - "web_search_status_unhealthy": "网页搜索服务不可用", - "web_search_status_not_configured": "网页搜索服务未配置", - "web_search_engines_label": "搜索引擎", - "web_search_engines_placeholder": "google,brave,duckduckgo", - "web_search_engines_description": "要使用的 SearXNG 引擎的逗号分隔列表。留空则使用默认值。", - "web_search_language_label": "首选语言", - "web_search_language_placeholder": "zh", - "web_search_language_description": "IETF 语言标签(例如 zh、zh-CN)。留空则自动检测。", - "web_search_safesearch_label": "SafeSearch 级别", - "web_search_safesearch_description": "0 = 关闭,1 = 中等,2 = 严格", - "web_search_save": "保存网页搜索设置", - "web_search_saving": "保存中...", - "web_search_saved": "网页搜索设置已保存" + "general_unsaved_changes": "您有未保存的更改。点击\"保存更改\"以应用它们。" }, "homepage": { "hero_title_part1": "AI 工作空间", diff --git a/surfsense_web/package.json b/surfsense_web/package.json index 0f4d2ca33..2e999b42c 100644 --- a/surfsense_web/package.json +++ b/surfsense_web/package.json @@ -1,6 +1,6 @@ { "name": "surfsense_web", - "version": "0.0.29", + "version": "0.0.27", "private": true, "packageManager": "pnpm@10.26.0", "description": "SurfSense Frontend", @@ -31,8 +31,8 @@ "dependencies": { "@ai-sdk/react": "^1.2.12", "@ariakit/react": "^0.4.21", - "@assistant-ui/react": "^0.14.14", - "@assistant-ui/react-markdown": "^0.14.1", + "@assistant-ui/react": "^0.12.19", + "@assistant-ui/react-markdown": "^0.12.6", "@babel/standalone": "^7.29.2", "@hookform/resolvers": "^5.2.2", "@marsidev/react-turnstile": "^1.5.0", diff --git a/surfsense_web/playwright.config.ts b/surfsense_web/playwright.config.ts index 330ddd83b..ef066a9be 100644 --- a/surfsense_web/playwright.config.ts +++ b/surfsense_web/playwright.config.ts @@ -2,18 +2,12 @@ import { defineConfig, devices } from "@playwright/test"; const PORT = process.env.PORT || "3000"; const BACKEND_PORT = process.env.BACKEND_PORT || "8000"; -const ZERO_CACHE_PORT = process.env.ZERO_CACHE_PORT || "4848"; const baseURL = process.env.PLAYWRIGHT_BASE_URL || `http://localhost:${PORT}`; -const useProxyOrigin = process.env.PLAYWRIGHT_USE_PROXY_ORIGIN === "true"; -const backendURL = useProxyOrigin ? baseURL : `http://localhost:${BACKEND_PORT}`; -const zeroCacheURL = useProxyOrigin ? `${baseURL}/zero` : `http://localhost:${ZERO_CACHE_PORT}`; process.env.PLAYWRIGHT_TEST_EMAIL ??= "e2e-test@surfsense.net"; process.env.PLAYWRIGHT_TEST_PASSWORD ??= "E2eTestPassword123!"; -process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL ??= backendURL; -process.env.SURFSENSE_BACKEND_INTERNAL_URL ??= backendURL; -process.env.AUTH_TYPE ??= "LOCAL"; -process.env.NEXT_PUBLIC_ZERO_CACHE_URL ??= zeroCacheURL; +process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL ??= `http://localhost:${BACKEND_PORT}`; +process.env.NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE ??= "LOCAL"; /** * Playwright configuration for SurfSense web E2E tests. @@ -73,9 +67,7 @@ export default defineConfig({ stderr: "pipe", env: { NEXT_PUBLIC_FASTAPI_BACKEND_URL: process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL, - SURFSENSE_BACKEND_INTERNAL_URL: process.env.SURFSENSE_BACKEND_INTERNAL_URL, - AUTH_TYPE: process.env.AUTH_TYPE, - NEXT_PUBLIC_ZERO_CACHE_URL: process.env.NEXT_PUBLIC_ZERO_CACHE_URL, + NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE: process.env.NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE, }, }, }); diff --git a/surfsense_web/pnpm-lock.yaml b/surfsense_web/pnpm-lock.yaml index 4a5b0b5d0..652eff8f5 100644 --- a/surfsense_web/pnpm-lock.yaml +++ b/surfsense_web/pnpm-lock.yaml @@ -15,11 +15,11 @@ importers: specifier: ^0.4.21 version: 0.4.21(react-dom@19.2.4(react@19.2.4))(react@19.2.4) '@assistant-ui/react': - specifier: ^0.14.14 - version: 0.14.14(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(immer@10.2.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4)) + specifier: ^0.12.19 + version: 0.12.19(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(immer@10.2.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4)) '@assistant-ui/react-markdown': - specifier: ^0.14.1 - version: 0.14.1(@assistant-ui/react@0.14.14(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(immer@10.2.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4)))(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) + specifier: ^0.12.6 + version: 0.12.6(@assistant-ui/react@0.12.19(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(immer@10.2.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4)))(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) '@babel/standalone': specifier: ^7.29.2 version: 7.29.2 @@ -498,13 +498,13 @@ packages: react: ^17.0.0 || ^18.0.0 || ^19.0.0 react-dom: ^17.0.0 || ^18.0.0 || ^19.0.0 - '@assistant-ui/core@0.2.10': - resolution: {integrity: sha512-0YyqlpZgg1Hoaq2X4jHAaMKXg+lGniLygNt1KrGFTPgbxeo8ZStRjWyyG2xIl+zlFKHiKCGHzflUHvlJi4IurA==} + '@assistant-ui/core@0.1.7': + resolution: {integrity: sha512-219T42ihVOicbJXZLWgD2CW5Bylg9Nk7geC331X4RfJxTDYlm2zIjViGlGaqfj6URXBp6kMulO2BTUrHGmAvdw==} peerDependencies: - '@assistant-ui/store': ^0.2.13 - '@assistant-ui/tap': ^0.5.14 + '@assistant-ui/store': ^0.2.3 + '@assistant-ui/tap': ^0.5.3 '@types/react': '*' - assistant-cloud: ^0.1.31 + assistant-cloud: ^0.1.22 react: ^18 || ^19 zustand: ^5.0.11 peerDependenciesMeta: @@ -517,18 +517,18 @@ packages: zustand: optional: true - '@assistant-ui/react-markdown@0.14.1': - resolution: {integrity: sha512-Q1S66rLS0J+b7jUjKrPGryLZsdg8v9NX/QdSTRmOCi5H6smWHfgMYvDypQ4BHn+4Tc+m3ggLKFPCgBV6t6iLhQ==} + '@assistant-ui/react-markdown@0.12.6': + resolution: {integrity: sha512-utJqsdDXB3UVZfOa3ErLpaTHraeXkDshR0D34shWdTHrmLyx9e/HypTu4+BgiSsxS+ME6t9WO9M3VeGDprfUcQ==} peerDependencies: - '@assistant-ui/react': ^0.14.8 + '@assistant-ui/react': ^0.12.19 '@types/react': '*' react: ^18 || ^19 peerDependenciesMeta: '@types/react': optional: true - '@assistant-ui/react@0.14.14': - resolution: {integrity: sha512-qS7YJewwFbmhs+yte56ZnO9jIOK+8hKo7mOK3cKDcCndn+jGSWTJmoNVIYQgMpB2JYIJ/SKZD+LeWSR6K3LL5g==} + '@assistant-ui/react@0.12.19': + resolution: {integrity: sha512-scAf0o8cwjuHT9Y44EFGXcE2y6BSmpeMvt0NxOn8+Y/HBlNttQMLNvrM0p2AjacXCUufagiafAnWybzBV3nKEQ==} peerDependencies: '@types/react': '*' '@types/react-dom': '*' @@ -540,18 +540,18 @@ packages: '@types/react-dom': optional: true - '@assistant-ui/store@0.2.13': - resolution: {integrity: sha512-7NL6HWMBxe1ndLWO4kHkjQ0Syyc0D/Aj+zxdpcy4yrplG71X04CzFimMBBSQAk+AnGBf+d96D7cuUZdjHkTavg==} + '@assistant-ui/store@0.2.3': + resolution: {integrity: sha512-daStbgSQiX7+csqK6Cvo7A8p8UZkTCSMxBHxbhJvwrlVbp7BRJWTxq3U3rpTkSGIar23SXIyVRRfXU8VW7pswA==} peerDependencies: - '@assistant-ui/tap': ^0.5.14 + '@assistant-ui/tap': ^0.5.3 '@types/react': '*' react: ^18 || ^19 peerDependenciesMeta: '@types/react': optional: true - '@assistant-ui/tap@0.5.14': - resolution: {integrity: sha512-SAy0ip8nKo72U8K9MuU7gYUR4tzoIi6k+HAQgev3zA/sWN7hr/QDDUTblrn5QB9Y/yycRiq8s98WD1vnDy8WMQ==} + '@assistant-ui/tap@0.5.3': + resolution: {integrity: sha512-wy06ksqF2LfFxe4JXy31Ns89N/be1Dy3c+mG363cFHFp3CbLkRu8CrCN2SQSgCkXt628E+D8QyzqdBcl9kD4NQ==} peerDependencies: '@types/react': '*' react: ^18 || ^19 @@ -5170,19 +5170,11 @@ packages: resolution: {integrity: sha512-BNoCY6SXXPQ7gF2opIP4GBE+Xw7U+pHMYKuzjgCN3GwiaIR09UUeKfheyIry77QtrCBlC0KK0q5/TER/tYh3PQ==} engines: {node: '>= 0.4'} - assistant-cloud@0.1.31: - resolution: {integrity: sha512-YBLc79w2EFD/6YjvcZrperpZ+B3TQ9LZ39AbjfcnbIJiSXYAs8cDH+mgy1GrfJBq47nhGaTVEf7ajv+hk084eA==} + assistant-cloud@0.1.22: + resolution: {integrity: sha512-AEE9shV+oFrGDv/MRTRERctNKpIYS0n34UpAQXXICiOkSWD6QZnS1ljLqruFko7fJoT5CIWq8dNeJWdzQLTBLg==} - assistant-stream@0.3.20: - resolution: {integrity: sha512-CniC84epmE9JrMSDzlZVWJ13O5rYbjoqEzh0jT+QfsrR07LBls42DMJ60XNxKXm8Hrn6MHSZcxqBUqwXRtoutA==} - peerDependencies: - ioredis: ^5.10.1 - redis: ^5.12.1 - peerDependenciesMeta: - ioredis: - optional: true - redis: - optional: true + assistant-stream@0.3.6: + resolution: {integrity: sha512-NdtSRrQfWCDA/aqQ1xhobf/xnhuMZkhFAw9xzAt5iAoL3ouxVXOowSRN87OL4MYBQEvqtcjw9/CE6YcsXoBtuw==} ast-types-flow@0.0.8: resolution: {integrity: sha512-OH/2E5Fg20h2aPrbe+QL8JZQFko0YZaF+j4mnQ7BGhfavO7OpSLa8a0y9sBwomHdSbkhTS8TQNayBfnW5DwbvQ==} @@ -8130,9 +8122,6 @@ packages: safe-buffer@5.2.1: resolution: {integrity: sha512-rp3So07KcdmmKbGvgaNxQSJr7bGVSVk5S9Eq1F+ppbRo70+YeaDxkw5Dd8NPN+GD6bjnYm2VuPuCXmpuYvmCXQ==} - safe-content-frame@0.0.20: - resolution: {integrity: sha512-saE3fBeGWOsi04PzTUaRi6RsBIjDYrZX4KzgIZUjbq3xQeOKYMcW1DeTb573Zyx1ggCDVJKoD/THchblISwjiQ==} - safe-push-apply@1.0.0: resolution: {integrity: sha512-iKE9w/Z7xCzUMIZqdBsp6pEQvwuEebH4vdpjcDWnyzaI6yl6O9FHvVpmGelvEHNsoY6wGblkxR6Zty/h00WiSA==} engines: {node: '>= 0.4'} @@ -8905,9 +8894,6 @@ packages: zod@4.3.6: resolution: {integrity: sha512-rftlrkhHZOcjDwkGlnUtZZkvaPHCsDATp4pGpuOOMDaTdDDXF91wuVDJoWoPsKX/3YPQ5fHuF3STjcYyKr+Qhg==} - zod@4.4.3: - resolution: {integrity: sha512-ytENFjIJFl2UwYglde2jchW2Hwm4GJFLDiSXWdTrJQBIN9Fcyp7n4DhxJEiWNAJMV1/BqWfW/kkg71UDcHJyTQ==} - zustand-x@6.2.1: resolution: {integrity: sha512-y3nQMQNx3BORY95vpuodJvh/8AqQu++S3q6mJYBSo1J0Q168Sy+FatqER658YESDqv2bwviXcIT3bgl/Ip6M5g==} peerDependencies: @@ -8931,24 +8917,6 @@ packages: use-sync-external-store: optional: true - zustand@5.0.14: - resolution: {integrity: sha512-/8tAspM5LMPr28b3fwLYrtdj77ECpfZviaP75CMTnwO8ISyaE4GDIG/9rDDYq/cH9D2Xw2A2RXglLInmVBQB/g==} - engines: {node: '>=12.20.0'} - peerDependencies: - '@types/react': '>=18.0.0' - immer: '>=9.0.6' - react: '>=18.0.0' - use-sync-external-store: '>=1.2.0' - peerDependenciesMeta: - '@types/react': - optional: true - immer: - optional: true - react: - optional: true - use-sync-external-store: - optional: true - zwitch@2.0.4: resolution: {integrity: sha512-bXE4cR/kVZhKZX/RjPEflHaKVhUVl85noU3v6b8apfQEc1x4A+zBxjZ4lN8LqGd6WZ3dl98pY4o717VFmoPp+A==} @@ -9000,24 +8968,21 @@ snapshots: react: 19.2.4 react-dom: 19.2.4(react@19.2.4) - '@assistant-ui/core@0.2.10(@assistant-ui/store@0.2.13(@assistant-ui/tap@0.5.14(@types/react@19.2.14)(react@19.2.4))(@types/react@19.2.14)(react@19.2.4))(@assistant-ui/tap@0.5.14(@types/react@19.2.14)(react@19.2.4))(@types/react@19.2.14)(assistant-cloud@0.1.31)(react@19.2.4)(zustand@5.0.14(@types/react@19.2.14)(immer@10.2.0)(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4)))': + '@assistant-ui/core@0.1.7(@assistant-ui/store@0.2.3(@assistant-ui/tap@0.5.3(@types/react@19.2.14)(react@19.2.4))(@types/react@19.2.14)(react@19.2.4))(@assistant-ui/tap@0.5.3(@types/react@19.2.14)(react@19.2.4))(@types/react@19.2.14)(assistant-cloud@0.1.22)(react@19.2.4)(zustand@5.0.11(@types/react@19.2.14)(immer@10.2.0)(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4)))': dependencies: - '@assistant-ui/store': 0.2.13(@assistant-ui/tap@0.5.14(@types/react@19.2.14)(react@19.2.4))(@types/react@19.2.14)(react@19.2.4) - '@assistant-ui/tap': 0.5.14(@types/react@19.2.14)(react@19.2.4) - assistant-stream: 0.3.20 - nanoid: 5.1.11 + '@assistant-ui/store': 0.2.3(@assistant-ui/tap@0.5.3(@types/react@19.2.14)(react@19.2.4))(@types/react@19.2.14)(react@19.2.4) + '@assistant-ui/tap': 0.5.3(@types/react@19.2.14)(react@19.2.4) + assistant-stream: 0.3.6 + nanoid: 5.1.7 optionalDependencies: '@types/react': 19.2.14 - assistant-cloud: 0.1.31 + assistant-cloud: 0.1.22 react: 19.2.4 - zustand: 5.0.14(@types/react@19.2.14)(immer@10.2.0)(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4)) - transitivePeerDependencies: - - ioredis - - redis + zustand: 5.0.11(@types/react@19.2.14)(immer@10.2.0)(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4)) - '@assistant-ui/react-markdown@0.14.1(@assistant-ui/react@0.14.14(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(immer@10.2.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4)))(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)': + '@assistant-ui/react-markdown@0.12.6(@assistant-ui/react@0.12.19(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(immer@10.2.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4)))(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)': dependencies: - '@assistant-ui/react': 0.14.14(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(immer@10.2.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4)) + '@assistant-ui/react': 0.12.19(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(immer@10.2.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4)) '@radix-ui/react-primitive': 2.1.4(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) '@radix-ui/react-use-callback-ref': 1.1.1(@types/react@19.2.14)(react@19.2.4) classnames: 2.5.1 @@ -9030,45 +8995,42 @@ snapshots: - react-dom - supports-color - '@assistant-ui/react@0.14.14(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(immer@10.2.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4))': + '@assistant-ui/react@0.12.19(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(immer@10.2.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4))': dependencies: - '@assistant-ui/core': 0.2.10(@assistant-ui/store@0.2.13(@assistant-ui/tap@0.5.14(@types/react@19.2.14)(react@19.2.4))(@types/react@19.2.14)(react@19.2.4))(@assistant-ui/tap@0.5.14(@types/react@19.2.14)(react@19.2.4))(@types/react@19.2.14)(assistant-cloud@0.1.31)(react@19.2.4)(zustand@5.0.14(@types/react@19.2.14)(immer@10.2.0)(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4))) - '@assistant-ui/store': 0.2.13(@assistant-ui/tap@0.5.14(@types/react@19.2.14)(react@19.2.4))(@types/react@19.2.14)(react@19.2.4) - '@assistant-ui/tap': 0.5.14(@types/react@19.2.14)(react@19.2.4) + '@assistant-ui/core': 0.1.7(@assistant-ui/store@0.2.3(@assistant-ui/tap@0.5.3(@types/react@19.2.14)(react@19.2.4))(@types/react@19.2.14)(react@19.2.4))(@assistant-ui/tap@0.5.3(@types/react@19.2.14)(react@19.2.4))(@types/react@19.2.14)(assistant-cloud@0.1.22)(react@19.2.4)(zustand@5.0.11(@types/react@19.2.14)(immer@10.2.0)(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4))) + '@assistant-ui/store': 0.2.3(@assistant-ui/tap@0.5.3(@types/react@19.2.14)(react@19.2.4))(@types/react@19.2.14)(react@19.2.4) + '@assistant-ui/tap': 0.5.3(@types/react@19.2.14)(react@19.2.4) '@radix-ui/primitive': 1.1.3 '@radix-ui/react-compose-refs': 1.1.2(@types/react@19.2.14)(react@19.2.4) '@radix-ui/react-context': 1.1.3(@types/react@19.2.14)(react@19.2.4) '@radix-ui/react-primitive': 2.1.4(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) '@radix-ui/react-use-callback-ref': 1.1.1(@types/react@19.2.14)(react@19.2.4) '@radix-ui/react-use-escape-keydown': 1.1.1(@types/react@19.2.14)(react@19.2.4) - assistant-cloud: 0.1.31 - assistant-stream: 0.3.20 - nanoid: 5.1.11 + assistant-cloud: 0.1.22 + assistant-stream: 0.3.6 + nanoid: 5.1.7 radix-ui: 1.4.3(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) react: 19.2.4 react-dom: 19.2.4(react@19.2.4) react-textarea-autosize: 8.5.9(@types/react@19.2.14)(react@19.2.4) - safe-content-frame: 0.0.20 - zod: 4.4.3 - zustand: 5.0.14(@types/react@19.2.14)(immer@10.2.0)(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4)) + zod: 4.3.6 + zustand: 5.0.11(@types/react@19.2.14)(immer@10.2.0)(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4)) optionalDependencies: '@types/react': 19.2.14 '@types/react-dom': 19.2.3(@types/react@19.2.14) transitivePeerDependencies: - immer - - ioredis - - redis - use-sync-external-store - '@assistant-ui/store@0.2.13(@assistant-ui/tap@0.5.14(@types/react@19.2.14)(react@19.2.4))(@types/react@19.2.14)(react@19.2.4)': + '@assistant-ui/store@0.2.3(@assistant-ui/tap@0.5.3(@types/react@19.2.14)(react@19.2.4))(@types/react@19.2.14)(react@19.2.4)': dependencies: - '@assistant-ui/tap': 0.5.14(@types/react@19.2.14)(react@19.2.4) + '@assistant-ui/tap': 0.5.3(@types/react@19.2.14)(react@19.2.4) react: 19.2.4 use-effect-event: 2.0.3(react@19.2.4) optionalDependencies: '@types/react': 19.2.14 - '@assistant-ui/tap@0.5.14(@types/react@19.2.14)(react@19.2.4)': + '@assistant-ui/tap@0.5.3(@types/react@19.2.14)(react@19.2.4)': optionalDependencies: '@types/react': 19.2.14 react: 19.2.4 @@ -13829,17 +13791,14 @@ snapshots: get-intrinsic: 1.3.0 is-array-buffer: 3.0.5 - assistant-cloud@0.1.31: + assistant-cloud@0.1.22: dependencies: - assistant-stream: 0.3.20 - transitivePeerDependencies: - - ioredis - - redis + assistant-stream: 0.3.6 - assistant-stream@0.3.20: + assistant-stream@0.3.6: dependencies: '@standard-schema/spec': 1.1.0 - nanoid: 5.1.11 + nanoid: 5.1.7 secure-json-parse: 4.1.0 ast-types-flow@0.0.8: {} @@ -17494,8 +17453,6 @@ snapshots: safe-buffer@5.2.1: {} - safe-content-frame@0.0.20: {} - safe-push-apply@1.0.0: dependencies: es-errors: 1.3.0 @@ -18362,8 +18319,6 @@ snapshots: zod@4.3.6: {} - zod@4.4.3: {} - zustand-x@6.2.1(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(scheduler@0.27.0)(zustand@5.0.11(@types/react@19.2.14)(immer@10.2.0)(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4))): dependencies: immer: 10.2.0 @@ -18385,11 +18340,4 @@ snapshots: react: 19.2.4 use-sync-external-store: 1.6.0(react@19.2.4) - zustand@5.0.14(@types/react@19.2.14)(immer@10.2.0)(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4)): - optionalDependencies: - '@types/react': 19.2.14 - immer: 10.2.0 - react: 19.2.4 - use-sync-external-store: 1.6.0(react@19.2.4) - zwitch@2.0.4: {} diff --git a/surfsense_web/proxy.ts b/surfsense_web/proxy.ts deleted file mode 100644 index b53ce68a7..000000000 --- a/surfsense_web/proxy.ts +++ /dev/null @@ -1,24 +0,0 @@ -import { NextResponse, type NextRequest } from "next/server"; -import { BUILD_TIME_AUTH_TYPE } from "@/lib/env-config"; -import { - RUNTIME_AUTH_TYPE_COOKIE_NAME, - resolveRuntimeAuthUiMode, -} from "@/lib/runtime-auth-config"; - -export function proxy(request: NextRequest) { - const response = NextResponse.next(); - const authType = resolveRuntimeAuthUiMode(process.env.AUTH_TYPE, BUILD_TIME_AUTH_TYPE); - - response.cookies.set(RUNTIME_AUTH_TYPE_COOKIE_NAME, authType, { - path: "/", - maxAge: 60 * 60 * 24 * 365, - sameSite: "lax", - secure: request.nextUrl.protocol === "https:", - }); - - return response; -} - -export const config = { - matcher: ["/((?!api|auth|_next/static|_next/image|favicon.ico|.*\\..*).*)"], -}; diff --git a/surfsense_web/zero/queries/index.ts b/surfsense_web/zero/queries/index.ts index 45df8fa98..fe711f5d3 100644 --- a/surfsense_web/zero/queries/index.ts +++ b/surfsense_web/zero/queries/index.ts @@ -4,7 +4,6 @@ import { chatSessionQueries, commentQueries, messageQueries } from "./chat"; import { connectorQueries, documentQueries } from "./documents"; import { folderQueries } from "./folders"; import { notificationQueries } from "./inbox"; -import { podcastQueries } from "./podcasts"; import { userQueries } from "./user"; export const queries = defineQueries({ @@ -17,5 +16,4 @@ export const queries = defineQueries({ chatSession: chatSessionQueries, user: userQueries, automationRuns: automationRunQueries, - podcasts: podcastQueries, }); diff --git a/surfsense_web/zero/queries/podcasts.ts b/surfsense_web/zero/queries/podcasts.ts deleted file mode 100644 index 5298534dd..000000000 --- a/surfsense_web/zero/queries/podcasts.ts +++ /dev/null @@ -1,12 +0,0 @@ -import { defineQuery } from "@rocicorp/zero"; -import { z } from "zod"; -import { zql } from "../schema/index"; - -export const podcastQueries = { - bySpace: defineQuery(z.object({ searchSpaceId: z.number() }), ({ args: { searchSpaceId } }) => - zql.podcasts.where("searchSpaceId", searchSpaceId).orderBy("createdAt", "desc") - ), - byId: defineQuery(z.object({ podcastId: z.number() }), ({ args: { podcastId } }) => - zql.podcasts.where("id", podcastId).one() - ), -}; diff --git a/surfsense_web/zero/schema/index.ts b/surfsense_web/zero/schema/index.ts index d1187ddab..d6731e371 100644 --- a/surfsense_web/zero/schema/index.ts +++ b/surfsense_web/zero/schema/index.ts @@ -4,7 +4,6 @@ import { chatCommentTable, chatSessionStateTable, newChatMessageTable } from "./ import { documentTable, searchSourceConnectorTable } from "./documents"; import { folderTable } from "./folders"; import { notificationTable } from "./inbox"; -import { podcastTable } from "./podcasts"; import { userTable } from "./user"; const chatCommentRelationships = relationships(chatCommentTable, ({ one }) => ({ @@ -39,7 +38,6 @@ export const schema = createSchema({ chatSessionStateTable, userTable, automationRunTable, - podcastTable, ], relationships: [chatCommentRelationships, newChatMessageRelationships], }); diff --git a/surfsense_web/zero/schema/podcasts.ts b/surfsense_web/zero/schema/podcasts.ts deleted file mode 100644 index d473d776f..000000000 --- a/surfsense_web/zero/schema/podcasts.ts +++ /dev/null @@ -1,19 +0,0 @@ -import { json, number, string, table } from "@rocicorp/zero"; - -// Mirrors PODCAST_COLS in the backend zero_publication. status drives the -// lifecycle UI by push; spec is the reviewable brief. The bulky source_content -// and transcript are intentionally not published and are fetched over REST. -export const podcastTable = table("podcasts") - .columns({ - id: number(), - title: string(), - status: string(), - spec: json().optional(), - specVersion: number().from("spec_version"), - durationSeconds: number().optional().from("duration_seconds"), - error: string().optional(), - searchSpaceId: number().from("search_space_id"), - threadId: number().optional().from("thread_id"), - createdAt: number().from("created_at"), - }) - .primaryKey("id"); diff --git a/surfsense_web/zero/schema/user.ts b/surfsense_web/zero/schema/user.ts index 3b6c3ec92..f483fa9b4 100644 --- a/surfsense_web/zero/schema/user.ts +++ b/surfsense_web/zero/schema/user.ts @@ -3,16 +3,18 @@ import { number, string, table } from "@rocicorp/zero"; /** * Live-meter slice of the ``user`` table replicated through Zero. * - * ``creditMicrosBalance`` is stored as integer micro-USD (1_000_000 == $1.00); - * UI consumers divide by 1M when displaying and clamp at $0.00 (the balance can - * dip slightly negative when actual cost exceeds the pre-charge estimate). - * Sensitive fields (email, hashed_password, oauth, etc.) are intentionally - * omitted via the Postgres column-list publication so they never enter WAL - * replication. + * ``premiumCreditMicrosLimit`` / ``premiumCreditMicrosUsed`` are stored + * as integer micro-USD (1_000_000 == $1.00). UI consumers divide by 1M + * when displaying. Sensitive fields (email, hashed_password, oauth, etc.) + * are intentionally omitted via the Postgres column-list publication so + * they never enter WAL replication. */ export const userTable = table("user") .columns({ id: string(), - creditMicrosBalance: number().from("credit_micros_balance"), + pagesLimit: number().from("pages_limit"), + pagesUsed: number().from("pages_used"), + premiumCreditMicrosLimit: number().from("premium_credit_micros_limit"), + premiumCreditMicrosUsed: number().from("premium_credit_micros_used"), }) .primaryKey("id");