diff --git a/.github/workflows/desktop-release.yml b/.github/workflows/desktop-release.yml index 7336fa9bd..4fcc8c597 100644 --- a/.github/workflows/desktop-release.yml +++ b/.github/workflows/desktop-release.yml @@ -95,12 +95,10 @@ jobs: run: pnpm build working-directory: surfsense_web env: - NEXT_PUBLIC_FASTAPI_BACKEND_URL: ${{ vars.HOSTED_BACKEND_URL }} - SURFSENSE_BACKEND_INTERNAL_URL: ${{ vars.HOSTED_BACKEND_URL }} + NEXT_PUBLIC_FASTAPI_BACKEND_URL: ${{ vars.NEXT_PUBLIC_FASTAPI_BACKEND_URL }} NEXT_PUBLIC_ZERO_CACHE_URL: ${{ vars.NEXT_PUBLIC_ZERO_CACHE_URL }} NEXT_PUBLIC_DEPLOYMENT_MODE: ${{ vars.NEXT_PUBLIC_DEPLOYMENT_MODE }} - NEXT_PUBLIC_AUTH_TYPE: ${{ vars.NEXT_PUBLIC_AUTH_TYPE }} - NEXT_PUBLIC_ETL_SERVICE: ${{ vars.NEXT_PUBLIC_ETL_SERVICE }} + NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE: ${{ vars.NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE }} NEXT_PUBLIC_POSTHOG_KEY: ${{ secrets.NEXT_PUBLIC_POSTHOG_KEY }} - name: Install desktop dependencies @@ -111,7 +109,6 @@ jobs: run: pnpm build working-directory: surfsense_desktop env: - HOSTED_BACKEND_URL: ${{ vars.HOSTED_BACKEND_URL }} HOSTED_FRONTEND_URL: ${{ vars.HOSTED_FRONTEND_URL }} POSTHOG_KEY: ${{ secrets.POSTHOG_KEY }} POSTHOG_HOST: ${{ vars.POSTHOG_HOST }} diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index f82c4d609..8da56fc33 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -199,6 +199,11 @@ jobs: build-args: | ${{ matrix.image == 'backend' && format('USE_CUDA={0}', matrix.use_cuda) || '' }} ${{ matrix.image == 'backend' && format('CUDA_EXTRA={0}', matrix.cuda_extra) || '' }} + ${{ matrix.image == 'web' && 'NEXT_PUBLIC_FASTAPI_BACKEND_URL=__NEXT_PUBLIC_FASTAPI_BACKEND_URL__' || '' }} + ${{ matrix.image == 'web' && 'NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=__NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE__' || '' }} + ${{ matrix.image == 'web' && 'NEXT_PUBLIC_ETL_SERVICE=__NEXT_PUBLIC_ETL_SERVICE__' || '' }} + ${{ matrix.image == 'web' && 'NEXT_PUBLIC_ZERO_CACHE_URL=__NEXT_PUBLIC_ZERO_CACHE_URL__' || '' }} + ${{ matrix.image == 'web' && 'NEXT_PUBLIC_DEPLOYMENT_MODE=__NEXT_PUBLIC_DEPLOYMENT_MODE__' || '' }} - name: Export digest run: | diff --git a/.github/workflows/e2e-tests.yml b/.github/workflows/e2e-tests.yml index bd9cff13b..b87537dab 100644 --- a/.github/workflows/e2e-tests.yml +++ b/.github/workflows/e2e-tests.yml @@ -27,10 +27,9 @@ jobs: PLAYWRIGHT_TEST_EMAIL: e2e-test@surfsense.net PLAYWRIGHT_TEST_PASSWORD: E2eTestPassword123! # Frontend env: Playwright's webServer (surfsense_web/playwright.config.ts) - # spawns `pnpm build && pnpm start` in CI. + # spawns `pnpm build && pnpm start` in CI; these get baked into the build. NEXT_PUBLIC_FASTAPI_BACKEND_URL: http://localhost:8000 - SURFSENSE_BACKEND_INTERNAL_URL: http://localhost:8000 - AUTH_TYPE: LOCAL + NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE: LOCAL # Shared secret for the test-only POST /__e2e__/auth/token endpoint. # Must match docker-compose.e2e.yml's backend env (x-backend-env). E2E_MINT_SECRET: e2e-mint-secret-not-for-production diff --git a/VERSION b/VERSION index 369bd4c2a..1fe695856 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.0.29 +0.0.28 diff --git a/docker/.env.example b/docker/.env.example index 63308bc9e..54ca489b2 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -30,9 +30,6 @@ SECRET_KEY=replace_me_with_a_random_string # Auth type: LOCAL (email/password) or GOOGLE (OAuth) AUTH_TYPE=LOCAL -# Deployment mode: self-hosted enables local filesystem connectors; cloud hides them. -DEPLOYMENT_MODE=self-hosted - # Allow new user registrations (TRUE or FALSE) # REGISTRATION_ENABLED=TRUE @@ -46,47 +43,51 @@ ETL_SERVICE=DOCLING EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2 # ------------------------------------------------------------------------------ -# How You Access SurfSense +# Ports (change to avoid conflicts with other services on your machine) # ------------------------------------------------------------------------------ -# One public URL. Browser traffic stays same-origin and Caddy routes internally. -SURFSENSE_PUBLIC_URL=http://localhost:3929 + +# BACKEND_PORT=8929 +# FRONTEND_PORT=3929 +# ZERO_CACHE_PORT=5929 +# SEARXNG_PORT=8888 +# FLOWER_PORT=5555 + +# ============================================================================== +# DEV COMPOSE ONLY (docker-compose.dev.yml) +# You only need them only if you are running `docker-compose.dev.yml`. +# ============================================================================== + +# -- pgAdmin (database GUI) -- +# PGADMIN_PORT=5050 +# PGADMIN_DEFAULT_EMAIL=admin@surfsense.com +# PGADMIN_DEFAULT_PASSWORD=surfsense + +# -- Redis exposed port (dev only; Redis is internal-only in prod) -- +# REDIS_PORT=6379 + +# -- WhatsApp bridge exposed port (dev/hybrid only; prod keeps it Docker-internal) -- +# WHATSAPP_BRIDGE_PORT=9929 + +# -- Frontend Build Args -- +# In dev, the frontend is built from source and these are passed as build args. +# In prod, they are automatically derived from AUTH_TYPE, ETL_SERVICE, and the port settings above. +# NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=LOCAL +# NEXT_PUBLIC_ETL_SERVICE=DOCLING +# NEXT_PUBLIC_DEPLOYMENT_MODE=self-hosted # ------------------------------------------------------------------------------ -# Public Ports +# Custom Domain / Reverse Proxy # ------------------------------------------------------------------------------ -# Production Docker exposes only Caddy to your machine. Caddy then routes -# frontend, backend, and zero-cache traffic internally. +# ONLY set these if you are serving SurfSense on a real domain via a reverse +# proxy (e.g. Caddy, Nginx, Cloudflare Tunnel). +# For standard localhost deployments, leave all of these commented out. +# they are automatically derived from the port settings above. # -# Local default: LISTEN_HTTP_PORT=3929 -# Domain default: LISTEN_HTTP_PORT=80 and LISTEN_HTTPS_PORT=443 -LISTEN_HTTP_PORT=3929 -LISTEN_HTTPS_PORT=443 - -# ------------------------------------------------------------------------------ -# Custom Domain / HTTPS -# ------------------------------------------------------------------------------ -# Leave SURFSENSE_SITE_ADDRESS as :80 for local HTTP. -# Set it to your domain to enable automatic HTTPS: -# SURFSENSE_SITE_ADDRESS=surf.example.com -# CERT_EMAIL=you@example.com -SURFSENSE_SITE_ADDRESS=:80 -CERT_EMAIL= - -# ------------------------------------------------------------------------------ -# Advanced Reverse Proxy Settings -# ------------------------------------------------------------------------------ -# Usually do not change these. They are for custom certificate setup, CDNs/load -# balancers, trusted proxy IPs, or changing upload limits. -# -# CERT_ACME_CA=https://acme-v02.api.letsencrypt.org/directory -# CERT_ACME_DNS= -# If a CDN/load balancer sits in front of Caddy, narrow this to that proxy's CIDRs. -# TRUSTED_PROXIES=0.0.0.0/0 -# SURFSENSE_MAX_BODY_SIZE=5GB -# -# Browser API and Zero URLs are same-origin relative behind bundled Caddy. -# Next.js server-side calls use Docker DNS through SURFSENSE_BACKEND_INTERNAL_URL -# set internally by docker-compose.yml. Usually do not override it. +# NEXT_FRONTEND_URL=https://app.yourdomain.com +# BACKEND_URL=https://api.yourdomain.com +# NEXT_PUBLIC_FASTAPI_BACKEND_URL=https://api.yourdomain.com +# NEXT_PUBLIC_ZERO_CACHE_URL=https://zero.yourdomain.com +# FASTAPI_BACKEND_INTERNAL_URL=http://backend:8000 # ------------------------------------------------------------------------------ # Zero-cache (real-time sync) @@ -107,9 +108,10 @@ CERT_EMAIL= # Sync worker tuning. zero-cache defaults ZERO_NUM_SYNC_WORKERS to the number # of CPU cores, which can exceed the connection pool limits on high-core machines. -# Each sync worker needs at least 1 connection from both the UPSTREAM and CVR pools. -# Keep ZERO_UPSTREAM_MAX_CONNS and ZERO_CVR_MAX_CONNS greater than or equal to -# ZERO_NUM_SYNC_WORKERS. +# Each sync worker needs at least 1 connection from both the UPSTREAM and CVR +# pools, so these constraints must hold: +# ZERO_UPSTREAM_MAX_CONNS >= ZERO_NUM_SYNC_WORKERS +# ZERO_CVR_MAX_CONNS >= ZERO_NUM_SYNC_WORKERS # Default of 4 workers is sufficient for self-hosted / personal use. # ZERO_NUM_SYNC_WORKERS=4 # ZERO_UPSTREAM_MAX_CONNS=20 @@ -123,16 +125,16 @@ CERT_EMAIL= # ZERO_QUERY_URL: where zero-cache forwards query requests for resolution. # ZERO_MUTATE_URL: required by zero-cache when auth tokens are used, even though -# SurfSense does not use Zero mutators. Setting both URLs tells zero-cache to -# skip its own JWT verification and let the app endpoints handle auth instead. -# The mutate endpoint is a no-op that returns an empty response. +# SurfSense does not use Zero mutators. Setting both URLs tells zero-cache to +# skip its own JWT verification and let the app endpoints handle auth instead. +# The mutate endpoint is a no-op that returns an empty response. # Default: Docker service networking (http://frontend:3000/api/zero/...). # Override when running the frontend outside Docker: -# ZERO_QUERY_URL=http://host.docker.internal:3000/api/zero/query -# ZERO_MUTATE_URL=http://host.docker.internal:3000/api/zero/mutate -# Override for custom domain only when zero-cache is not in the bundled Docker network: -# ZERO_QUERY_URL=https://surf.example.com/api/zero/query -# ZERO_MUTATE_URL=https://surf.example.com/api/zero/mutate +# ZERO_QUERY_URL=http://host.docker.internal:3000/api/zero/query +# ZERO_MUTATE_URL=http://host.docker.internal:3000/api/zero/mutate +# Override for custom domain: +# ZERO_QUERY_URL=https://app.yourdomain.com/api/zero/query +# ZERO_MUTATE_URL=https://app.yourdomain.com/api/zero/mutate # ZERO_QUERY_URL=http://frontend:3000/api/zero/query # ZERO_MUTATE_URL=http://frontend:3000/api/zero/mutate @@ -220,74 +222,73 @@ STT_SERVICE=local/base # ------------------------------------------------------------------------------ # -- Google Connectors -- -# GOOGLE_CALENDAR_REDIRECT_URI=http://localhost:3929/api/v1/auth/google/calendar/connector/callback -# GOOGLE_GMAIL_REDIRECT_URI=http://localhost:3929/api/v1/auth/google/gmail/connector/callback -# GOOGLE_DRIVE_REDIRECT_URI=http://localhost:3929/api/v1/auth/google/drive/connector/callback +# GOOGLE_CALENDAR_REDIRECT_URI=http://localhost:8000/api/v1/auth/google/calendar/connector/callback +# GOOGLE_GMAIL_REDIRECT_URI=http://localhost:8000/api/v1/auth/google/gmail/connector/callback +# GOOGLE_DRIVE_REDIRECT_URI=http://localhost:8000/api/v1/auth/google/drive/connector/callback # -- Notion -- # NOTION_CLIENT_ID= # NOTION_CLIENT_SECRET= -# NOTION_REDIRECT_URI=http://localhost:3929/api/v1/auth/notion/connector/callback +# NOTION_REDIRECT_URI=http://localhost:8000/api/v1/auth/notion/connector/callback # -- Slack -- # SLACK_CLIENT_ID= # SLACK_CLIENT_SECRET= -# SLACK_REDIRECT_URI=http://localhost:3929/api/v1/auth/slack/connector/callback +# SLACK_REDIRECT_URI=http://localhost:8000/api/v1/auth/slack/connector/callback # -- Discord -- # DISCORD_CLIENT_ID= # DISCORD_CLIENT_SECRET= -# DISCORD_REDIRECT_URI=http://localhost:3929/api/v1/auth/discord/connector/callback +# DISCORD_REDIRECT_URI=http://localhost:8000/api/v1/auth/discord/connector/callback # DISCORD_BOT_TOKEN= # -- Atlassian (Jira & Confluence) -- # ATLASSIAN_CLIENT_ID= # ATLASSIAN_CLIENT_SECRET= -# JIRA_REDIRECT_URI=http://localhost:3929/api/v1/auth/jira/connector/callback -# CONFLUENCE_REDIRECT_URI=http://localhost:3929/api/v1/auth/confluence/connector/callback +# JIRA_REDIRECT_URI=http://localhost:8000/api/v1/auth/jira/connector/callback +# CONFLUENCE_REDIRECT_URI=http://localhost:8000/api/v1/auth/confluence/connector/callback # -- Linear -- # LINEAR_CLIENT_ID= # LINEAR_CLIENT_SECRET= -# LINEAR_REDIRECT_URI=http://localhost:3929/api/v1/auth/linear/connector/callback +# LINEAR_REDIRECT_URI=http://localhost:8000/api/v1/auth/linear/connector/callback # -- ClickUp -- # CLICKUP_CLIENT_ID= # CLICKUP_CLIENT_SECRET= -# CLICKUP_REDIRECT_URI=http://localhost:3929/api/v1/auth/clickup/connector/callback +# CLICKUP_REDIRECT_URI=http://localhost:8000/api/v1/auth/clickup/connector/callback # -- Airtable -- # AIRTABLE_CLIENT_ID= # AIRTABLE_CLIENT_SECRET= -# AIRTABLE_REDIRECT_URI=http://localhost:3929/api/v1/auth/airtable/connector/callback +# AIRTABLE_REDIRECT_URI=http://localhost:8000/api/v1/auth/airtable/connector/callback # -- Microsoft OAuth (Teams & OneDrive) -- # MICROSOFT_CLIENT_ID= # MICROSOFT_CLIENT_SECRET= -# TEAMS_REDIRECT_URI=http://localhost:3929/api/v1/auth/teams/connector/callback -# ONEDRIVE_REDIRECT_URI=http://localhost:3929/api/v1/auth/onedrive/connector/callback +# TEAMS_REDIRECT_URI=http://localhost:8000/api/v1/auth/teams/connector/callback +# ONEDRIVE_REDIRECT_URI=http://localhost:8000/api/v1/auth/onedrive/connector/callback # -- Dropbox -- # DROPBOX_APP_KEY= # DROPBOX_APP_SECRET= -# DROPBOX_REDIRECT_URI=http://localhost:3929/api/v1/auth/dropbox/connector/callback +# DROPBOX_REDIRECT_URI=http://localhost:8000/api/v1/auth/dropbox/connector/callback # -- Composio -- # COMPOSIO_API_KEY= # COMPOSIO_ENABLED=TRUE -# COMPOSIO_REDIRECT_URI=http://localhost:3929/api/v1/auth/composio/connector/callback +# COMPOSIO_REDIRECT_URI=http://localhost:8000/api/v1/auth/composio/connector/callback # ------------------------------------------------------------------------------ # Messaging Channels (optional) # ------------------------------------------------------------------------------ # Configure only the external chat channels you want to use. -# GATEWAY_ENABLED=TRUE # -- Telegram -- # TELEGRAM_SHARED_BOT_TOKEN= # TELEGRAM_SHARED_BOT_USERNAME= # TELEGRAM_WEBHOOK_SECRET= -# GATEWAY_BASE_URL=http://localhost:3929 +# GATEWAY_BASE_URL=http://localhost:8929 # GATEWAY_TELEGRAM_INTAKE_MODE=webhook # -- WhatsApp -- @@ -306,20 +307,20 @@ STT_SERVICE=local/base # # GATEWAY_SLACK_ENABLED=FALSE # GATEWAY_SLACK_SIGNING_SECRET= -# GATEWAY_SLACK_REDIRECT_URI=http://localhost:3929/api/v1/gateway/slack/callback +# GATEWAY_SLACK_REDIRECT_URI=http://localhost:8929/api/v1/gateway/slack/callback # -- Discord -- # Uses DISCORD_CLIENT_ID, DISCORD_CLIENT_SECRET, and DISCORD_BOT_TOKEN from the # Discord connector section. # # GATEWAY_DISCORD_ENABLED=FALSE -# GATEWAY_DISCORD_REDIRECT_URI=http://localhost:3929/api/v1/gateway/discord/callback +# GATEWAY_DISCORD_REDIRECT_URI=http://localhost:8929/api/v1/gateway/discord/callback # ------------------------------------------------------------------------------ # SearXNG (bundled web search, works out of the box with no config needed) # ------------------------------------------------------------------------------ # SearXNG provides web search to all search spaces automatically. -# To access the SearXNG UI directly in dev/deps-only compose: http://localhost:8888 +# To access the SearXNG UI directly: http://localhost:8888 # To disable the service entirely: docker compose up --scale searxng=0 # To point at your own SearXNG instance instead of the bundled one: # SEARXNG_DEFAULT_HOST=http://your-searxng:8080 @@ -456,36 +457,3 @@ NOLOGIN_MODE_ENABLED=FALSE # RESIDENTIAL_PROXY_HOSTNAME= # RESIDENTIAL_PROXY_LOCATION= # RESIDENTIAL_PROXY_TYPE=1 - -# ============================================================================== -# DEV / DEPS-ONLY COMPOSE OVERRIDES -# These are only needed for docker-compose.dev.yml or docker-compose.deps-only.yml. -# Production Docker exposes Caddy only; raw app ports below do not affect -# docker-compose.yml. -# ============================================================================== - -# -- pgAdmin (database GUI, dev/deps-only only) -- -# PGADMIN_PORT=5050 -# PGADMIN_DEFAULT_EMAIL=admin@surfsense.com -# PGADMIN_DEFAULT_PASSWORD=surfsense - -# -- Redis exposed port (dev/deps-only only; Redis is internal-only in prod) -- -# REDIS_PORT=6379 - -# -- SearXNG exposed port (dev/deps-only only; internal-only in prod) -- -# SEARXNG_PORT=8888 - -# -- WhatsApp bridge exposed port (dev/hybrid only; prod keeps it Docker-internal) -- -# WHATSAPP_BRIDGE_PORT=9929 - -# -- Raw app ports (dev/deps-only only; prod exposes Caddy instead) -- -# BACKEND_PORT=8000 -# FRONTEND_PORT=3000 -# ZERO_CACHE_PORT=4848 - -# -- Frontend runtime flags (prod and dev compose) -- -# The frontend reads these at request time in Docker; no NEXT_PUBLIC_* rebuild -# or startup substitution is required. -# AUTH_TYPE=LOCAL -# ETL_SERVICE=DOCLING -# DEPLOYMENT_MODE=self-hosted diff --git a/docker/docker-compose.dev.yml b/docker/docker-compose.dev.yml index 5b86ea888..35effefc0 100644 --- a/docker/docker-compose.dev.yml +++ b/docker/docker-compose.dev.yml @@ -106,7 +106,6 @@ services: volumes: - ../surfsense_backend/app:/app/app - shared_temp:/shared_tmp - - object_store:/app/.local_object_store env_file: - ../surfsense_backend/.env extra_hosts: @@ -120,7 +119,6 @@ services: - PYTHONPATH=/app - UVICORN_LOOP=asyncio - UNSTRUCTURED_HAS_PATCHED_LOOP=1 - - FILE_STORAGE_LOCAL_PATH=/app/.local_object_store - LANGCHAIN_TRACING_V2=false - LANGSMITH_TRACING=false - AUTH_TYPE=${AUTH_TYPE:-LOCAL} @@ -173,7 +171,6 @@ services: volumes: - ../surfsense_backend/app:/app/app - shared_temp:/shared_tmp - - object_store:/app/.local_object_store env_file: - ../surfsense_backend/.env extra_hosts: @@ -185,7 +182,6 @@ services: - REDIS_APP_URL=${REDIS_URL:-redis://redis:6379/0} - CELERY_TASK_DEFAULT_QUEUE=surfsense - PYTHONPATH=/app - - FILE_STORAGE_LOCAL_PATH=/app/.local_object_store - SEARXNG_DEFAULT_HOST=${SEARXNG_DEFAULT_HOST:-http://searxng:8080} - SERVICE_ROLE=worker depends_on: @@ -257,15 +253,16 @@ services: frontend: build: context: ../surfsense_web + args: + NEXT_PUBLIC_FASTAPI_BACKEND_URL: ${NEXT_PUBLIC_FASTAPI_BACKEND_URL:-http://localhost:8000} + NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE: ${NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE:-LOCAL} + NEXT_PUBLIC_ETL_SERVICE: ${NEXT_PUBLIC_ETL_SERVICE:-DOCLING} + NEXT_PUBLIC_ZERO_CACHE_URL: ${NEXT_PUBLIC_ZERO_CACHE_URL:-http://localhost:${ZERO_CACHE_PORT:-4848}} + NEXT_PUBLIC_DEPLOYMENT_MODE: ${NEXT_PUBLIC_DEPLOYMENT_MODE:-self-hosted} ports: - "${FRONTEND_PORT:-3000}:3000" env_file: - ../surfsense_web/.env - environment: - AUTH_TYPE: ${AUTH_TYPE:-LOCAL} - ETL_SERVICE: ${ETL_SERVICE:-DOCLING} - DEPLOYMENT_MODE: ${DEPLOYMENT_MODE:-self-hosted} - SURFSENSE_BACKEND_INTERNAL_URL: http://backend:8000 depends_on: backend: condition: service_healthy @@ -281,8 +278,6 @@ volumes: name: surfsense-dev-redis shared_temp: name: surfsense-dev-shared-temp - object_store: - name: surfsense-dev-object-store zero_cache_data: name: surfsense-dev-zero-cache whatsapp_sessions: diff --git a/docker/docker-compose.proxy.yml b/docker/docker-compose.proxy.yml deleted file mode 100644 index 1990f6db8..000000000 --- a/docker/docker-compose.proxy.yml +++ /dev/null @@ -1,54 +0,0 @@ -# ============================================================================= -# SurfSense — Optional Caddy reverse-proxy overlay -# ============================================================================= -# Usage (from docker/): -# PROXY_HTTP_PORT=8080 SURFSENSE_PUBLIC_URL=http://localhost:8080 \ -# docker compose -f docker-compose.yml -f docker-compose.proxy.yml up -d -# -# This overlay is for validation and custom deployments. The production -# docker-compose.yml includes Caddy by default. -# ============================================================================= - -services: - backend: - ports: - - "${BACKEND_PORT:-8929}:8000" - - zero-cache: - ports: - - "${ZERO_CACHE_PORT:-5929}:4848" - - frontend: - ports: - - "${FRONTEND_PORT:-3929}:3000" - - proxy: - image: caddy:2-alpine - restart: unless-stopped - ports: - - "${PROXY_HTTP_PORT:-8080}:80" - - "${PROXY_HTTPS_PORT:-8443}:443" - volumes: - - ./proxy/Caddyfile:/etc/caddy/Caddyfile:ro - - caddy_data:/data - - caddy_config:/config - environment: - SURFSENSE_SITE_ADDRESS: ${SURFSENSE_SITE_ADDRESS:-:80} - CERT_EMAIL: ${CERT_EMAIL:-} - CERT_ACME_CA: ${CERT_ACME_CA:-https://acme-v02.api.letsencrypt.org/directory} - CERT_ACME_DNS: ${CERT_ACME_DNS:-} - TRUSTED_PROXIES: ${TRUSTED_PROXIES:-0.0.0.0/0} - SURFSENSE_MAX_BODY_SIZE: ${SURFSENSE_MAX_BODY_SIZE:-5GB} - depends_on: - frontend: - condition: service_started - backend: - condition: service_healthy - zero-cache: - condition: service_healthy - -volumes: - caddy_data: - name: surfsense-caddy-data - caddy_config: - name: surfsense-caddy-config diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 1ee7ae0ed..9bbf28ffd 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -94,42 +94,12 @@ services: timeout: 5s retries: 5 - # Single public entry point for the Docker stack. Comment this service out - # only if you front SurfSense with your own reverse proxy. - proxy: - image: caddy:2-alpine - # For DNS-01/wildcard certificates, replace image with: - # build: ./proxy - restart: unless-stopped - ports: - - "${LISTEN_HTTP_PORT:-3929}:80" - - "${LISTEN_HTTPS_PORT:-443}:443" - volumes: - - ./proxy/Caddyfile:/etc/caddy/Caddyfile:ro - - caddy_data:/data - - caddy_config:/config - environment: - SURFSENSE_SITE_ADDRESS: ${SURFSENSE_SITE_ADDRESS:-:80} - CERT_EMAIL: ${CERT_EMAIL:-} - CERT_ACME_CA: ${CERT_ACME_CA:-https://acme-v02.api.letsencrypt.org/directory} - CERT_ACME_DNS: ${CERT_ACME_DNS:-} - TRUSTED_PROXIES: ${TRUSTED_PROXIES:-0.0.0.0/0} - SURFSENSE_MAX_BODY_SIZE: ${SURFSENSE_MAX_BODY_SIZE:-5GB} - depends_on: - frontend: - condition: service_started - backend: - condition: service_healthy - zero-cache: - condition: service_healthy - backend: image: ghcr.io/modsetter/surfsense-backend:${SURFSENSE_VERSION:-latest}${SURFSENSE_VARIANT:+-${SURFSENSE_VARIANT}} - expose: - - "8000" + ports: + - "${BACKEND_PORT:-8929}:8000" volumes: - shared_temp:/shared_tmp - - object_store:/app/.local_object_store env_file: - .env extra_hosts: @@ -143,9 +113,7 @@ services: PYTHONPATH: /app UVICORN_LOOP: asyncio UNSTRUCTURED_HAS_PATCHED_LOOP: "1" - FILE_STORAGE_LOCAL_PATH: /app/.local_object_store - NEXT_FRONTEND_URL: ${NEXT_FRONTEND_URL:-${SURFSENSE_PUBLIC_URL:-http://localhost:${LISTEN_HTTP_PORT:-3929}}} - BACKEND_URL: ${BACKEND_URL:-${SURFSENSE_PUBLIC_URL:-http://localhost:${LISTEN_HTTP_PORT:-3929}}} + NEXT_FRONTEND_URL: ${NEXT_FRONTEND_URL:-http://localhost:${FRONTEND_PORT:-3929}} SEARXNG_DEFAULT_HOST: ${SEARXNG_DEFAULT_HOST:-http://searxng:8080} WHATSAPP_BRIDGE_URL: ${WHATSAPP_BRIDGE_URL:-http://whatsapp-bridge:9929} # Daytona Sandbox – uncomment and set credentials to enable cloud code execution @@ -197,7 +165,6 @@ services: image: ghcr.io/modsetter/surfsense-backend:${SURFSENSE_VERSION:-latest}${SURFSENSE_VARIANT:+-${SURFSENSE_VARIANT}} volumes: - shared_temp:/shared_tmp - - object_store:/app/.local_object_store env_file: - .env extra_hosts: @@ -209,7 +176,6 @@ services: REDIS_APP_URL: ${REDIS_URL:-redis://redis:6379/0} CELERY_TASK_DEFAULT_QUEUE: surfsense PYTHONPATH: /app - FILE_STORAGE_LOCAL_PATH: /app/.local_object_store SEARXNG_DEFAULT_HOST: ${SEARXNG_DEFAULT_HOST:-http://searxng:8080} SERVICE_ROLE: worker depends_on: @@ -251,8 +217,8 @@ services: zero-cache: image: rocicorp/zero:1.4.0 - expose: - - "4848" + ports: + - "${ZERO_CACHE_PORT:-5929}:4848" extra_hosts: - "host.docker.internal:host-gateway" environment: @@ -286,13 +252,16 @@ services: frontend: image: ghcr.io/modsetter/surfsense-web:${SURFSENSE_VERSION:-latest} - expose: - - "3000" + ports: + - "${FRONTEND_PORT:-3929}:3000" environment: - AUTH_TYPE: ${AUTH_TYPE:-LOCAL} - ETL_SERVICE: ${ETL_SERVICE:-DOCLING} - DEPLOYMENT_MODE: ${DEPLOYMENT_MODE:-self-hosted} - SURFSENSE_BACKEND_INTERNAL_URL: http://backend:8000 + NEXT_PUBLIC_FASTAPI_BACKEND_URL: ${NEXT_PUBLIC_FASTAPI_BACKEND_URL:-http://localhost:${BACKEND_PORT:-8929}} + NEXT_PUBLIC_ZERO_CACHE_URL: ${NEXT_PUBLIC_ZERO_CACHE_URL:-http://localhost:${ZERO_CACHE_PORT:-5929}} + NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE: ${AUTH_TYPE:-LOCAL} + NEXT_PUBLIC_ETL_SERVICE: ${ETL_SERVICE:-DOCLING} + NEXT_PUBLIC_DEPLOYMENT_MODE: ${DEPLOYMENT_MODE:-self-hosted} + NEXT_PUBLIC_WHATSAPP_DISPLAY_PHONE_NUMBER: ${WHATSAPP_SHARED_DISPLAY_PHONE_NUMBER:-} + FASTAPI_BACKEND_INTERNAL_URL: ${FASTAPI_BACKEND_INTERNAL_URL:-http://backend:8000} labels: - "com.centurylinklabs.watchtower.enable=true" depends_on: @@ -309,13 +278,7 @@ volumes: name: surfsense-redis shared_temp: name: surfsense-shared-temp - object_store: - name: surfsense-object-store zero_cache_data: name: surfsense-zero-cache - caddy_data: - name: surfsense-caddy-data - caddy_config: - name: surfsense-caddy-config whatsapp_sessions: name: surfsense-whatsapp-sessions diff --git a/docker/proxy/Caddyfile b/docker/proxy/Caddyfile deleted file mode 100644 index 534a8c2c2..000000000 --- a/docker/proxy/Caddyfile +++ /dev/null @@ -1,45 +0,0 @@ -{ - # Optional ACME/global settings. These are harmless in the default :80 - # localhost mode and become active when SURFSENSE_SITE_ADDRESS is a domain. - {$CERT_EMAIL} - acme_ca {$CERT_ACME_CA:https://acme-v02.api.letsencrypt.org/directory} - {$CERT_ACME_DNS} - servers { - client_ip_headers X-Forwarded-For X-Real-IP - trusted_proxies static {$TRUSTED_PROXIES:0.0.0.0/0} - } -} - -(surfsense_proxy) { - request_body { - max_size {$SURFSENSE_MAX_BODY_SIZE:5GB} - } - - # Frontend-owned auth page (the post-login token handler). More specific than - # /auth/*, so Caddy's matcher-specificity sort routes it here, not to backend. - reverse_proxy /auth/callback* frontend:3000 - - # Backend auth routes (FastAPI Users + OAuth helpers). - reverse_proxy /auth/* backend:8000 - - # Backend user profile routes (FastAPI Users users router, mounted at /users). - reverse_proxy /users/* backend:8000 - - # Backend REST, streaming, connector OAuth, and messaging gateway endpoints. - # FastAPI already serves /api/v1, so the path is forwarded unchanged. - reverse_proxy /api/v1/* backend:8000 { - flush_interval -1 - } - - # Zero accepts a single path-component base URL (Zero >= 0.6). - # Preserve /zero so browser cacheURL can be ${SURFSENSE_PUBLIC_URL}/zero. - reverse_proxy /zero/* zero-cache:4848 - - # Next.js app and frontend-owned API routes: - # /api/zero/*, /api/search, /api/contact, etc. - reverse_proxy /* frontend:3000 -} - -{$SURFSENSE_SITE_ADDRESS::80} { - import surfsense_proxy -} diff --git a/docker/proxy/Dockerfile b/docker/proxy/Dockerfile deleted file mode 100644 index 8395a817c..000000000 --- a/docker/proxy/Dockerfile +++ /dev/null @@ -1,10 +0,0 @@ -FROM caddy:2-builder-alpine AS builder - -RUN xcaddy build \ - --with github.com/caddy-dns/cloudflare \ - --with github.com/caddy-dns/digitalocean - -FROM caddy:2-alpine - -COPY --from=builder /usr/bin/caddy /usr/bin/caddy -COPY Caddyfile /etc/caddy/Caddyfile diff --git a/docker/scripts/install.sh b/docker/scripts/install.sh index f9660b132..4df15fbd0 100644 --- a/docker/scripts/install.sh +++ b/docker/scripts/install.sh @@ -333,13 +333,11 @@ step "Downloading SurfSense files" info "Installation directory: ${INSTALL_DIR}" mkdir -p "${INSTALL_DIR}/scripts" mkdir -p "${INSTALL_DIR}/searxng" -mkdir -p "${INSTALL_DIR}/proxy" FILES=( "docker/docker-compose.yml:docker-compose.yml" "docker/docker-compose.gpu.yml:docker-compose.gpu.yml" "docker/.env.example:.env.example" - "docker/proxy/Caddyfile:proxy/Caddyfile" "docker/postgresql.conf:postgresql.conf" "docker/scripts/migrate-database.sh:scripts/migrate-database.sh" "docker/searxng/settings.yml:searxng/settings.yml" @@ -534,12 +532,9 @@ _variant_display=$(grep '^SURFSENSE_VARIANT=' "${INSTALL_DIR}/.env" 2>/dev/null _variant_display="${_variant_display:-cpu}" step "SurfSense is now installed [${_version_display}]" -_public_url=$(grep '^SURFSENSE_PUBLIC_URL=' "${INSTALL_DIR}/.env" 2>/dev/null | cut -d= -f2- | tr -d '"' | head -1 || true) -_public_url="${_public_url:-http://localhost:3929}" - -info " SurfSense: ${_public_url}" -info " Backend: ${_public_url}/api/v1" -info " Zero sync: ${_public_url}/zero" +info " Frontend: http://localhost:3929" +info " Backend: http://localhost:8929" +info " API Docs: http://localhost:8929/docs" info "" info " Config: ${INSTALL_DIR}/.env" info " Variant: ${_variant_display}" diff --git a/surfsense_backend/.env.example b/surfsense_backend/.env.example index a6b2b30a3..b4f67328c 100644 --- a/surfsense_backend/.env.example +++ b/surfsense_backend/.env.example @@ -1,20 +1,5 @@ DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense -# --- Database startup / safety knobs (optional) --- -# Run extension/table/index DDL on app startup. Set FALSE when schema is owned -# exclusively by Alembic migrations. -# DB_BOOTSTRAP_ON_STARTUP=TRUE -# lock_timeout (ms) for boot-time DDL so a contended CREATE INDEX/TABLE fails -# fast instead of hanging the FastAPI lifespan behind another transaction. -# DB_DDL_LOCK_TIMEOUT_MS=5000 -# idle_in_transaction_session_timeout (ms) so an abandoned "idle in transaction" -# session can't wedge the DB indefinitely. 0 disables. (asyncpg only) -# DB_IDLE_IN_TX_TIMEOUT_MS=900000 -# Same, for the Celery worker engine (long ingestion/podcast/video tasks). If a -# task hasn't touched the DB in this window it's treated as orphaned and dropped. -# 0 disables. (asyncpg only) -# DB_CELERY_IDLE_IN_TX_TIMEOUT_MS=3600000 - # Deployment environment: dev or production SURFSENSE_ENV=dev @@ -30,9 +15,12 @@ CELERY_TASK_DEFAULT_QUEUE=surfsense # Optional: TTL in seconds for connector indexing lock key # CONNECTOR_INDEXING_LOCK_TTL_SECONDS=28800 -# Messaging Gateway: disabled by default; set TRUE to enable chat integrations. -# Supported messaging gateways: WhatsApp, Telegram, Discord, Slack -# GATEWAY_ENABLED=TRUE +# Messaging Gateway (global) +# GATEWAY_ENABLED: master switch for ALL messaging gateway channels (Telegram, WhatsApp, +# Slack, Discord). When FALSE, no gateway background workers/supervisors start and all +# gateway HTTP routes (webhooks, OAuth callbacks, pairing) return 404. Set per-channel +# flags below to control individual platforms once the gateway is enabled. +GATEWAY_ENABLED=TRUE # Telegram Gateway # TELEGRAM_WEBHOOK_SECRET must be 1-256 chars and contain only A-Z, a-z, 0-9, _ or - @@ -323,42 +311,6 @@ FILE_STORAGE_BACKEND=local # AZURE_STORAGE_CONNECTION_STRING=DefaultEndpointsProtocol=https;AccountName=...;AccountKey=...;EndpointSuffix=core.windows.net # AZURE_STORAGE_CONTAINER=surfsense-documents -# ETL Parse Cache -# Reuse parser output for identical file bytes across workspaces (skips paid -# re-parsing on LlamaCloud / Azure DI / Unstructured). Off by default. -ETL_CACHE_ENABLED=false -# Bump to invalidate all cached entries after a parser/behaviour change. -# ETL_CACHE_PARSER_VERSION=1 -# Prune entries unused for this many days. -# ETL_CACHE_TTL_DAYS=90 -# Soft cap on total cached markdown; coldest entries are evicted past it. -# ETL_CACHE_MAX_TOTAL_MB=5120 -# Rows deleted per eviction pass. -# ETL_CACHE_EVICTION_BATCH=500 -# Optional dedicated blob storage; unset reuses the main file storage backend. -# ETL_CACHE_STORAGE_BACKEND=azure -# ETL_CACHE_STORAGE_CONTAINER=surfsense-etl-cache -# ETL_CACHE_STORAGE_LOCAL_PATH=/var/lib/surfsense/etl-cache - -# Embedding Cache -# Reuse chunk+embedding output for identical markdown across workspaces (skips -# re-chunking and re-embedding). Blobs share the ETL_CACHE_STORAGE_* backend. -# Off by default. -EMBEDDING_CACHE_ENABLED=false -# Bump to invalidate all cached embedding sets after a chunker change. -# EMBEDDING_CACHE_CHUNKER_VERSION=1 -# Prune entries unused for this many days. -# EMBEDDING_CACHE_TTL_DAYS=90 -# Soft cap on total cached embeddings; coldest entries are evicted past it. -# EMBEDDING_CACHE_MAX_TOTAL_MB=5120 -# Rows deleted per eviction pass. -# EMBEDDING_CACHE_EVICTION_BATCH=500 - -# Incremental re-indexing: on document edits, keep chunks whose text is -# unchanged (reusing their embeddings) and embed only new/changed ones. -# Set to false to fall back to delete-all + full re-embed (kill switch). -# CHUNK_RECONCILE_ENABLED=true - # Daytona Sandbox (isolated code execution) # DAYTONA_SANDBOX_ENABLED=FALSE # DAYTONA_API_KEY=your-daytona-api-key @@ -398,9 +350,7 @@ LANGSMITH_PROJECT=surfsense # SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false # adds a per-turn LLM call # Observability - OTel -# Disabled by default. Uncomment to enable OpenTelemetry. -# SURFSENSE_ENABLE_OTEL=true - +# SURFSENSE_ENABLE_OTEL=false # OpenTelemetry - endpoint enables export; absent = no-op. # Production should point at an OTel Collector. For local docker-compose.dev.yml, # use http://otel-lgtm:4317 instead. diff --git a/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py b/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py index 8c74b637b..fba621a0c 100644 --- a/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py +++ b/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py @@ -4,7 +4,7 @@ Revision ID: 138 Revises: 137 Create Date: 2026-04-30 -Add a single thread-level column to persist the Auto model pin: +Add a single thread-level column to persist the Auto (Fastest) model pin: - pinned_llm_config_id: concrete resolved global LLM config id used for this thread. NULL means "no pin; Auto will resolve on next turn". diff --git a/surfsense_backend/alembic/versions/158_evolve_podcasts_lifecycle.py b/surfsense_backend/alembic/versions/158_evolve_podcasts_lifecycle.py index f1d231f9e..f3b194cbd 100644 --- a/surfsense_backend/alembic/versions/158_evolve_podcasts_lifecycle.py +++ b/surfsense_backend/alembic/versions/158_evolve_podcasts_lifecycle.py @@ -15,19 +15,6 @@ down_revision: str | None = "157" branch_labels: str | Sequence[str] | None = None depends_on: str | Sequence[str] | None = None -PUBLICATION_NAME = "zero_publication" -TARGET_STATUS_LABELS = ( - "pending", - "awaiting_brief", - "drafting", - "awaiting_review", - "rendering", - "ready", - "failed", - "cancelled", -) -LEGACY_STATUS_LABELS = ("pending", "generating", "ready", "failed") - def _drop_podcasts_from_publication() -> None: """Detach podcasts from zero_publication so status can be retyped. @@ -41,103 +28,31 @@ def _drop_podcasts_from_publication() -> None: published = conn.execute( sa.text( "SELECT 1 FROM pg_publication_tables " - "WHERE pubname = :publication " + "WHERE pubname = 'zero_publication' " "AND schemaname = current_schema() AND tablename = 'podcasts'" - ), - {"publication": PUBLICATION_NAME}, + ) ).fetchone() if published: - op.execute(f'ALTER PUBLICATION "{PUBLICATION_NAME}" DROP TABLE "podcasts";') + op.execute('ALTER PUBLICATION "zero_publication" DROP TABLE "podcasts";') -def _enum_labels(type_name: str) -> list[str] | None: - rows = ( - op.get_bind() - .execute( - sa.text( - "SELECT e.enumlabel " - "FROM pg_type t " - "JOIN pg_namespace n ON n.oid = t.typnamespace " - "JOIN pg_enum e ON e.enumtypid = t.oid " - "WHERE n.nspname = current_schema() AND t.typname = :type_name " - "ORDER BY e.enumsortorder" - ), - {"type_name": type_name}, - ) - .fetchall() - ) - if not rows: - return None - return [str(row[0]) for row in rows] +def upgrade() -> None: + _drop_podcasts_from_publication() - -def _column_type_name(table: str, column: str) -> str | None: - row = ( - op.get_bind() - .execute( - sa.text( - "SELECT udt_name " - "FROM information_schema.columns " - "WHERE table_schema = current_schema() " - "AND table_name = :table AND column_name = :column" - ), - {"table": table, "column": column}, - ) - .fetchone() - ) - return str(row[0]) if row else None - - -def _ensure_status_enum( - *, - desired_labels: tuple[str, ...], - temporary_type: str, - create_sql: str, - alter_sql: str, - default_value: str, -) -> None: - current_labels = _enum_labels("podcast_status") - desired = list(desired_labels) - - if current_labels != desired: - if current_labels is None: - if _enum_labels(temporary_type) is None: - raise RuntimeError("podcast_status enum is missing") - elif _enum_labels(temporary_type) is None: - op.execute(f"ALTER TYPE podcast_status RENAME TO {temporary_type};") - else: - raise RuntimeError( - "podcast_status and its temporary replacement both exist" - ) - - if _enum_labels("podcast_status") is None: - op.execute(create_sql) - - if _enum_labels("podcast_status") != desired: - raise RuntimeError("podcast_status enum is not in the expected shape") - - op.execute("ALTER TABLE podcasts ALTER COLUMN status DROP DEFAULT;") - if _column_type_name("podcasts", "status") != "podcast_status": - op.execute(alter_sql) + # Retype the status enum by swapping in a fresh type and casting existing + # rows. The legacy transient value 'generating' maps onto 'rendering'. + op.execute("ALTER TYPE podcast_status RENAME TO podcast_status_old;") op.execute( - f"ALTER TABLE podcasts ALTER COLUMN status SET DEFAULT '{default_value}';" - ) - - if _enum_labels(temporary_type) is not None: - op.execute(f"DROP TYPE {temporary_type};") - - -def _upgrade_status_enum() -> None: - _ensure_status_enum( - desired_labels=TARGET_STATUS_LABELS, - temporary_type="podcast_status_old", - create_sql=""" + """ CREATE TYPE podcast_status AS ENUM ( 'pending', 'awaiting_brief', 'drafting', 'awaiting_review', 'rendering', 'ready', 'failed', 'cancelled' ); - """, - alter_sql=""" + """ + ) + op.execute("ALTER TABLE podcasts ALTER COLUMN status DROP DEFAULT;") + op.execute( + """ ALTER TABLE podcasts ALTER COLUMN status TYPE podcast_status USING ( @@ -146,43 +61,10 @@ def _upgrade_status_enum() -> None: ELSE status::text END )::podcast_status; - """, - default_value="pending", + """ ) - - -def _downgrade_status_enum() -> None: - _ensure_status_enum( - desired_labels=LEGACY_STATUS_LABELS, - temporary_type="podcast_status_new", - create_sql=( - "CREATE TYPE podcast_status AS ENUM " - "('pending', 'generating', 'ready', 'failed');" - ), - alter_sql=""" - ALTER TABLE podcasts - ALTER COLUMN status TYPE podcast_status - USING ( - CASE status::text - WHEN 'awaiting_brief' THEN 'pending' - WHEN 'drafting' THEN 'generating' - WHEN 'awaiting_review' THEN 'generating' - WHEN 'rendering' THEN 'generating' - WHEN 'cancelled' THEN 'failed' - ELSE status::text - END - )::podcast_status; - """, - default_value="ready", - ) - - -def upgrade() -> None: - _drop_podcasts_from_publication() - - # Retype the status enum by swapping in a fresh type and casting existing - # rows. The legacy transient value 'generating' maps onto 'rendering'. - _upgrade_status_enum() + op.execute("ALTER TABLE podcasts ALTER COLUMN status SET DEFAULT 'pending';") + op.execute("DROP TYPE podcast_status_old;") op.execute("ALTER TABLE podcasts ADD COLUMN IF NOT EXISTS source_content TEXT;") op.execute("ALTER TABLE podcasts ADD COLUMN IF NOT EXISTS spec JSONB;") @@ -201,8 +83,6 @@ def upgrade() -> None: def downgrade() -> None: - _drop_podcasts_from_publication() - op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS error;") op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS duration_seconds;") op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS storage_key;") @@ -212,4 +92,27 @@ def downgrade() -> None: op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS source_content;") # Collapse the expanded lifecycle back onto the original four values. - _downgrade_status_enum() + op.execute("ALTER TYPE podcast_status RENAME TO podcast_status_new;") + op.execute( + "CREATE TYPE podcast_status AS ENUM " + "('pending', 'generating', 'ready', 'failed');" + ) + op.execute("ALTER TABLE podcasts ALTER COLUMN status DROP DEFAULT;") + op.execute( + """ + ALTER TABLE podcasts + ALTER COLUMN status TYPE podcast_status + USING ( + CASE status::text + WHEN 'awaiting_brief' THEN 'pending' + WHEN 'drafting' THEN 'generating' + WHEN 'awaiting_review' THEN 'generating' + WHEN 'rendering' THEN 'generating' + WHEN 'cancelled' THEN 'failed' + ELSE status::text + END + )::podcast_status; + """ + ) + op.execute("ALTER TABLE podcasts ALTER COLUMN status SET DEFAULT 'ready';") + op.execute("DROP TYPE podcast_status_new;") diff --git a/surfsense_backend/alembic/versions/160_add_model_connections.py b/surfsense_backend/alembic/versions/160_add_model_connections.py deleted file mode 100644 index fea45aca7..000000000 --- a/surfsense_backend/alembic/versions/160_add_model_connections.py +++ /dev/null @@ -1,299 +0,0 @@ -"""add model connections - -Revision ID: 160 -Revises: 159 -""" - -from collections.abc import Sequence - -import sqlalchemy as sa -from sqlalchemy.dialects import postgresql - -from alembic import op - -revision: str = "160" -down_revision: str | None = "159" -branch_labels: str | Sequence[str] | None = None -depends_on: str | Sequence[str] | None = None - - -connection_scope = postgresql.ENUM( - "GLOBAL", - "SEARCH_SPACE", - "USER", - name="connectionscope", - create_type=False, -) -model_source = postgresql.ENUM( - "DISCOVERED", - "MANUAL", - name="modelsource", - create_type=False, -) - - -def _table_exists(table_name: str) -> bool: - return table_name in sa.inspect(op.get_bind()).get_table_names() - - -def _column_exists(table_name: str, column_name: str) -> bool: - if not _table_exists(table_name): - return False - return column_name in { - column["name"] for column in sa.inspect(op.get_bind()).get_columns(table_name) - } - - -def _index_exists(table_name: str, index_name: str) -> bool: - if not _table_exists(table_name): - return False - return index_name in { - index["name"] for index in sa.inspect(op.get_bind()).get_indexes(table_name) - } - - -def _create_index_if_missing( - index_name: str, - table_name: str, - columns: list[str], -) -> None: - if not _index_exists(table_name, index_name): - op.create_index(index_name, table_name, columns, unique=False) - - -def _add_searchspace_column_if_missing( - column_name: str, - *, - server_default: object | None = None, -) -> None: - if not _column_exists("searchspaces", column_name): - op.add_column( - "searchspaces", - sa.Column( - column_name, - sa.Integer(), - nullable=True, - server_default=server_default, - ), - ) - - -def _drop_column_if_exists(table_name: str, column_name: str) -> None: - if _column_exists(table_name, column_name): - op.drop_column(table_name, column_name) - - -def _drop_index_if_exists(table_name: str, index_name: str) -> None: - if _index_exists(table_name, index_name): - op.drop_index(index_name, table_name=table_name) - - -def upgrade() -> None: - bind = op.get_bind() - connection_scope.create(bind, checkfirst=True) - model_source.create(bind, checkfirst=True) - - if _table_exists("connections"): - if _column_exists("connections", "litellm_provider") and not _column_exists( - "connections", "provider" - ): - op.alter_column( - "connections", - "litellm_provider", - new_column_name="provider", - existing_type=sa.String(length=100), - existing_nullable=True, - ) - op.alter_column( - "connections", - "provider", - existing_type=sa.String(length=100), - nullable=False, - ) - elif _column_exists("connections", "native_provider") and not _column_exists( - "connections", "provider" - ): - op.alter_column( - "connections", - "native_provider", - new_column_name="provider", - existing_type=sa.String(length=100), - existing_nullable=True, - ) - op.alter_column( - "connections", - "provider", - existing_type=sa.String(length=100), - nullable=False, - ) - elif not _column_exists("connections", "provider"): - op.add_column( - "connections", - sa.Column("provider", sa.String(length=100), nullable=False), - ) - _drop_index_if_exists("connections", "ix_connections_protocol") - _drop_column_if_exists("connections", "protocol") - else: - op.create_table( - "connections", - sa.Column("id", sa.Integer(), nullable=False), - sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), - sa.Column("provider", sa.String(length=100), nullable=False), - sa.Column("base_url", sa.String(length=500), nullable=True), - sa.Column("api_key", sa.String(), nullable=True), - sa.Column( - "extra", - postgresql.JSONB(astext_type=sa.Text()), - server_default=sa.text("'{}'::jsonb"), - nullable=False, - ), - sa.Column("scope", connection_scope, nullable=False), - sa.Column( - "enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False - ), - sa.Column("search_space_id", sa.Integer(), nullable=True), - sa.Column("user_id", sa.UUID(), nullable=True), - sa.CheckConstraint( - "(scope = 'GLOBAL' AND search_space_id IS NULL AND user_id IS NULL) OR " - "(scope = 'SEARCH_SPACE' AND search_space_id IS NOT NULL AND user_id IS NOT NULL) OR " - "(scope = 'USER' AND user_id IS NOT NULL)", - name="ck_connections_scope_owner", - ), - sa.ForeignKeyConstraint( - ["search_space_id"], ["searchspaces.id"], ondelete="CASCADE" - ), - sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), - sa.PrimaryKeyConstraint("id"), - ) - if _index_exists( - "connections", "ix_connections_native_provider" - ) and not _index_exists("connections", "ix_connections_provider"): - op.execute( - "ALTER INDEX ix_connections_native_provider " - "RENAME TO ix_connections_provider" - ) - if _index_exists( - "connections", "ix_connections_litellm_provider" - ) and not _index_exists("connections", "ix_connections_provider"): - op.execute( - "ALTER INDEX ix_connections_litellm_provider " - "RENAME TO ix_connections_provider" - ) - _create_index_if_missing("ix_connections_provider", "connections", ["provider"]) - _create_index_if_missing("ix_connections_scope", "connections", ["scope"]) - - if not _table_exists("models"): - op.create_table( - "models", - sa.Column("id", sa.Integer(), nullable=False), - sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), - sa.Column("connection_id", sa.Integer(), nullable=False), - sa.Column("model_id", sa.String(length=255), nullable=False), - sa.Column("display_name", sa.String(length=255), nullable=True), - sa.Column( - "source", - model_source, - server_default="DISCOVERED", - nullable=False, - ), - sa.Column("supports_chat", sa.Boolean(), nullable=True), - sa.Column("max_input_tokens", sa.Integer(), nullable=True), - sa.Column("supports_image_input", sa.Boolean(), nullable=True), - sa.Column("supports_tools", sa.Boolean(), nullable=True), - sa.Column("supports_image_generation", sa.Boolean(), nullable=True), - sa.Column( - "capabilities_override", - postgresql.JSONB(astext_type=sa.Text()), - server_default=sa.text("'{}'::jsonb"), - nullable=False, - ), - sa.Column( - "enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False - ), - sa.Column("billing_tier", sa.String(length=50), nullable=True), - sa.Column( - "catalog", - postgresql.JSONB(astext_type=sa.Text()), - server_default=sa.text("'{}'::jsonb"), - nullable=False, - ), - sa.ForeignKeyConstraint( - ["connection_id"], ["connections.id"], ondelete="CASCADE" - ), - sa.PrimaryKeyConstraint("id"), - sa.UniqueConstraint( - "connection_id", "model_id", name="uq_models_connection_model_id" - ), - ) - else: - if not _column_exists("models", "supports_chat"): - op.add_column( - "models", sa.Column("supports_chat", sa.Boolean(), nullable=True) - ) - if not _column_exists("models", "max_input_tokens"): - op.add_column( - "models", sa.Column("max_input_tokens", sa.Integer(), nullable=True) - ) - if not _column_exists("models", "supports_image_input"): - op.add_column( - "models", sa.Column("supports_image_input", sa.Boolean(), nullable=True) - ) - if not _column_exists("models", "supports_tools"): - op.add_column( - "models", sa.Column("supports_tools", sa.Boolean(), nullable=True) - ) - if not _column_exists("models", "supports_image_generation"): - op.add_column( - "models", - sa.Column("supports_image_generation", sa.Boolean(), nullable=True), - ) - _drop_column_if_exists("models", "capabilities") - _drop_column_if_exists("models", "capabilities_declared") - _drop_column_if_exists("models", "capabilities_verified") - _create_index_if_missing("ix_models_connection_id", "models", ["connection_id"]) - _create_index_if_missing("ix_models_model_id", "models", ["model_id"]) - _create_index_if_missing("ix_models_billing_tier", "models", ["billing_tier"]) - - _add_searchspace_column_if_missing("chat_model_id", server_default=sa.text("0")) - _add_searchspace_column_if_missing( - "image_gen_model_id", server_default=sa.text("0") - ) - _add_searchspace_column_if_missing("vision_model_id", server_default=sa.text("0")) - for column_name in ("chat_model_id", "image_gen_model_id", "vision_model_id"): - op.alter_column( - "searchspaces", - column_name, - existing_type=sa.Integer(), - existing_nullable=True, - server_default=sa.text("0"), - ) - op.execute( - """ - UPDATE searchspaces - SET - chat_model_id = COALESCE(chat_model_id, 0), - image_gen_model_id = COALESCE(image_gen_model_id, 0), - vision_model_id = COALESCE(vision_model_id, 0) - """ - ) - - op.execute("DROP TYPE IF EXISTS connectionprotocol") - - -def downgrade() -> None: - op.drop_column("searchspaces", "vision_model_id") - op.drop_column("searchspaces", "image_gen_model_id") - op.drop_column("searchspaces", "chat_model_id") - - op.drop_index(op.f("ix_models_billing_tier"), table_name="models") - op.drop_index("ix_models_model_id", table_name="models") - op.drop_index(op.f("ix_models_connection_id"), table_name="models") - op.drop_table("models") - - op.drop_index(op.f("ix_connections_scope"), table_name="connections") - op.drop_index(op.f("ix_connections_provider"), table_name="connections") - op.drop_table("connections") - - bind = op.get_bind() - model_source.drop(bind, checkfirst=True) - connection_scope.drop(bind, checkfirst=True) diff --git a/surfsense_backend/alembic/versions/161_remove_legacy_model_configs.py b/surfsense_backend/alembic/versions/161_remove_legacy_model_configs.py deleted file mode 100644 index 2108d763c..000000000 --- a/surfsense_backend/alembic/versions/161_remove_legacy_model_configs.py +++ /dev/null @@ -1,270 +0,0 @@ -"""remove legacy model config tables - -Revision ID: 161 -Revises: 160 -""" - -from collections.abc import Sequence - -import sqlalchemy as sa -from sqlalchemy.dialects import postgresql -from sqlalchemy.types import TypeEngine - -from alembic import op - -revision: str = "161" -down_revision: str | None = "160" -branch_labels: str | Sequence[str] | None = None -depends_on: str | Sequence[str] | None = None - - -litellm_provider = postgresql.ENUM( - "OPENAI", - "ANTHROPIC", - "GOOGLE", - "AZURE_OPENAI", - "BEDROCK", - "VERTEX_AI", - "GROQ", - "COHERE", - "MISTRAL", - "DEEPSEEK", - "XAI", - "OPENROUTER", - "TOGETHER_AI", - "FIREWORKS_AI", - "REPLICATE", - "PERPLEXITY", - "OLLAMA", - "ALIBABA_QWEN", - "MOONSHOT", - "ZHIPU", - "ANYSCALE", - "DEEPINFRA", - "CEREBRAS", - "SAMBANOVA", - "AI21", - "CLOUDFLARE", - "DATABRICKS", - "COMETAPI", - "HUGGINGFACE", - "GITHUB_MODELS", - "MINIMAX", - "CUSTOM", - name="litellmprovider", - create_type=False, -) -image_gen_provider = postgresql.ENUM( - "OPENAI", - "AZURE_OPENAI", - "GOOGLE", - "VERTEX_AI", - "BEDROCK", - "RECRAFT", - "OPENROUTER", - "XINFERENCE", - "NSCALE", - name="imagegenprovider", - create_type=False, -) -vision_provider = postgresql.ENUM( - "OPENAI", - "ANTHROPIC", - "GOOGLE", - "AZURE_OPENAI", - "VERTEX_AI", - "BEDROCK", - "XAI", - "OPENROUTER", - "OLLAMA", - "GROQ", - "TOGETHER_AI", - "FIREWORKS_AI", - "DEEPSEEK", - "MISTRAL", - "CUSTOM", - name="visionprovider", - create_type=False, -) - - -def _table_exists(table_name: str) -> bool: - return table_name in sa.inspect(op.get_bind()).get_table_names() - - -def _column_exists(table_name: str, column_name: str) -> bool: - if not _table_exists(table_name): - return False - return column_name in { - column["name"] for column in sa.inspect(op.get_bind()).get_columns(table_name) - } - - -def _drop_column_if_exists(table_name: str, column_name: str) -> None: - if _column_exists(table_name, column_name): - op.drop_column(table_name, column_name) - - -def _rename_column_if_exists( - table_name: str, - old_column_name: str, - new_column_name: str, - *, - existing_type: TypeEngine, - existing_nullable: bool = True, -) -> None: - if _column_exists(table_name, old_column_name) and not _column_exists( - table_name, new_column_name - ): - op.alter_column( - table_name, - old_column_name, - new_column_name=new_column_name, - existing_type=existing_type, - existing_nullable=existing_nullable, - ) - - -def upgrade() -> None: - for table_name in ( - "new_llm_configs", - "vision_llm_configs", - "image_generation_configs", - ): - if _table_exists(table_name): - op.drop_table(table_name) - - _drop_column_if_exists("searchspaces", "agent_llm_id") - _drop_column_if_exists("searchspaces", "image_generation_config_id") - _drop_column_if_exists("searchspaces", "vision_llm_config_id") - - _rename_column_if_exists( - "image_generations", - "image_generation_config_id", - "image_gen_model_id", - existing_type=sa.Integer(), - ) - - op.execute("DROP TYPE IF EXISTS litellmprovider") - op.execute("DROP TYPE IF EXISTS imagegenprovider") - op.execute("DROP TYPE IF EXISTS visionprovider") - - -def downgrade() -> None: - bind = op.get_bind() - litellm_provider.create(bind, checkfirst=True) - image_gen_provider.create(bind, checkfirst=True) - vision_provider.create(bind, checkfirst=True) - - _rename_column_if_exists( - "image_generations", - "image_gen_model_id", - "image_generation_config_id", - existing_type=sa.Integer(), - ) - - if _table_exists("searchspaces"): - if not _column_exists("searchspaces", "agent_llm_id"): - op.add_column( - "searchspaces", - sa.Column("agent_llm_id", sa.Integer(), nullable=True), - ) - if not _column_exists("searchspaces", "image_generation_config_id"): - op.add_column( - "searchspaces", - sa.Column("image_generation_config_id", sa.Integer(), nullable=True), - ) - if not _column_exists("searchspaces", "vision_llm_config_id"): - op.add_column( - "searchspaces", - sa.Column("vision_llm_config_id", sa.Integer(), nullable=True), - ) - - if not _table_exists("image_generation_configs"): - op.create_table( - "image_generation_configs", - sa.Column("id", sa.Integer(), nullable=False), - sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), - sa.Column("name", sa.String(length=100), nullable=False), - sa.Column("description", sa.String(length=500), nullable=True), - sa.Column("provider", image_gen_provider, nullable=False), - sa.Column("custom_provider", sa.String(length=100), nullable=True), - sa.Column("model_name", sa.String(length=100), nullable=False), - sa.Column("api_key", sa.String(), nullable=False), - sa.Column("api_base", sa.String(length=500), nullable=True), - sa.Column("api_version", sa.String(length=50), nullable=True), - sa.Column("litellm_params", sa.JSON(), nullable=True), - sa.Column("search_space_id", sa.Integer(), nullable=False), - sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False), - sa.ForeignKeyConstraint( - ["search_space_id"], ["searchspaces.id"], ondelete="CASCADE" - ), - sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), - sa.PrimaryKeyConstraint("id"), - ) - op.create_index( - op.f("ix_image_generation_configs_name"), - "image_generation_configs", - ["name"], - unique=False, - ) - - if not _table_exists("vision_llm_configs"): - op.create_table( - "vision_llm_configs", - sa.Column("id", sa.Integer(), nullable=False), - sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), - sa.Column("name", sa.String(length=100), nullable=False), - sa.Column("description", sa.String(length=500), nullable=True), - sa.Column("provider", vision_provider, nullable=False), - sa.Column("custom_provider", sa.String(length=100), nullable=True), - sa.Column("model_name", sa.String(length=100), nullable=False), - sa.Column("api_key", sa.String(), nullable=False), - sa.Column("api_base", sa.String(length=500), nullable=True), - sa.Column("api_version", sa.String(length=50), nullable=True), - sa.Column("litellm_params", sa.JSON(), nullable=True), - sa.Column("search_space_id", sa.Integer(), nullable=False), - sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False), - sa.ForeignKeyConstraint( - ["search_space_id"], ["searchspaces.id"], ondelete="CASCADE" - ), - sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), - sa.PrimaryKeyConstraint("id"), - ) - op.create_index( - op.f("ix_vision_llm_configs_name"), - "vision_llm_configs", - ["name"], - unique=False, - ) - - if not _table_exists("new_llm_configs"): - op.create_table( - "new_llm_configs", - sa.Column("id", sa.Integer(), nullable=False), - sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), - sa.Column("name", sa.String(length=100), nullable=False), - sa.Column("description", sa.String(length=500), nullable=True), - sa.Column("provider", litellm_provider, nullable=False), - sa.Column("custom_provider", sa.String(length=100), nullable=True), - sa.Column("model_name", sa.String(length=100), nullable=False), - sa.Column("api_key", sa.String(), nullable=False), - sa.Column("api_base", sa.String(length=500), nullable=True), - sa.Column("litellm_params", sa.JSON(), nullable=True), - sa.Column("system_instructions", sa.Text(), nullable=False), - sa.Column("use_default_system_instructions", sa.Boolean(), nullable=False), - sa.Column("citations_enabled", sa.Boolean(), nullable=False), - sa.Column("search_space_id", sa.Integer(), nullable=False), - sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False), - sa.ForeignKeyConstraint( - ["search_space_id"], ["searchspaces.id"], ondelete="CASCADE" - ), - sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), - sa.PrimaryKeyConstraint("id"), - ) - op.create_index( - op.f("ix_new_llm_configs_name"), - "new_llm_configs", - ["name"], - unique=False, - ) diff --git a/surfsense_backend/alembic/versions/162_add_etl_cache_parses.py b/surfsense_backend/alembic/versions/162_add_etl_cache_parses.py deleted file mode 100644 index 87e1c5813..000000000 --- a/surfsense_backend/alembic/versions/162_add_etl_cache_parses.py +++ /dev/null @@ -1,53 +0,0 @@ -"""add etl_cache_parses table for content-addressed parse reuse - -Revision ID: 162 -Revises: 161 -""" - -from collections.abc import Sequence - -from alembic import op - -revision: str = "162" -down_revision: str | None = "161" -branch_labels: str | Sequence[str] | None = None -depends_on: str | Sequence[str] | None = None - - -def upgrade() -> None: - op.execute( - """ - CREATE TABLE IF NOT EXISTS etl_cache_parses ( - id SERIAL PRIMARY KEY, - source_sha256 VARCHAR(64) NOT NULL, - etl_service VARCHAR(32) NOT NULL, - mode VARCHAR(16) NOT NULL, - parser_version INTEGER NOT NULL, - storage_backend VARCHAR(32) NOT NULL, - storage_key TEXT NOT NULL, - size_bytes BIGINT NOT NULL, - content_type VARCHAR(32) NOT NULL, - actual_pages INTEGER NOT NULL DEFAULT 0, - times_reused BIGINT NOT NULL DEFAULT 0, - last_used_at TIMESTAMP WITH TIME ZONE NOT NULL, - created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), - CONSTRAINT uq_etl_cache_parses_key - UNIQUE (source_sha256, etl_service, mode, parser_version) - ); - """ - ) - - op.execute( - "CREATE INDEX IF NOT EXISTS ix_etl_cache_parses_last_used_at " - "ON etl_cache_parses(last_used_at);" - ) - op.execute( - "CREATE INDEX IF NOT EXISTS ix_etl_cache_parses_created_at " - "ON etl_cache_parses(created_at);" - ) - - -def downgrade() -> None: - op.execute("DROP INDEX IF EXISTS ix_etl_cache_parses_created_at;") - op.execute("DROP INDEX IF EXISTS ix_etl_cache_parses_last_used_at;") - op.execute("DROP TABLE IF EXISTS etl_cache_parses;") diff --git a/surfsense_backend/alembic/versions/163_add_embedding_cache_sets.py b/surfsense_backend/alembic/versions/163_add_embedding_cache_sets.py deleted file mode 100644 index f15616332..000000000 --- a/surfsense_backend/alembic/versions/163_add_embedding_cache_sets.py +++ /dev/null @@ -1,53 +0,0 @@ -"""add embedding_cache_sets table for content-addressed embedding reuse - -Revision ID: 163 -Revises: 162 -""" - -from collections.abc import Sequence - -from alembic import op - -revision: str = "163" -down_revision: str | None = "162" -branch_labels: str | Sequence[str] | None = None -depends_on: str | Sequence[str] | None = None - - -def upgrade() -> None: - op.execute( - """ - CREATE TABLE IF NOT EXISTS embedding_cache_sets ( - id SERIAL PRIMARY KEY, - markdown_sha256 VARCHAR(64) NOT NULL, - embedding_model VARCHAR(255) NOT NULL, - embedding_dim INTEGER NOT NULL, - chunker_kind VARCHAR(8) NOT NULL, - chunker_version INTEGER NOT NULL, - storage_backend VARCHAR(32) NOT NULL, - storage_key TEXT NOT NULL, - size_bytes BIGINT NOT NULL, - chunk_count INTEGER NOT NULL DEFAULT 0, - times_reused BIGINT NOT NULL DEFAULT 0, - last_used_at TIMESTAMP WITH TIME ZONE NOT NULL, - created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), - CONSTRAINT uq_embedding_cache_sets_key - UNIQUE (markdown_sha256, embedding_model, chunker_kind, chunker_version) - ); - """ - ) - - op.execute( - "CREATE INDEX IF NOT EXISTS ix_embedding_cache_sets_last_used_at " - "ON embedding_cache_sets(last_used_at);" - ) - op.execute( - "CREATE INDEX IF NOT EXISTS ix_embedding_cache_sets_created_at " - "ON embedding_cache_sets(created_at);" - ) - - -def downgrade() -> None: - op.execute("DROP INDEX IF EXISTS ix_embedding_cache_sets_created_at;") - op.execute("DROP INDEX IF EXISTS ix_embedding_cache_sets_last_used_at;") - op.execute("DROP TABLE IF EXISTS embedding_cache_sets;") diff --git a/surfsense_backend/alembic/versions/164_remove_inactive_users.py b/surfsense_backend/alembic/versions/164_remove_inactive_users.py deleted file mode 100644 index 3ce23e204..000000000 --- a/surfsense_backend/alembic/versions/164_remove_inactive_users.py +++ /dev/null @@ -1,219 +0,0 @@ -"""remove users that never logged back in (last_login IS NULL) - -Migration 103 added ``user.last_login``. Any user whose ``last_login`` is still -NULL has never authenticated since that column existed, i.e. they never logged -back in. This migration purges those users together with everything that hangs -off them: the search spaces they own, and (via ON DELETE CASCADE) -``searchspaces -> documents -> chunks`` plus all other user/space-scoped rows. - -This runs BEFORE the chunks.position backfill (revision 165) on purpose: it -removes a large amount of dead chunk data first, so the expensive backfill has -far fewer rows to rewrite. - -Work is done in committed batches (not one giant cascading DELETE) so that on a -large table it streams progress to the alembic console, keeps each transaction -small, bounds WAL/bloat growth, and is resumable if interrupted. - -Revision ID: 164 -Revises: 163 -""" - -import logging -import time -from collections.abc import Sequence - -import sqlalchemy as sa - -from alembic import op - -revision: str = "164" -down_revision: str | None = "163" -branch_labels: str | Sequence[str] | None = None -depends_on: str | Sequence[str] | None = None - -# Documents removed per committed batch. Each document delete cascades to its -# chunks (via ix_chunks_document_id), so keep this modest to bound batch size. -DOC_BATCH = 1_000 -# Users removed per committed batch. Each cascades to owned search spaces and -# the remaining space-/user-scoped rows. -USER_BATCH = 500 -# Minimum seconds between progress log lines (keeps the console readable). -LOG_EVERY_SECONDS = 5.0 - -USER_SCRATCH = "_inactive_user_ids" -DOC_SCRATCH = "_inactive_doc_ids" - -logger = logging.getLogger("alembic.runtime.migration") - - -def _fmt_duration(seconds: float) -> str: - seconds = int(seconds) - h, rem = divmod(seconds, 3600) - m, s = divmod(rem, 60) - if h: - return f"{h}h{m:02d}m{s:02d}s" - if m: - return f"{m}m{s:02d}s" - return f"{s}s" - - -def upgrade() -> None: - bind = op.get_bind() - - # Run the heavy work outside the migration's single transaction so each - # batch can commit on its own. - with op.get_context().autocommit_block(): - # Materialize the target user ids once. Rebuilt from scratch on every - # run, so a re-run after an interruption simply picks up whoever still - # has NULL last_login -> the migration is idempotent and resumable. - op.execute(f"DROP TABLE IF EXISTS {USER_SCRATCH};") - op.execute( - f"CREATE UNLOGGED TABLE {USER_SCRATCH} AS " - 'SELECT id FROM "user" WHERE last_login IS NULL;' - ) - op.execute(f"ALTER TABLE {USER_SCRATCH} ADD PRIMARY KEY (id);") - - total_users = ( - bind.execute(sa.text(f"SELECT count(*) FROM {USER_SCRATCH}")).scalar() or 0 - ) - if total_users == 0: - logger.info("no users with NULL last_login; nothing to remove") - op.execute(f"DROP TABLE IF EXISTS {USER_SCRATCH};") - return - - logger.info( - "found %s users with NULL last_login (never logged back in); " - "removing them and all data in search spaces they own", - f"{total_users:,}", - ) - - # Documents living in search spaces owned by those users. Deleting these - # explicitly (in batches) is what bounds the otherwise-unbounded - # chunks cascade. - op.execute(f"DROP TABLE IF EXISTS {DOC_SCRATCH};") - op.execute( - f""" - CREATE UNLOGGED TABLE {DOC_SCRATCH} AS - SELECT d.id - FROM documents d - JOIN searchspaces s ON s.id = d.search_space_id - WHERE s.user_id IN (SELECT id FROM {USER_SCRATCH}); - """ - ) - op.execute(f"ALTER TABLE {DOC_SCRATCH} ADD PRIMARY KEY (id);") - total_docs = ( - bind.execute(sa.text(f"SELECT count(*) FROM {DOC_SCRATCH}")).scalar() or 0 - ) - - # Phase 1: delete documents (cascades chunks, document_versions, - # document_files) in committed batches. - logger.info( - "phase 1/2: deleting %s documents (cascades their chunks) " - "in batches of %s...", - f"{total_docs:,}", - f"{DOC_BATCH:,}", - ) - _batched_delete( - bind, - scratch=DOC_SCRATCH, - target_table="documents", - target_col="id", - batch_size=DOC_BATCH, - total=total_docs, - label="documents", - ) - op.execute(f"DROP TABLE IF EXISTS {DOC_SCRATCH};") - - # Phase 2: delete the users themselves. This cascades the now-empty - # search spaces plus all remaining user-/space-scoped rows. - logger.info( - "phase 2/2: deleting %s users (cascades search spaces and " - "remaining data) in batches of %s...", - f"{total_users:,}", - f"{USER_BATCH:,}", - ) - _batched_delete( - bind, - scratch=USER_SCRATCH, - target_table='"user"', - target_col="id", - batch_size=USER_BATCH, - total=total_users, - label="users", - ) - op.execute(f"DROP TABLE IF EXISTS {USER_SCRATCH};") - - logger.info("migration 164 finished") - - -def _batched_delete( - bind: sa.engine.Connection, - *, - scratch: str, - target_table: str, - target_col: str, - batch_size: int, - total: int, - label: str, -) -> None: - """Pop ids from ``scratch`` and delete the matching rows, one committed - batch at a time, logging progress. Atomic per batch: the row delete and the - scratch pop happen in a single statement, so an interrupted run leaves the - scratch table in sync with what has actually been deleted.""" - started = time.monotonic() - last_log = 0.0 - done = 0 - - stmt = sa.text( - f""" - WITH batch AS ( - SELECT id FROM {scratch} LIMIT :n - ), deleted AS ( - DELETE FROM {target_table} - WHERE {target_col} IN (SELECT id FROM batch) - ), popped AS ( - DELETE FROM {scratch} - WHERE id IN (SELECT id FROM batch) - RETURNING id - ) - SELECT count(*) FROM popped - """ - ) - - while True: - popped = bind.execute(stmt, {"n": batch_size}).scalar() or 0 - if popped == 0: - break - done += popped - - now = time.monotonic() - if now - last_log >= LOG_EVERY_SECONDS or done >= total: - elapsed = now - started - pct = (100.0 * done / total) if total else 100.0 - eta = (elapsed / pct * (100.0 - pct)) if pct > 0 else 0.0 - logger.info( - "%s deleted: %.1f%% (%s/%s) elapsed %s eta %s", - label, - pct, - f"{done:,}", - f"{total:,}", - _fmt_duration(elapsed), - _fmt_duration(eta), - ) - last_log = now - - logger.info( - "deleted %s %s in %s", - f"{done:,}", - label, - _fmt_duration(time.monotonic() - started), - ) - - -def downgrade() -> None: - # Irreversible: deleted users and their cascaded data cannot be restored. - # No-op so the downgrade chain can still pass through this revision. - logger.warning( - "migration 164 (remove_inactive_users) is irreversible; " - "downgrade is a no-op (deleted users/data are not restored)" - ) diff --git a/surfsense_backend/alembic/versions/165_add_chunk_position.py b/surfsense_backend/alembic/versions/165_add_chunk_position.py deleted file mode 100644 index b214f3d89..000000000 --- a/surfsense_backend/alembic/versions/165_add_chunk_position.py +++ /dev/null @@ -1,49 +0,0 @@ -"""add chunks.position for explicit document order - -Incremental re-indexing keeps unchanged chunk rows, so auto-increment ids no -longer reflect document order. The ``position`` column makes that order -explicit and is written by the indexing pipeline for every new or re-indexed -document. - -This migration intentionally does NOT backfill historical rows. On a large, -heavily-indexed table (notably a multi-hundred-GB HNSW embedding index) a bulk -UPDATE of every chunk becomes a non-HOT update that rewrites every secondary -index per row -- turning a one-off migration into a multi-day operation. -Instead, existing rows keep ``position = 0`` and therefore order by the -``Chunk.id`` tiebreaker (identical to the pre-feature behavior); new and -re-indexed documents get correct positions from application code, and any -document needing exact order can simply be re-indexed on demand. - -Revision ID: 165 -Revises: 164 -""" - -from collections.abc import Sequence - -from alembic import op - -revision: str = "165" -down_revision: str | None = "164" -branch_labels: str | Sequence[str] | None = None -depends_on: str | Sequence[str] | None = None - -# Leftover UNLOGGED scratch table from earlier backfill attempts; dropped here -# so re-running this migration converges the schema regardless of past state. -SCRATCH_TABLE = "_chunk_position_backfill" - - -def upgrade() -> None: - # Adding a NOT NULL column with a constant default is metadata-only on - # PostgreSQL 11+, so this is fast even on very large tables. IF NOT EXISTS - # makes it a no-op where the column already exists. - op.execute( - "ALTER TABLE chunks ADD COLUMN IF NOT EXISTS position INTEGER NOT NULL DEFAULT 0;" - ) - - # Clean up the scratch table left behind by the abandoned backfill approach. - op.execute(f"DROP TABLE IF EXISTS {SCRATCH_TABLE};") - - -def downgrade() -> None: - op.execute(f"DROP TABLE IF EXISTS {SCRATCH_TABLE};") - op.execute("ALTER TABLE chunks DROP COLUMN IF EXISTS position;") diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/kb_persistence/middleware.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/kb_persistence/middleware.py index a6c83a7d4..ef86eaddd 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/kb_persistence/middleware.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/kb_persistence/middleware.py @@ -241,15 +241,8 @@ async def _create_document( chunk_embeddings = await asyncio.to_thread(embed_texts, chunks) session.add_all( [ - Chunk( - document_id=doc.id, - content=text, - embedding=embedding, - position=i, - ) - for i, (text, embedding) in enumerate( - zip(chunks, chunk_embeddings, strict=True) - ) + Chunk(document_id=doc.id, content=text, embedding=embedding) + for text, embedding in zip(chunks, chunk_embeddings, strict=True) ] ) return doc @@ -296,15 +289,8 @@ async def _update_document( chunk_embeddings = await asyncio.to_thread(embed_texts, chunks) session.add_all( [ - Chunk( - document_id=document.id, - content=text, - embedding=embedding, - position=i, - ) - for i, (text, embedding) in enumerate( - zip(chunks, chunk_embeddings, strict=True) - ) + Chunk(document_id=document.id, content=text, embedding=embedding) + for text, embedding in zip(chunks, chunk_embeddings, strict=True) ] ) return document @@ -489,9 +475,7 @@ async def _load_chunks_for_snapshot( session: AsyncSession, *, doc_id: int ) -> list[dict[str, str]]: rows = await session.execute( - select(Chunk.content) - .where(Chunk.document_id == doc_id) - .order_by(Chunk.position, Chunk.id) + select(Chunk.content).where(Chunk.document_id == doc_id).order_by(Chunk.id) ) return [{"content": row.content} for row in rows.all() if row.content is not None] diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/agent_cache.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/agent_cache.py index 2d3599de0..6ac22e575 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/agent_cache.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/agent_cache.py @@ -57,7 +57,7 @@ async def build_agent_with_cache( mcp_tools_by_agent: dict[str, list[BaseTool]], disabled_tools: list[str] | None, config_id: str | None, - image_gen_model_id_override: int | None = None, + image_generation_config_id_override: int | None = None, ) -> Any: """Compile the multi-agent graph, serving from cache when key components are stable.""" @@ -121,7 +121,7 @@ async def build_agent_with_cache( # Bound into the generate_image subagent tool at construction time, so it # must key the compiled-agent cache to avoid leaking one automation's # image model into another with the same config_id/search_space. - image_gen_model_id_override, + image_generation_config_id_override, ) return await get_cache().get_or_build(cache_key, builder=_build) diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/factory.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/factory.py index 10a734192..adb1bc1ed 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/factory.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/factory.py @@ -72,11 +72,11 @@ async def create_multi_agent_chat_deep_agent( mentioned_document_ids: list[int] | None = None, anon_session_id: str | None = None, filesystem_selection: FilesystemSelection | None = None, - image_gen_model_id: int | None = None, + image_generation_config_id: int | None = None, ): """Deep agent with SurfSense tools/middleware; registry route subagents behind ``task`` when enabled. - ``image_gen_model_id`` overrides the search space's image model for + ``image_generation_config_id`` overrides the search space's image model for this invocation (used by automations to run on their captured model). When ``None``, the ``generate_image`` tool resolves the live search-space pref. """ @@ -147,7 +147,7 @@ async def create_multi_agent_chat_deep_agent( "llm": llm, # Per-invocation image model override (automations run on their captured # model). Reaches the generate_image subagent tool via subagent_dependencies. - "image_gen_model_id_override": image_gen_model_id, + "image_generation_config_id_override": image_generation_config_id, } _t0 = time.perf_counter() @@ -303,7 +303,7 @@ async def create_multi_agent_chat_deep_agent( mcp_tools_by_agent=mcp_tools_by_agent, disabled_tools=disabled_tools, config_id=config_id, - image_gen_model_id_override=image_gen_model_id, + image_generation_config_id_override=image_generation_config_id, ) _perf_log.info( "[create_agent] Middleware stack + graph compiled in %.3fs", diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/backends/kb_postgres.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/backends/kb_postgres.py index e13196537..7b8aaf2b0 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/backends/kb_postgres.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/backends/kb_postgres.py @@ -508,7 +508,7 @@ class KBPostgresBackend(BackendProtocol): chunk_rows = await session.execute( select(Chunk.id, Chunk.content) .where(Chunk.document_id == document.id) - .order_by(Chunk.position, Chunk.id) + .order_by(Chunk.id) ) chunks = [ {"chunk_id": row.id, "content": row.content} for row in chunk_rows.all() @@ -725,7 +725,7 @@ class KBPostgresBackend(BackendProtocol): .join(Document, Document.id == Chunk.document_id) .where(Document.search_space_id == self.search_space_id) .where(Chunk.content.ilike(f"%{pattern}%")) - .order_by(Chunk.document_id, Chunk.position, Chunk.id) + .order_by(Chunk.document_id, Chunk.id) ) chunk_rows = await session.execute(sub) per_doc: dict[int, int] = {} diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/knowledge_search.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/knowledge_search.py index 9ef601791..681e80b0e 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/knowledge_search.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/knowledge_search.py @@ -394,10 +394,7 @@ async def browse_recent_documents( Chunk.document_id, Chunk.content, func.row_number() - .over( - partition_by=Chunk.document_id, - order_by=(Chunk.position, Chunk.id), - ) + .over(partition_by=Chunk.document_id, order_by=Chunk.id) .label("rn"), ) .where(Chunk.document_id.in_(doc_ids)) @@ -407,7 +404,7 @@ async def browse_recent_documents( chunk_query = ( select(numbered.c.chunk_id, numbered.c.document_id, numbered.c.content) .where(numbered.c.rn <= _RECENCY_MAX_CHUNKS_PER_DOC) - .order_by(numbered.c.document_id, numbered.c.rn) + .order_by(numbered.c.document_id, numbered.c.chunk_id) ) chunk_result = await session.execute(chunk_query) fetched_chunks = chunk_result.all() @@ -534,7 +531,7 @@ async def fetch_mentioned_documents( chunk_result = await session.execute( select(Chunk.id, Chunk.content, Chunk.document_id) .where(Chunk.document_id.in_(list(docs.keys()))) - .order_by(Chunk.document_id, Chunk.position, Chunk.id) + .order_by(Chunk.document_id, Chunk.id) ) chunks_by_doc: dict[int, list[dict[str, Any]]] = {doc_id: [] for doc_id in docs} for row in chunk_result.all(): diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py index 736c508ff..7bb4a7c24 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py @@ -10,53 +10,70 @@ from langgraph.types import Command from litellm import aimage_generation from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import selectinload from app.agents.chat.multi_agent_chat.shared.receipts.command import with_receipt from app.agents.chat.multi_agent_chat.shared.receipts.receipt import make_receipt from app.config import config from app.db import ( ImageGeneration, - Model, + ImageGenerationConfig, SearchSpace, shielded_async_session, ) -from app.services.auto_model_pin_service import ( - auto_model_candidates, - choose_auto_model_candidate, -) from app.services.image_gen_router_service import ( IMAGE_GEN_AUTO_MODE_ID, + ImageGenRouterService, is_image_gen_auto_mode, ) -from app.services.model_capabilities import has_capability -from app.services.model_resolver import to_litellm +from app.services.provider_api_base import resolve_api_base from app.utils.signed_image_urls import generate_image_token logger = logging.getLogger(__name__) - -def _get_global_model(model_id: int) -> dict | None: - return next((m for m in config.GLOBAL_MODELS if m.get("id") == model_id), None) +# Provider mapping (same as routes) +_PROVIDER_MAP = { + "OPENAI": "openai", + "AZURE_OPENAI": "azure", + "GOOGLE": "gemini", + "VERTEX_AI": "vertex_ai", + "BEDROCK": "bedrock", + "RECRAFT": "recraft", + "OPENROUTER": "openrouter", + "XINFERENCE": "xinference", + "NSCALE": "nscale", +} -def _get_global_connection(connection_id: int) -> dict | None: - return next( - (c for c in config.GLOBAL_CONNECTIONS if c.get("id") == connection_id), - None, - ) +def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str: + if custom_provider: + return custom_provider + return _PROVIDER_MAP.get(provider.upper(), provider.lower()) + + +def _build_model_string( + provider: str, model_name: str, custom_provider: str | None +) -> str: + return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}" + + +def _get_global_image_gen_config(config_id: int) -> dict | None: + """Get a global image gen config by negative ID.""" + for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS: + if cfg.get("id") == config_id: + return cfg + return None def create_generate_image_tool( search_space_id: int, db_session: AsyncSession, - image_gen_model_id_override: int | None = None, + image_generation_config_id_override: int | None = None, ): """Create ``generate_image`` with bound search space; DB work uses a per-call session. - ``image_gen_model_id_override``: when set (automations running on a - captured model), use this model id instead of reading the search space's - live ``image_gen_model_id``. + ``image_generation_config_id_override``: when set (automations running on a + captured model), use this config id instead of reading the search space's + live ``image_generation_config_id``. """ del db_session # tool uses a fresh per-call session instead @@ -101,23 +118,26 @@ def create_generate_image_tool( # task's session is shared across every tool; without isolation, # autoflushes from a concurrent writer poison this tool too. async with shielded_async_session() as session: - result = await session.execute( - select(SearchSpace).filter(SearchSpace.id == search_space_id) - ) - search_space = result.scalars().first() - if not search_space: - return _failed( - {"error": "Search space not found"}, - error="Search space not found", - ) - - if image_gen_model_id_override is not None: + if image_generation_config_id_override is not None: # Automation run: use the captured image model, insulated from # later search-space changes. No search-space read needed. - config_id = image_gen_model_id_override or IMAGE_GEN_AUTO_MODE_ID - else: config_id = ( - search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID + image_generation_config_id_override or IMAGE_GEN_AUTO_MODE_ID + ) + else: + result = await session.execute( + select(SearchSpace).filter(SearchSpace.id == search_space_id) + ) + search_space = result.scalars().first() + if not search_space: + return _failed( + {"error": "Search space not found"}, + error="Search space not found", + ) + + config_id = ( + search_space.image_generation_config_id + or IMAGE_GEN_AUTO_MODE_ID ) # size/quality/style are intentionally omitted: valid values @@ -127,86 +147,73 @@ def create_generate_image_tool( gen_kwargs["n"] = n if is_image_gen_auto_mode(config_id): - candidates = await auto_model_candidates( - session, - search_space_id=search_space_id, - user_id=search_space.user_id, - capability="image_gen", - ) - if not candidates: + if not ImageGenRouterService.is_initialized(): err = ( - "No image generation models available. " + "No image generation models configured. " "Please add an image model in Settings > Image Models." ) return _failed({"error": err}, error=err) - config_id = int( - choose_auto_model_candidate(candidates, search_space_id)["id"] + response = await ImageGenRouterService.aimage_generation( + prompt=prompt, model="auto", **gen_kwargs ) - - provider_base_url: str | None = None - - if config_id < 0: - global_model = _get_global_model(config_id) - if not global_model or not has_capability( - global_model, "image_gen" - ): - err = f"Image generation model {config_id} not found" - return _failed({"error": err}, error=err) - global_connection = _get_global_connection( - global_model["connection_id"] - ) - if not global_connection: - err = f"Image generation connection for model {config_id} not found" + elif config_id < 0: + cfg = _get_global_image_gen_config(config_id) + if not cfg: + err = f"Image generation config {config_id} not found" return _failed({"error": err}, error=err) - model_string, resolved_kwargs = to_litellm( - global_connection, - global_model["model_id"], + provider_prefix = _resolve_provider_prefix( + cfg.get("provider", ""), cfg.get("custom_provider") ) - gen_kwargs.update(resolved_kwargs) - provider_base_url = resolved_kwargs.get("api_base") + model_string = f"{provider_prefix}/{cfg['model_name']}" + gen_kwargs["api_key"] = cfg.get("api_key") + # Defense-in-depth: an empty ``api_base`` must not fall + # through to LiteLLM's global ``api_base`` (e.g. Azure). + api_base = resolve_api_base( + provider=cfg.get("provider"), + provider_prefix=provider_prefix, + config_api_base=cfg.get("api_base"), + ) + if api_base: + gen_kwargs["api_base"] = api_base + if cfg.get("api_version"): + gen_kwargs["api_version"] = cfg["api_version"] + if cfg.get("litellm_params"): + gen_kwargs.update(cfg["litellm_params"]) response = await aimage_generation( prompt=prompt, model=model_string, **gen_kwargs ) else: - # Positive ID = Model + Connection + # Positive ID = user-created ImageGenerationConfig cfg_result = await session.execute( - select(Model) - .options(selectinload(Model.connection)) - .filter(Model.id == config_id, Model.enabled.is_(True)) + select(ImageGenerationConfig).filter( + ImageGenerationConfig.id == config_id + ) ) - db_model = cfg_result.scalars().first() - if ( - not db_model - or not db_model.connection - or not db_model.connection.enabled - ): - err = f"Image generation model {config_id} not found" - return _failed({"error": err}, error=err) - conn = db_model.connection - if ( - conn.search_space_id is not None - and conn.search_space_id != search_space_id - ): - err = f"Image generation model {config_id} not found" - return _failed({"error": err}, error=err) - if ( - conn.user_id is not None - and conn.user_id != search_space.user_id - ): - err = f"Image generation model {config_id} not found" - return _failed({"error": err}, error=err) - if not has_capability(db_model, "image_gen"): - err = f"Model {config_id} is not image-generation capable" + db_cfg = cfg_result.scalars().first() + if not db_cfg: + err = f"Image generation config {config_id} not found" return _failed({"error": err}, error=err) - model_string, resolved_kwargs = to_litellm( - db_model.connection, - db_model.model_id, + provider_prefix = _resolve_provider_prefix( + db_cfg.provider.value, db_cfg.custom_provider ) - gen_kwargs.update(resolved_kwargs) - provider_base_url = resolved_kwargs.get("api_base") + model_string = f"{provider_prefix}/{db_cfg.model_name}" + gen_kwargs["api_key"] = db_cfg.api_key + # Defense-in-depth: an empty ``api_base`` must not fall + # through to LiteLLM's global ``api_base`` (e.g. Azure). + api_base = resolve_api_base( + provider=db_cfg.provider.value, + provider_prefix=provider_prefix, + config_api_base=db_cfg.api_base, + ) + if api_base: + gen_kwargs["api_base"] = api_base + if db_cfg.api_version: + gen_kwargs["api_version"] = db_cfg.api_version + if db_cfg.litellm_params: + gen_kwargs.update(db_cfg.litellm_params) response = await aimage_generation( prompt=prompt, model=model_string, **gen_kwargs @@ -223,7 +230,7 @@ def create_generate_image_tool( prompt=prompt, model=getattr(response, "_hidden_params", {}).get("model"), n=n, - image_gen_model_id=config_id, + image_generation_config_id=config_id, response_data=response_dict, search_space_id=search_space_id, access_token=access_token, @@ -245,19 +252,8 @@ def create_generate_image_tool( # b64_json (e.g. gpt-image-1) is served via our backend endpoint so # megabytes of base64 don't bloat the LLM context. - # Some OpenAI-compatible backends (e.g. Xinference) return a relative - # URL like /files/image.png. Browsers can't resolve these, so we - # prepend the provider's base origin when the URL starts with "/". if first_image.get("url"): - raw_url: str = first_image["url"] - if raw_url.startswith("/") and provider_base_url: - from urllib.parse import urlparse - - parsed = urlparse(provider_base_url) - origin = f"{parsed.scheme}://{parsed.netloc}" - image_url = f"{origin}{raw_url}" - else: - image_url = raw_url + image_url = first_image["url"] elif first_image.get("b64_json"): backend_url = config.BACKEND_URL or "http://localhost:8000" image_url = ( diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/index.py index 8de95f2df..b968c1701 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/index.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/index.py @@ -51,6 +51,8 @@ def load_tools( create_generate_image_tool( search_space_id=d["search_space_id"], db_session=d["db_session"], - image_gen_model_id_override=d.get("image_gen_model_id_override"), + image_generation_config_id_override=d.get( + "image_generation_config_id_override" + ), ), ] diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/knowledge_base.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/knowledge_base.py index d89124990..e99e0291a 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/knowledge_base.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/knowledge_base.py @@ -122,7 +122,7 @@ async def _browse_recent_documents( chunk_query = ( select(Chunk) .where(Chunk.document_id.in_(doc_ids)) - .order_by(Chunk.document_id, Chunk.position, Chunk.id) + .order_by(Chunk.document_id, Chunk.id) ) chunk_result = await session.execute(chunk_query) raw_chunks = chunk_result.scalars().all() diff --git a/surfsense_backend/app/agents/chat/runtime/llm_config.py b/surfsense_backend/app/agents/chat/runtime/llm_config.py index e7f2a0f0d..aad432edb 100644 --- a/surfsense_backend/app/agents/chat/runtime/llm_config.py +++ b/surfsense_backend/app/agents/chat/runtime/llm_config.py @@ -2,9 +2,9 @@ LLM configuration utilities for SurfSense agents. This module provides functions for loading LLM configurations from: -1. Auto mode (ID 0) - Resolved by callers to a concrete model-connection model +1. Auto mode (ID 0) - Uses LiteLLM Router for load balancing 2. YAML files (global configs with negative IDs) -3. Database model-connections table (user-created configs with positive IDs) +3. Database NewLLMConfig table (user-created configs with positive IDs) It also provides utilities for creating ChatLiteLLM instances and managing prompt configurations. @@ -24,6 +24,8 @@ from langchain_core.messages import AIMessage, BaseMessage from langchain_core.outputs import ChatGenerationChunk, ChatResult from langchain_litellm import ChatLiteLLM from litellm import get_model_info +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession from app.agents.chat.runtime.prompt_caching import ( apply_litellm_prompt_caching, @@ -31,7 +33,10 @@ from app.agents.chat.runtime.prompt_caching import ( from app.services.llm_router_service import ( AUTO_MODE_ID, ChatLiteLLMRouter, + LLMRouterService, _sanitize_content, + get_auto_mode_llm, + is_auto_mode, ) @@ -46,19 +51,16 @@ def _sanitize_messages(messages: list[BaseMessage]) -> list[BaseMessage]: reject the blank text. The OpenAI spec says ``content`` should be ``null`` when an assistant message only carries tool calls. """ - sanitized: list[BaseMessage] = [] for msg in messages: - next_msg = msg.model_copy(deep=True) - if isinstance(next_msg.content, list): - next_msg.content = _sanitize_content(next_msg.content) + if isinstance(msg.content, list): + msg.content = _sanitize_content(msg.content) if ( - isinstance(next_msg, AIMessage) - and (not next_msg.content or next_msg.content == "") - and getattr(next_msg, "tool_calls", None) + isinstance(msg, AIMessage) + and (not msg.content or msg.content == "") + and getattr(msg, "tool_calls", None) ): - next_msg.content = None # type: ignore[assignment] - sanitized.append(next_msg) - return sanitized + msg.content = None # type: ignore[assignment] + return messages class SanitizedChatLiteLLM(ChatLiteLLM): @@ -89,21 +91,13 @@ class SanitizedChatLiteLLM(ChatLiteLLM): ): yield chunk - async def _agenerate( - self, - messages: list[BaseMessage], - stop: list[str] | None = None, - run_manager: AsyncCallbackManagerForLLMRun | None = None, - stream: bool | None = None, - **kwargs: Any, - ) -> ChatResult: - return await super()._agenerate( - _sanitize_messages(messages), - stop=stop, - run_manager=run_manager, - stream=stream, - **kwargs, - ) + +# Re-exported under the historical name ``PROVIDER_MAP``. Source of truth lives +# in provider_capabilities so the YAML loader can resolve prefixes during +# app.config init without importing the agent/tools tree. +from app.services.provider_capabilities import ( # noqa: E402 + _PROVIDER_PREFIX_MAP as PROVIDER_MAP, +) def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None: @@ -127,9 +121,8 @@ class AgentConfig: """ Complete configuration for the SurfSense agent. - This combines resolved model settings with prompt configuration. - Supports Auto mode metadata (ID 0). Runtime callers must resolve Auto to - a concrete global or BYOK model before constructing ChatLiteLLM. + This combines LLM settings with prompt configuration from NewLLMConfig. + Supports Auto mode (ID 0) which uses LiteLLM Router for load balancing. """ # LLM Model Settings @@ -177,7 +170,7 @@ class AgentConfig: use_default_system_instructions=True, citations_enabled=True, config_id=AUTO_MODE_ID, - config_name="Auto", + config_name="Auto (Fastest)", is_auto_mode=True, billing_tier="free", is_premium=False, @@ -188,21 +181,64 @@ class AgentConfig: supports_image_input=True, ) + @classmethod + def from_new_llm_config(cls, config) -> "AgentConfig": + """Build an AgentConfig from a NewLLMConfig database model.""" + # Lazy import: keeps provider_capabilities (and litellm) out of init order. + from app.services.provider_capabilities import derive_supports_image_input + + provider_value = ( + config.provider.value + if hasattr(config.provider, "value") + else str(config.provider) + ) + litellm_params = config.litellm_params or {} + base_model = ( + litellm_params.get("base_model") + if isinstance(litellm_params, dict) + else None + ) + + return cls( + provider=provider_value, + model_name=config.model_name, + api_key=config.api_key, + api_base=config.api_base, + custom_provider=config.custom_provider, + litellm_params=config.litellm_params, + system_instructions=config.system_instructions, + use_default_system_instructions=config.use_default_system_instructions, + citations_enabled=config.citations_enabled, + config_id=config.id, + config_name=config.name, + is_auto_mode=False, + billing_tier="free", + is_premium=False, + anonymous_enabled=False, + quota_reserve_tokens=None, + # BYOK rows have no curated flag; ask LiteLLM (default-allow on + # unknown). The streaming safety net still blocks explicit text-only. + supports_image_input=derive_supports_image_input( + provider=provider_value, + model_name=config.model_name, + base_model=base_model, + custom_provider=config.custom_provider, + ), + ) + @classmethod def from_yaml_config(cls, yaml_config: dict) -> "AgentConfig": """Build an AgentConfig from a YAML configuration dictionary. - Supports prompt fields such as system_instructions, - use_default_system_instructions, and citations_enabled. + Supports the same prompt fields as NewLLMConfig (system_instructions, + use_default_system_instructions, citations_enabled). """ # Lazy import: keeps provider_capabilities (and litellm) out of init order. from app.services.provider_capabilities import derive_supports_image_input system_instructions = yaml_config.get("system_instructions", "") - provider = yaml_config.get("provider") or yaml_config.get( - "litellm_provider", "" - ) + provider = yaml_config.get("provider", "").upper() model_name = yaml_config.get("model_name", "") custom_provider = yaml_config.get("custom_provider") litellm_params = yaml_config.get("litellm_params") or {} @@ -288,15 +324,93 @@ def load_global_llm_config_by_id(llm_config_id: int) -> dict | None: return load_llm_config_from_yaml(llm_config_id) +async def load_new_llm_config_from_db( + session: AsyncSession, + config_id: int, +) -> "AgentConfig | None": + """Load a NewLLMConfig from the database by ID.""" + from app.db import NewLLMConfig + + try: + result = await session.execute( + select(NewLLMConfig).filter(NewLLMConfig.id == config_id) + ) + config = result.scalars().first() + + if not config: + print(f"Error: NewLLMConfig with id {config_id} not found") + return None + + return AgentConfig.from_new_llm_config(config) + except Exception as e: + print(f"Error loading NewLLMConfig from database: {e}") + return None + + +async def load_agent_llm_config_for_search_space( + session: AsyncSession, + search_space_id: int, +) -> "AgentConfig | None": + """Load the agent LLM config for a search space via its agent_llm_id. + + Positive id -> DB; negative -> YAML; None -> first global config (-1). + """ + from app.db import SearchSpace + + try: + result = await session.execute( + select(SearchSpace).filter(SearchSpace.id == search_space_id) + ) + search_space = result.scalars().first() + + if not search_space: + print(f"Error: SearchSpace with id {search_space_id} not found") + return None + + config_id = ( + search_space.agent_llm_id if search_space.agent_llm_id is not None else -1 + ) + return await load_agent_config(session, config_id, search_space_id) + except Exception as e: + print(f"Error loading agent LLM config for search space {search_space_id}: {e}") + return None + + +async def load_agent_config( + session: AsyncSession, + config_id: int, + search_space_id: int | None = None, +) -> "AgentConfig | None": + """Main config loader: id 0 -> Auto mode; negative -> YAML; positive -> DB.""" + if is_auto_mode(config_id): + if not LLMRouterService.is_initialized(): + print("Error: Auto mode requested but LLM Router not initialized") + return None + return AgentConfig.from_auto_mode() + + if config_id < 0: + # In-memory covers static YAML + dynamic OpenRouter configs. + from app.config import config as app_config + + for cfg in app_config.GLOBAL_LLM_CONFIGS: + if cfg.get("id") == config_id: + return AgentConfig.from_yaml_config(cfg) + yaml_config = load_llm_config_from_yaml(config_id) + if yaml_config: + return AgentConfig.from_yaml_config(yaml_config) + return None + else: + return await load_new_llm_config_from_db(session, config_id) + + def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None: """Create a ChatLiteLLM instance from a global LLM config dictionary.""" if llm_config.get("custom_provider"): model_string = f"{llm_config['custom_provider']}/{llm_config['model_name']}" else: - provider = llm_config.get("provider") or llm_config.get( - "litellm_provider", "openai" - ) - model_string = f"{provider}/{llm_config['model_name']}" + provider = llm_config.get("provider", "").upper() + provider_prefix = PROVIDER_MAP.get(provider, provider.lower()) + model_string = f"{provider_prefix}/{llm_config['model_name']}" litellm_kwargs = { "model": model_string, @@ -319,17 +433,29 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None: def create_chat_litellm_from_agent_config( agent_config: AgentConfig, ) -> ChatLiteLLM | ChatLiteLLMRouter | None: - """Create a ChatLiteLLM from an already resolved concrete model config.""" + """Create a ChatLiteLLM (or, for Auto mode, a load-balancing router) from config.""" if agent_config.is_auto_mode: - print( - "Error: Auto mode must be resolved to a concrete model before LLM creation" - ) - return None + if not LLMRouterService.is_initialized(): + print("Error: Auto mode requested but LLM Router not initialized") + return None + try: + router_llm = get_auto_mode_llm() + if router_llm is not None: + # Universal injection points only: auto-mode fans out across + # providers, so provider-specific kwargs have no known target. + apply_litellm_prompt_caching(router_llm, agent_config=agent_config) + return router_llm + except Exception as e: + print(f"Error creating ChatLiteLLMRouter: {e}") + return None if agent_config.custom_provider: model_string = f"{agent_config.custom_provider}/{agent_config.model_name}" else: - model_string = f"{agent_config.provider}/{agent_config.model_name}" + provider_prefix = PROVIDER_MAP.get( + agent_config.provider, agent_config.provider.lower() + ) + model_string = f"{provider_prefix}/{agent_config.model_name}" litellm_kwargs = { "model": model_string, diff --git a/surfsense_backend/app/app.py b/surfsense_backend/app/app.py index 6dfe6a776..d3f5dce2a 100644 --- a/surfsense_backend/app/app.py +++ b/surfsense_backend/app/app.py @@ -33,6 +33,7 @@ from app.config import ( initialize_llm_router, initialize_openrouter_integration, initialize_pricing_registration, + initialize_vision_llm_router, ) from app.db import User, create_db_and_tables, get_async_session from app.exceptions import GENERIC_5XX_MESSAGE, ISSUES_URL, SurfSenseError @@ -621,6 +622,7 @@ async def lifespan(app: FastAPI): initialize_pricing_registration() initialize_llm_router() initialize_image_gen_router() + initialize_vision_llm_router() # Phase 1.7 — JIT warmup. Bounded so a stuck warmup never delays # worker readiness. ``shield`` so Uvicorn cancelling startup diff --git a/surfsense_backend/app/automations/actions/builtin/agent_task/dependencies.py b/surfsense_backend/app/automations/actions/builtin/agent_task/dependencies.py index c9584ae2a..4ef8c52bf 100644 --- a/surfsense_backend/app/automations/actions/builtin/agent_task/dependencies.py +++ b/surfsense_backend/app/automations/actions/builtin/agent_task/dependencies.py @@ -39,31 +39,31 @@ async def build_dependencies( *, session: AsyncSession, search_space_id: int, - chat_model_id: int | None = None, - image_gen_model_id: int | None = None, - vision_model_id: int | None = None, + agent_llm_id: int | None = None, + image_generation_config_id: int | None = None, + vision_llm_config_id: int | None = None, ) -> AgentDependencies: """Load the LLM bundle, connector service, and a per-invoke in-memory checkpointer. - Resolves the chat model from the automation's *captured* model snapshot - (``chat_model_id``) so runs are insulated from later chat/search-space model + Resolves the agent LLM from the automation's *captured* model snapshot + (``agent_llm_id``) so runs are insulated from later chat/search-space model changes. The model policy is enforced here as a runtime backstop: a captured model that is no longer billable (e.g. a premium global config was removed) fails the run clearly instead of silently consuming a free model. - When ``chat_model_id`` is ``None`` (no captured snapshot — defensive fallback), - fall back to the live search space's ``chat_model_id`` and validate that. + When ``agent_llm_id`` is ``None`` (no captured snapshot — defensive fallback), + fall back to the live search space's ``agent_llm_id`` and validate that. """ - if chat_model_id is not None: + if agent_llm_id is not None: try: assert_models_billable( - chat_model_id=chat_model_id, - image_gen_model_id=image_gen_model_id, - vision_model_id=vision_model_id, + agent_llm_id=agent_llm_id, + image_generation_config_id=image_generation_config_id, + vision_llm_config_id=vision_llm_config_id, ) except AutomationModelPolicyError as exc: raise DependencyError(str(exc)) from exc - resolved_chat_model_id = chat_model_id or 0 + resolved_agent_llm_id = agent_llm_id or 0 else: search_space = await session.get(SearchSpace, search_space_id) if search_space is None: @@ -72,15 +72,15 @@ async def build_dependencies( assert_automation_models_billable(search_space) except AutomationModelPolicyError as exc: raise DependencyError(str(exc)) from exc - resolved_chat_model_id = search_space.chat_model_id or 0 + resolved_agent_llm_id = search_space.agent_llm_id or 0 llm, agent_config, err = await load_llm_bundle( session, - config_id=resolved_chat_model_id, + config_id=resolved_agent_llm_id, search_space_id=search_space_id, ) if err is not None or llm is None: - raise DependencyError(err or "failed to load chat model config") + raise DependencyError(err or "failed to load agent LLM config") connector_service, firecrawl_api_key = await setup_connector_and_firecrawl( session, search_space_id=search_space_id diff --git a/surfsense_backend/app/automations/actions/builtin/agent_task/invoke.py b/surfsense_backend/app/automations/actions/builtin/agent_task/invoke.py index c3a35930d..aa96e4f6e 100644 --- a/surfsense_backend/app/automations/actions/builtin/agent_task/invoke.py +++ b/surfsense_backend/app/automations/actions/builtin/agent_task/invoke.py @@ -150,9 +150,9 @@ async def run_agent_task( deps = await build_dependencies( session=agent_session, search_space_id=ctx.search_space_id, - chat_model_id=ctx.chat_model_id, - image_gen_model_id=ctx.image_gen_model_id, - vision_model_id=ctx.vision_model_id, + agent_llm_id=ctx.agent_llm_id, + image_generation_config_id=ctx.image_generation_config_id, + vision_llm_config_id=ctx.vision_llm_config_id, ) agent = await create_multi_agent_chat_deep_agent( @@ -167,7 +167,7 @@ async def run_agent_task( firecrawl_api_key=deps.firecrawl_api_key, thread_visibility=ChatVisibility.PRIVATE, mentioned_document_ids=mentioned_document_ids, - image_gen_model_id=ctx.image_gen_model_id, + image_generation_config_id=ctx.image_generation_config_id, ) agent_query, runtime_context = await _resolve_mention_context( diff --git a/surfsense_backend/app/automations/actions/types.py b/surfsense_backend/app/automations/actions/types.py index 3ee427512..453721a43 100644 --- a/surfsense_backend/app/automations/actions/types.py +++ b/surfsense_backend/app/automations/actions/types.py @@ -23,9 +23,9 @@ class ActionContext: # Captured model snapshot from the automation definition (``definition.models``), # resolved per run instead of the live search space. ``None`` falls back to the # search space's current prefs (defensive; should not happen post-capture). - chat_model_id: int | None = None - image_gen_model_id: int | None = None - vision_model_id: int | None = None + agent_llm_id: int | None = None + image_generation_config_id: int | None = None + vision_llm_config_id: int | None = None ActionHandler = Callable[[dict[str, Any]], Awaitable[Any]] diff --git a/surfsense_backend/app/automations/runtime/executor.py b/surfsense_backend/app/automations/runtime/executor.py index bcdab3940..da249d8e5 100644 --- a/surfsense_backend/app/automations/runtime/executor.py +++ b/surfsense_backend/app/automations/runtime/executor.py @@ -132,7 +132,9 @@ def _build_action_ctx( step_id=step.step_id, search_space_id=automation.search_space_id, creator_user_id=automation.created_by_user_id, - chat_model_id=models.chat_model_id if models else None, - image_gen_model_id=models.image_gen_model_id if models else None, - vision_model_id=models.vision_model_id if models else None, + agent_llm_id=models.agent_llm_id if models else None, + image_generation_config_id=( + models.image_generation_config_id if models else None + ), + vision_llm_config_id=models.vision_llm_config_id if models else None, ) diff --git a/surfsense_backend/app/automations/schemas/definition/envelope.py b/surfsense_backend/app/automations/schemas/definition/envelope.py index 787534d4a..7ca55b1ce 100644 --- a/surfsense_backend/app/automations/schemas/definition/envelope.py +++ b/surfsense_backend/app/automations/schemas/definition/envelope.py @@ -14,16 +14,16 @@ from .trigger_spec import TriggerSpec class AutomationModels(BaseModel): """Captured model profile for an automation. - Snapshotted from the search space's model roles at create time so runs are - insulated from later chat/search-space model changes. Model-id conventions + Snapshotted from the search space's preferences at create time so runs are + insulated from later chat/search-space model changes. Config-id conventions match the shared scheme (``0`` Auto, ``< 0`` global, ``> 0`` BYOK). """ model_config = ConfigDict(extra="forbid") - chat_model_id: int = 0 - image_gen_model_id: int = 0 - vision_model_id: int = 0 + agent_llm_id: int = 0 + image_generation_config_id: int = 0 + vision_llm_config_id: int = 0 class AutomationDefinition(BaseModel): diff --git a/surfsense_backend/app/automations/services/automation.py b/surfsense_backend/app/automations/services/automation.py index 1d371c35d..4227161e2 100644 --- a/surfsense_backend/app/automations/services/automation.py +++ b/surfsense_backend/app/automations/services/automation.py @@ -57,9 +57,9 @@ class AutomationService: else: search_space = await self._assert_models_billable(payload.search_space_id) payload.definition.models = AutomationModels( - chat_model_id=search_space.chat_model_id or 0, - image_gen_model_id=search_space.image_gen_model_id or 0, - vision_model_id=search_space.vision_model_id or 0, + agent_llm_id=search_space.agent_llm_id or 0, + image_generation_config_id=search_space.image_generation_config_id or 0, + vision_llm_config_id=search_space.vision_llm_config_id or 0, ) automation = Automation( @@ -225,9 +225,9 @@ class AutomationService: """ try: assert_models_billable( - chat_model_id=models.chat_model_id, - image_gen_model_id=models.image_gen_model_id, - vision_model_id=models.vision_model_id, + agent_llm_id=models.agent_llm_id, + image_generation_config_id=models.image_generation_config_id, + vision_llm_config_id=models.vision_llm_config_id, ) except AutomationModelPolicyError as exc: raise HTTPException(status_code=422, detail=str(exc)) from exc diff --git a/surfsense_backend/app/automations/services/model_policy.py b/surfsense_backend/app/automations/services/model_policy.py index e18264246..7e3e46b61 100644 --- a/surfsense_backend/app/automations/services/model_policy.py +++ b/surfsense_backend/app/automations/services/model_policy.py @@ -2,11 +2,11 @@ Automations run unattended, so every run must be **billable**: it may only use either a premium global model (``billing_tier == "premium"``) or a user-provided -BYOK model (a positive model id pointing at a per-user/per-space DB row). Free +BYOK model (a positive config id pointing at a per-user/per-space DB row). Free global models and Auto mode are blocked, because Auto can dispatch to a free deployment and free models aren't metered in premium credits. -Model id conventions (shared across chat / image / vision): +Config id conventions (shared across chat / image / vision): - ``id == 0`` → Auto mode (``AUTO_MODE_ID`` / ``IMAGE_GEN_AUTO_MODE_ID`` / ``VISION_AUTO_MODE_ID``). Blocked. - ``id < 0`` → global YAML/OpenRouter config. Allowed only if premium. @@ -24,45 +24,70 @@ from typing import TYPE_CHECKING, Literal if TYPE_CHECKING: from app.db import SearchSpace -ModelKind = Literal["chat", "image", "vision"] +ModelKind = Literal["llm", "image", "vision"] _KIND_LABEL: dict[ModelKind, str] = { - "chat": "chat model", + "llm": "agent LLM", "image": "image generation model", "vision": "vision model", } -def _is_premium_global(model_id: int) -> bool: - """Return True if a negative (global) model id is a premium tier model.""" +def _is_premium_global(kind: ModelKind, config_id: int) -> bool: + """Return True if a negative (global) config id is a premium tier model.""" from app.config import config as app_config - model = next((m for m in app_config.GLOBAL_MODELS if m.get("id") == model_id), None) - if not model: + cfg: dict | None = None + if kind == "llm": + from app.agents.chat.runtime.llm_config import ( + load_global_llm_config_by_id, + ) + + cfg = load_global_llm_config_by_id(config_id) + elif kind == "image": + cfg = next( + ( + c + for c in app_config.GLOBAL_IMAGE_GEN_CONFIGS + if c.get("id") == config_id + ), + None, + ) + else: # vision + cfg = next( + ( + c + for c in app_config.GLOBAL_VISION_LLM_CONFIGS + if c.get("id") == config_id + ), + None, + ) + + if not cfg: return False - return str(model.get("billing_tier", "free")).lower() == "premium" + return str(cfg.get("billing_tier", "free")).lower() == "premium" -def _classify(kind: ModelKind, model_id: int | None) -> tuple[bool, str]: - """Classify a resolved model id as allowed or blocked. +def _classify(kind: ModelKind, config_id: int | None) -> tuple[bool, str]: + """Classify a resolved config id as allowed or blocked. Returns ``(allowed, reason)``; ``reason`` is empty when allowed. """ label = _KIND_LABEL[kind] - if model_id is None or model_id == 0: + if config_id is None or config_id == 0: return ( False, f"The {label} is set to Auto mode. Automations require an explicit " "premium model or your own (BYOK) model so every run is billable.", ) - if model_id > 0: - # Positive id -> user/search-space BYOK model. Always allowed. + if config_id > 0: + # Positive id → user-owned BYOK config. Always allowed. return True, "" - # Negative id -> global model. Allowed only if premium. - if _is_premium_global(model_id): + # Negative id → global config. Allowed only if premium. + if _is_premium_global(kind, config_id): return True, "" return ( @@ -74,27 +99,27 @@ def _classify(kind: ModelKind, model_id: int | None) -> tuple[bool, str]: def get_model_eligibility( *, - chat_model_id: int | None, - image_gen_model_id: int | None, - vision_model_id: int | None, + agent_llm_id: int | None, + image_generation_config_id: int | None, + vision_llm_config_id: int | None, ) -> dict: - """Return ``{"allowed": bool, "violations": [...]}`` for explicit model ids. + """Return ``{"allowed": bool, "violations": [...]}`` for explicit config ids. The ID-based core shared by both the search-space path (creation/eligibility) and the captured-snapshot path (runtime backstop). Each violation is - ``{"kind", "model_id", "reason"}``. + ``{"kind", "config_id", "reason"}``. """ checks: list[tuple[ModelKind, int | None]] = [ - ("chat", chat_model_id), - ("image", image_gen_model_id), - ("vision", vision_model_id), + ("llm", agent_llm_id), + ("image", image_generation_config_id), + ("vision", vision_llm_config_id), ] violations: list[dict] = [] - for kind, model_id in checks: - allowed, reason = _classify(kind, model_id) + for kind, config_id in checks: + allowed, reason = _classify(kind, config_id) if not allowed: - violations.append({"kind": kind, "model_id": model_id, "reason": reason}) + violations.append({"kind": kind, "config_id": config_id, "reason": reason}) return {"allowed": not violations, "violations": violations} @@ -106,9 +131,9 @@ def get_automation_model_eligibility(search_space: SearchSpace) -> dict: wrapper over :func:`get_model_eligibility`. """ return get_model_eligibility( - chat_model_id=search_space.chat_model_id, - image_gen_model_id=search_space.image_gen_model_id, - vision_model_id=search_space.vision_model_id, + agent_llm_id=search_space.agent_llm_id, + image_generation_config_id=search_space.image_generation_config_id, + vision_llm_config_id=search_space.vision_llm_config_id, ) @@ -125,9 +150,9 @@ class AutomationModelPolicyError(Exception): def assert_models_billable( *, - chat_model_id: int | None, - image_gen_model_id: int | None, - vision_model_id: int | None, + agent_llm_id: int | None, + image_generation_config_id: int | None, + vision_llm_config_id: int | None, ) -> None: """Raise :class:`AutomationModelPolicyError` if any explicit id is not billable. @@ -135,9 +160,9 @@ def assert_models_billable( captured model snapshot. """ result = get_model_eligibility( - chat_model_id=chat_model_id, - image_gen_model_id=image_gen_model_id, - vision_model_id=vision_model_id, + agent_llm_id=agent_llm_id, + image_generation_config_id=image_generation_config_id, + vision_llm_config_id=vision_llm_config_id, ) if not result["allowed"]: raise AutomationModelPolicyError(result["violations"]) diff --git a/surfsense_backend/app/celery_app.py b/surfsense_backend/app/celery_app.py index 704c9cf9b..5eebffd65 100644 --- a/surfsense_backend/app/celery_app.py +++ b/surfsense_backend/app/celery_app.py @@ -115,12 +115,14 @@ def init_worker(**kwargs): initialize_llm_router, initialize_openrouter_integration, initialize_pricing_registration, + initialize_vision_llm_router, ) initialize_openrouter_integration() initialize_pricing_registration() initialize_llm_router() initialize_image_gen_router() + initialize_vision_llm_router() # Celery configuration, sourced from the central Config singleton @@ -190,8 +192,6 @@ celery_app = Celery( "app.tasks.celery_tasks.stripe_reconciliation_task", "app.tasks.celery_tasks.auto_reload_task", "app.tasks.celery_tasks.gateway_tasks", - "app.etl_pipeline.cache.eviction.task", - "app.indexing_pipeline.cache.eviction.task", "app.automations.tasks.execute_run", "app.automations.triggers.builtin.schedule.selector", "app.automations.triggers.builtin.event.selector", @@ -306,18 +306,6 @@ celery_app.conf.beat_schedule = { "schedule": crontab(hour="3", minute="17"), "options": {"expires": 600}, }, - # Prune the ETL parse cache (TTL + size budget) once daily, off-peak. - "evict-etl-cache": { - "task": "evict_etl_cache", - "schedule": crontab(hour="4", minute="0"), - "options": {"expires": 600}, - }, - # Prune the embedding cache (chunk+embedding sets) once daily, off-peak. - "evict-embedding-cache": { - "task": "evict_embedding_cache", - "schedule": crontab(hour="4", minute="30"), - "options": {"expires": 600}, - }, # Fire due automation schedule triggers (Beat entry owned by the schedule # trigger; see app.automations.triggers.builtin.schedule.source). **SCHEDULE_BEAT_SCHEDULE, diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index 63be54654..bbaf3ac55 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -78,7 +78,8 @@ def load_global_llm_configs(): # stamps) never leak into the cached YAML structure. configs = copy.deepcopy(data.get("global_llm_configs", [])) - # Lazy import keeps the `app.config` -> `app.services` edge one-way. + # Lazy import keeps the `app.config` -> `app.services` edge one-way + # and matches the `provider_api_base` pattern used elsewhere. from app.services.provider_capabilities import derive_supports_image_input seen_slugs: dict[str, int] = {} @@ -103,7 +104,7 @@ def load_global_llm_configs(): else None ) cfg["supports_image_input"] = derive_supports_image_input( - provider=cfg.get("provider") or cfg.get("litellm_provider"), + provider=cfg.get("provider"), model_name=cfg.get("model_name"), base_model=base_model, custom_provider=cfg.get("custom_provider"), @@ -119,10 +120,10 @@ def load_global_llm_configs(): else: seen_slugs[slug] = cfg.get("id", 0) - # Stamp Auto ranking metadata. YAML configs are always + # Stamp Auto (Fastest) ranking metadata. YAML configs are always # Tier A — operator-curated, locked first when premium-eligible. # The OpenRouter refresh tick later re-stamps health for any cfg - # whose provider == "openrouter" via _enrich_health. + # whose provider == "OPENROUTER" via _enrich_health. try: from app.services.quality_score import static_score_yaml @@ -132,7 +133,7 @@ def load_global_llm_configs(): cfg["quality_score_static"] = static_q cfg["quality_score"] = static_q cfg["quality_score_health"] = None - # YAML cfgs whose provider is openrouter are also subject + # YAML cfgs whose provider is OPENROUTER are also subject # to health gating against their own /endpoints data — a # hand-picked dead OR model is still dead. _enrich_health # re-stamps health_gated for them on the next refresh tick. @@ -210,6 +211,42 @@ def load_global_image_gen_configs(): return [] +def load_global_vision_llm_configs(): + data = _global_config_data() + if not data: + return [] + + try: + configs = copy.deepcopy(data.get("global_vision_llm_configs", []) or []) + for cfg in configs: + if isinstance(cfg, dict): + cfg.setdefault("billing_tier", "free") + return configs + except Exception as e: + print(f"Warning: Failed to load global vision LLM configs: {e}") + return [] + + +def load_vision_llm_router_settings(): + default_settings = { + "routing_strategy": "usage-based-routing", + "num_retries": 3, + "allowed_fails": 3, + "cooldown_time": 60, + } + + data = _global_config_data() + if not data: + return default_settings + + try: + settings = data.get("vision_llm_router_settings", {}) + return {**default_settings, **settings} + except Exception as e: + print(f"Warning: Failed to load vision LLM router settings: {e}") + return default_settings + + def load_image_gen_router_settings(): """ Load router settings for image generation Auto mode from YAML file. @@ -326,8 +363,8 @@ def initialize_openrouter_integration(): else: print("Info: OpenRouter integration enabled but no models fetched") - # Image generation emissions reuse the catalogue already cached by - # ``service.initialize`` + # Image generation + vision LLM emissions are opt-in (issue L). + # Both reuse the catalogue already cached by ``service.initialize`` # so we don't make additional network calls here. if settings.get("image_generation_enabled"): try: @@ -341,26 +378,21 @@ def initialize_openrouter_integration(): except Exception as e: print(f"Warning: Failed to inject OpenRouter image-gen configs: {e}") - refresh_global_model_catalog() + if settings.get("vision_enabled"): + try: + vision_configs = service.get_vision_llm_configs() + if vision_configs: + config.GLOBAL_VISION_LLM_CONFIGS.extend(vision_configs) + print( + f"Info: OpenRouter integration added {len(vision_configs)} " + f"vision LLM models" + ) + except Exception as e: + print(f"Warning: Failed to inject OpenRouter vision-LLM configs: {e}") except Exception as e: print(f"Warning: Failed to initialize OpenRouter integration: {e}") -def materialize_global_configs(): - from app.services.global_model_catalog import materialize_global_model_catalog - - return materialize_global_model_catalog( - chat_configs=getattr(config, "GLOBAL_LLM_CONFIGS", []), - image_configs=getattr(config, "GLOBAL_IMAGE_GEN_CONFIGS", []), - ) - - -def refresh_global_model_catalog(): - connections, models = materialize_global_configs() - config.GLOBAL_CONNECTIONS = connections - config.GLOBAL_MODELS = models - - def initialize_pricing_registration(): """ Teach LiteLLM the per-token cost of every deployment in @@ -398,10 +430,7 @@ def initialize_llm_router(): router_settings = config.ROUTER_SETTINGS if not all_configs: - print( - "Info: No global LLM configs found; global Auto pool is unavailable. " - "Auto can still use enabled BYOK models." - ) + print("Info: No global LLM configs found, Auto mode will not be available") return try: @@ -446,6 +475,32 @@ def initialize_image_gen_router(): print(f"Warning: Failed to initialize Image Generation Router: {e}") +def initialize_vision_llm_router(): + vision_configs = load_global_vision_llm_configs() + # Reuse the router settings already parsed at Config construction. The + # *configs* list is intentionally re-read from YAML (it must exclude the + # OpenRouter-injected dynamic models held in config.GLOBAL_VISION_LLM_CONFIGS). + router_settings = config.VISION_LLM_ROUTER_SETTINGS + + if not vision_configs: + print( + "Info: No global vision LLM configs found, " + "Vision LLM Auto mode will not be available" + ) + return + + try: + from app.services.vision_llm_router_service import VisionLLMRouterService + + VisionLLMRouterService.initialize(vision_configs, router_settings) + print( + f"Info: Vision LLM Router initialized with {len(vision_configs)} models " + f"(strategy: {router_settings.get('routing_strategy', 'usage-based-routing')})" + ) + except Exception as e: + print(f"Warning: Failed to initialize Vision LLM Router: {e}") + + class Config: # Check if ffmpeg is installed if not is_ffmpeg_installed(): @@ -486,28 +541,6 @@ class Config: # Database DATABASE_URL = os.getenv("DATABASE_URL") - # When TRUE (default) the app ensures extensions/tables/indexes exist on - # startup. Set FALSE in environments where schema is owned exclusively by - # Alembic migrations to skip all boot-time DDL. - DB_BOOTSTRAP_ON_STARTUP = ( - os.getenv("DB_BOOTSTRAP_ON_STARTUP", "TRUE").upper() == "TRUE" - ) - # Per-session lock_timeout (ms) applied to boot-time DDL so a contended - # CREATE INDEX / CREATE TABLE fails fast instead of hanging the FastAPI - # lifespan forever behind another transaction's lock. - DB_DDL_LOCK_TIMEOUT_MS = int(os.getenv("DB_DDL_LOCK_TIMEOUT_MS", "5000")) - # Global idle_in_transaction_session_timeout (ms) applied to every pooled - # connection so an abandoned "idle in transaction" session can't wedge the - # database indefinitely. 0 disables. Only applied to asyncpg connections. - DB_IDLE_IN_TX_TIMEOUT_MS = int(os.getenv("DB_IDLE_IN_TX_TIMEOUT_MS", "900000")) - # Same protection for the separate Celery worker engine, where long-running - # ingestion/podcast/video tasks live. Kept higher than the web default so a - # legitimate per-document embed window is never reaped: if a task hasn't - # touched the DB in 60 min it's treated as orphaned and dropped. 0 disables. - DB_CELERY_IDLE_IN_TX_TIMEOUT_MS = int( - os.getenv("DB_CELERY_IDLE_IN_TX_TIMEOUT_MS", "3600000") - ) - # Celery / Redis # Redis (single endpoint for Celery broker, result backend, and app cache). # Legacy CELERY_BROKER_URL / CELERY_RESULT_BACKEND / REDIS_APP_URL still @@ -557,15 +590,14 @@ class Config: # Platform web search (SearXNG) SEARXNG_DEFAULT_HOST = os.getenv("SEARXNG_DEFAULT_HOST") - SURFSENSE_PUBLIC_URL = os.getenv("SURFSENSE_PUBLIC_URL") - NEXT_FRONTEND_URL = os.getenv("NEXT_FRONTEND_URL") or SURFSENSE_PUBLIC_URL + NEXT_FRONTEND_URL = os.getenv("NEXT_FRONTEND_URL") # Backend URL to override the http to https in the OAuth redirect URI - BACKEND_URL = os.getenv("BACKEND_URL") or SURFSENSE_PUBLIC_URL + BACKEND_URL = os.getenv("BACKEND_URL") - # Messaging gateway + # Messaging gateway (Telegram v1) # Global master switch: when FALSE, no gateway supervisors/workers start and all - # gated gateway HTTP routes return 404, regardless of the per-channel flags below. - GATEWAY_ENABLED = os.getenv("GATEWAY_ENABLED", "FALSE").upper() == "TRUE" + # gateway HTTP routes return 404, regardless of the per-channel flags below. + GATEWAY_ENABLED = os.getenv("GATEWAY_ENABLED", "TRUE").upper() == "TRUE" TELEGRAM_SHARED_BOT_TOKEN = os.getenv("TELEGRAM_SHARED_BOT_TOKEN") TELEGRAM_SHARED_BOT_USERNAME = os.getenv("TELEGRAM_SHARED_BOT_USERNAME") TELEGRAM_WEBHOOK_SECRET = os.getenv("TELEGRAM_WEBHOOK_SECRET") @@ -730,7 +762,7 @@ class Config: os.getenv("QUOTA_DEFAULT_IMAGE_RESERVE_MICROS", "50000") ) - # Per-podcast reservation (in micro-USD). One chat model call generating + # Per-podcast reservation (in micro-USD). One agent LLM call generating # a transcript, typically 5k-20k completion tokens. $0.20 covers a long # premium-model run. Tune via env. QUOTA_DEFAULT_PODCAST_RESERVE_MICROS = int( @@ -836,13 +868,6 @@ class Config: # LLM instances are now managed per-user through the LLMConfig system # Legacy environment variables removed in favor of user-specific configurations - # True when an operator-provided global_llm_config.yaml is present. - # Used to gate the per-search-space LLM onboarding flow: when a global - # config file exists, search spaces inherit it and onboarding is skipped. - GLOBAL_LLM_CONFIG_FILE_EXISTS = ( - BASE_DIR / "app" / "config" / "global_llm_config.yaml" - ).exists() - # Global LLM Configurations (optional) # Load from global_llm_config.yaml if available # These can be used as default options for users @@ -857,17 +882,11 @@ class Config: # Router settings for Image Generation Auto mode IMAGE_GEN_ROUTER_SETTINGS = load_image_gen_router_settings() - # Virtual GLOBAL connection/model catalog. This is server-only metadata - # derived from global_llm_config.yaml; GLOBAL keys are not stored in DB. - from app.services.global_model_catalog import ( - materialize_global_model_catalog as _materialize_global_model_catalog, - ) + # Global Vision LLM Configurations (optional) + GLOBAL_VISION_LLM_CONFIGS = load_global_vision_llm_configs() - GLOBAL_CONNECTIONS, GLOBAL_MODELS = _materialize_global_model_catalog( - chat_configs=GLOBAL_LLM_CONFIGS, - image_configs=GLOBAL_IMAGE_GEN_CONFIGS, - ) - del _materialize_global_model_catalog + # Router settings for Vision LLM Auto mode + VISION_LLM_ROUTER_SETTINGS = load_vision_llm_router_settings() # OpenRouter Integration settings (optional) OPENROUTER_INTEGRATION_SETTINGS = load_openrouter_integration_settings() @@ -933,47 +952,6 @@ class Config: AZURE_DI_ENDPOINT = os.getenv("AZURE_DI_ENDPOINT") AZURE_DI_KEY = os.getenv("AZURE_DI_KEY") - # ETL parse cache: reuse parser output for identical bytes across workspaces. - ETL_CACHE_ENABLED = ( - os.getenv("ETL_CACHE_ENABLED", "false").strip().lower() == "true" - ) - # Bump to invalidate every cached entry after a parser/behaviour change. - ETL_CACHE_PARSER_VERSION = int(os.getenv("ETL_CACHE_PARSER_VERSION", "1")) - ETL_CACHE_TTL_DAYS = int(os.getenv("ETL_CACHE_TTL_DAYS", "90")) - ETL_CACHE_MAX_TOTAL_MB = int(os.getenv("ETL_CACHE_MAX_TOTAL_MB", "5120")) - ETL_CACHE_EVICTION_BATCH = int(os.getenv("ETL_CACHE_EVICTION_BATCH", "500")) - # Optional dedicated blob storage; unset reuses the main file_storage backend. - ETL_CACHE_STORAGE_BACKEND = os.getenv("ETL_CACHE_STORAGE_BACKEND") - ETL_CACHE_STORAGE_CONTAINER = os.getenv("ETL_CACHE_STORAGE_CONTAINER") - ETL_CACHE_STORAGE_LOCAL_PATH = os.getenv("ETL_CACHE_STORAGE_LOCAL_PATH") - - # Embedding cache: reuse chunk+embedding output for identical markdown across - # workspaces. Blobs share the ETL_CACHE_STORAGE_* backend. - EMBEDDING_CACHE_ENABLED = ( - os.getenv("EMBEDDING_CACHE_ENABLED", "false").strip().lower() == "true" - ) - # Bump to invalidate every cached embedding set after a chunker change. - EMBEDDING_CACHE_CHUNKER_VERSION = int( - os.getenv("EMBEDDING_CACHE_CHUNKER_VERSION", "1") - ) - EMBEDDING_CACHE_TTL_DAYS = int(os.getenv("EMBEDDING_CACHE_TTL_DAYS", "90")) - EMBEDDING_CACHE_MAX_TOTAL_MB = int( - os.getenv("EMBEDDING_CACHE_MAX_TOTAL_MB", "5120") - ) - EMBEDDING_CACHE_EVICTION_BATCH = int( - os.getenv("EMBEDDING_CACHE_EVICTION_BATCH", "500") - ) - - # Incremental re-indexing: on document edits, keep chunk rows whose text is - # unchanged (reusing their embeddings) and embed only new/changed chunks. - # Kill switch -- disabling falls back to delete-all + full re-embed. - CHUNK_RECONCILE_ENABLED = ( - os.getenv("CHUNK_RECONCILE_ENABLED", "true").strip().lower() == "true" - ) - INDEXING_CHUNK_INSERT_BATCH_SIZE = int( - os.getenv("INDEXING_CHUNK_INSERT_BATCH_SIZE", "200") - ) - # Proxy provider selection. Maps to a ProxyProvider implementation registered # in app/utils/proxy/registry.py. Add new vendors there and switch via this var. PROXY_PROVIDER = os.getenv("PROXY_PROVIDER", "anonymous_proxies") diff --git a/surfsense_backend/app/config/global_llm_config.example.yaml b/surfsense_backend/app/config/global_llm_config.example.yaml index 329d96e37..1c09a91ac 100644 --- a/surfsense_backend/app/config/global_llm_config.example.yaml +++ b/surfsense_backend/app/config/global_llm_config.example.yaml @@ -1,236 +1,362 @@ # Global LLM Configuration # # SETUP INSTRUCTIONS: -# 1. Copy this file to global_llm_config.yaml. -# 2. Replace placeholder credentials, endpoints, deployment names, and pricing -# with values from your own provider accounts. +# 1. For production: Copy this file to global_llm_config.yaml and add your real API keys +# 2. For testing: The system will use this example file automatically if global_llm_config.yaml doesn't exist # -# This file is intentionally safe to commit. Do not put real API keys in this -# example file. +# NOTE: The example API keys below are placeholders and won't work. +# Replace them with your actual API keys to enable global configurations. # -# These YAML entries are materialized at startup as server-owned GLOBAL -# connections and models: +# These configurations will be available to all users as a convenient option +# Users can choose to use these global configs or add their own # -# global_llm_configs -> GLOBAL chat models -# global_image_generation_configs -> GLOBAL image generation models +# AUTO MODE (Recommended): +# - Auto mode (ID: 0) uses LiteLLM Router to automatically load balance across all global configs +# - This helps avoid rate limits by distributing requests across multiple providers +# - New users are automatically assigned Auto mode by default +# - Configure router_settings below to customize the load balancing behavior # -# Do not add global_connections or global_models sections here. They are -# runtime-derived metadata exposed through the model-connections APIs. -# -# Static config shape: -# - Connection fields: provider, api_key, api_base, api_version -# - Model fields: model_name, billing_tier, rpm/tpm, capabilities, litellm_params -# - Public no-login SEO metadata: seo_title, seo_description -# - Prompt defaults: system_instructions, use_default_system_instructions, -# citations_enabled -# -# Provider notes: -# - Use the canonical provider field. -# - For Azure, use the bare deployment name in model_name, for example -# model_name: "gpt-5.1". The resolver prefixes the LiteLLM model string from -# provider: "azure". -# -# GLOBAL ID namespace: -# - ID 0 is reserved for Auto mode. -# - Negative IDs are server-owned GLOBAL models. -# - Positive IDs are user/BYOK database models. -# - Keep static IDs unique across chat and image generation. -# - Suggested static ranges: chat -1..-999, image -2001..-2999. -# - Vision is not a separate config/table. Chat models that accept images use -# supports_image_input: true. +# Structure matches NewLLMConfig: +# - Model configuration (provider, model_name, api_key, etc.) +# - Prompt configuration (system_instructions, citations_enabled) # # COST-BASED PREMIUM CREDITS: -# Each premium model bills the user's USD-credit balance based on provider cost -# reported by LiteLLM. For custom Azure deployments or any model LiteLLM does -# not know, declare per-token costs inline: +# Each premium config bills the user's USD-credit balance based on the +# actual provider cost reported by LiteLLM. For models LiteLLM already +# knows (most OpenAI/Anthropic/etc. names) you don't need to do anything. +# For custom Azure deployment names (e.g. an in-house "gpt-5.4" deployment) +# or any model LiteLLM doesn't have in its built-in pricing table, declare +# per-token costs inline so they bill correctly: # # litellm_params: -# base_model: "my-custom-deployment" -# # USD per token; 0.00000125 == $1.25 per million input tokens. -# input_cost_per_token: 0.00000125 -# output_cost_per_token: 0.00001 +# base_model: "my-custom-azure-deploy" +# # USD per token; e.g. 0.000003 == $3.00 per million input tokens +# input_cost_per_token: 0.000003 +# output_cost_per_token: 0.000015 # -# OpenRouter dynamic chat models pull pricing automatically from OpenRouter's -# API. Models without resolvable pricing debit $0 and log a warning. +# OpenRouter dynamic models pull pricing automatically from OpenRouter's +# API — no inline declaration needed. Models without resolvable pricing +# debit $0 from the user's balance and log a WARNING. -# ============================================================================= -# Chat Auto Mode Router Settings -# ============================================================================= -# These settings control how the LiteLLM Router distributes Auto-mode requests -# across curated router-eligible GLOBAL chat deployments. +# Router Settings for Auto Mode +# These settings control how the LiteLLM Router distributes requests across models router_settings: # Routing strategy options: - # - "usage-based-routing": Routes to deployment with lowest current usage. - # - "simple-shuffle": Random distribution with optional RPM/TPM weighting. - # - "least-busy": Routes to least busy deployment. - # - "latency-based-routing": Routes based on response latency. + # - "usage-based-routing": Routes to deployment with lowest current usage (recommended for rate limits) + # - "simple-shuffle": Random distribution with optional RPM/TPM weighting + # - "least-busy": Routes to least busy deployment + # - "latency-based-routing": Routes based on response latency routing_strategy: "usage-based-routing" + + # Number of retries before failing num_retries: 3 + + # Number of failures allowed before cooling down a deployment allowed_fails: 3 + + # Cooldown time in seconds after allowed_fails is exceeded cooldown_time: 60 - # Optional fallback map: - # fallbacks: - # - {"azure/gpt-5.1": ["azure/gpt-5.4-mini"]} -# ============================================================================= -# Static GLOBAL Chat Models -# ============================================================================= + # Fallback models (optional) - when primary fails, try these + # Format: [{"primary_model": ["fallback1", "fallback2"]}] + # fallbacks: [] + global_llm_configs: - # Premium Azure chat model with image input support and explicit custom - # pricing. This is the current shape to use for hosted GPT 5.x deployments. + # Example: OpenAI GPT-4 Turbo with citations enabled - id: -1 - name: "Azure GPT 5.1" - billing_tier: "premium" - anonymous_enabled: false - seo_enabled: false - seo_slug: "azure-gpt-5-1" - quota_reserve_tokens: 4000 - provider: "azure" - model_name: "gpt-5.1" - supports_image_input: true - supports_tools: true - max_input_tokens: 400000 - api_key: "your-azure-api-key-here" - api_base: "https://your-resource.openai.azure.com" - # api_version is optional. Include it if your Azure deployment requires a - # specific API version. - # api_version: "2025-04-01-preview" - rpm: 47500 - tpm: 14750000 - litellm_params: - max_tokens: 16384 - base_model: "gpt-5.1" - input_cost_per_token: 0.00000125 - output_cost_per_token: 0.00001 - system_instructions: "" - use_default_system_instructions: true - citations_enabled: true - - # Larger premium chat model. If your provider prices long-context traffic - # differently, choose a conservative flat price or document the limitation - # next to the inline pricing. - - id: -2 - name: "Azure GPT 5.4" - billing_tier: "premium" - anonymous_enabled: false - seo_enabled: false - seo_slug: "azure-gpt-5-4" - quota_reserve_tokens: 4000 - provider: "azure" - model_name: "gpt-5.4" - supports_image_input: true - supports_tools: true - max_input_tokens: 400000 - api_key: "your-azure-api-key-here" - api_base: "https://your-resource.openai.azure.com" - rpm: 150000 - tpm: 15000000 - litellm_params: - max_tokens: 16384 - base_model: "gpt-5.4" - input_cost_per_token: 0.0000025 - output_cost_per_token: 0.000015 - system_instructions: "" - use_default_system_instructions: true - citations_enabled: true - - # Free/no-login hosted model. Free models are visible to users when - # anonymous_enabled/seo_enabled are true but do not debit premium credits. - - id: -3 - name: "Azure GPT 5.4 Mini" + name: "Global GPT-4 Turbo" + description: "OpenAI's GPT-4 Turbo with default prompts and citations" billing_tier: "free" anonymous_enabled: true seo_enabled: true - seo_slug: "gpt-5-4-mini-no-login" - seo_title: "Free GPT 5.4 Mini Chat" - seo_description: "Chat with a hosted GPT 5.4 Mini model without signing in." + seo_slug: "gpt-4-turbo" quota_reserve_tokens: 4000 - provider: "azure" - model_name: "gpt-5.4-mini" - supports_image_input: false - supports_tools: true - max_input_tokens: 128000 - api_key: "your-azure-api-key-here" - api_base: "https://your-resource.openai.azure.com" - rpm: 15000 - tpm: 15000000 + provider: "OPENAI" + model_name: "gpt-4-turbo-preview" + api_key: "sk-your-openai-api-key-here" + api_base: "" + # Rate limits for load balancing (requests/tokens per minute) + rpm: 500 # Requests per minute + tpm: 100000 # Tokens per minute litellm_params: - max_tokens: 16384 - base_model: "gpt-5.4-mini" + temperature: 0.7 + max_tokens: 4000 + # Prompt Configuration + system_instructions: "" # Empty = use default SURFSENSE_SYSTEM_INSTRUCTIONS + use_default_system_instructions: true + citations_enabled: true + + # Example: Anthropic Claude 3 Opus + - id: -2 + name: "Global Claude 3 Opus" + description: "Anthropic's most capable model with citations" + billing_tier: "free" + anonymous_enabled: true + seo_enabled: true + seo_slug: "claude-3-opus" + quota_reserve_tokens: 4000 + provider: "ANTHROPIC" + model_name: "claude-3-opus-20240229" + api_key: "sk-ant-your-anthropic-api-key-here" + api_base: "" + rpm: 1000 + tpm: 100000 + litellm_params: + temperature: 0.7 + max_tokens: 4000 system_instructions: "" use_default_system_instructions: true citations_enabled: true - # Planner LLM. This is operator-only and is not shown in the user-facing - # model selector. Only one global_llm_configs entry should set is_planner. + # Example: Fast model - GPT-3.5 Turbo (citations disabled for speed) + - id: -3 + name: "Global GPT-3.5 Turbo (Fast)" + description: "Fast responses without citations for quick queries" + billing_tier: "free" + anonymous_enabled: true + seo_enabled: true + seo_slug: "gpt-3.5-turbo-fast" + quota_reserve_tokens: 2000 + provider: "OPENAI" + model_name: "gpt-3.5-turbo" + api_key: "sk-your-openai-api-key-here" + api_base: "" + rpm: 3500 # GPT-3.5 has higher rate limits + tpm: 200000 + litellm_params: + temperature: 0.5 + max_tokens: 2000 + system_instructions: "" + use_default_system_instructions: true + citations_enabled: false # Disabled for faster responses + + # Example: Chinese LLM - DeepSeek with custom instructions + - id: -4 + name: "Global DeepSeek Chat (Chinese)" + description: "DeepSeek optimized for Chinese language responses" + billing_tier: "free" + anonymous_enabled: true + seo_enabled: true + seo_slug: "deepseek-chat-chinese" + quota_reserve_tokens: 4000 + provider: "DEEPSEEK" + model_name: "deepseek-chat" + api_key: "your-deepseek-api-key-here" + api_base: "https://api.deepseek.com/v1" + rpm: 60 + tpm: 100000 + litellm_params: + temperature: 0.7 + max_tokens: 4000 + # Custom system instructions for Chinese responses + system_instructions: | + + You are SurfSense, a reasoning and acting AI agent designed to answer user questions using the user's personal knowledge base. + + Today's date (UTC): {resolved_today} + + IMPORTANT: Please respond in Chinese (简体中文) unless the user specifically requests another language. + + use_default_system_instructions: false + citations_enabled: true + + # Example: Azure OpenAI GPT-4o + # IMPORTANT: For Azure deployments, always include 'base_model' in litellm_params + # to enable accurate token counting, cost tracking, and max token limits + - id: -5 + name: "Global Azure GPT-4o" + description: "Azure OpenAI GPT-4o deployment" + billing_tier: "free" + anonymous_enabled: true + seo_enabled: true + seo_slug: "azure-gpt-4o" + quota_reserve_tokens: 4000 + provider: "AZURE" + # model_name format for Azure: azure/ + model_name: "azure/gpt-4o-deployment" + api_key: "your-azure-api-key-here" + api_base: "https://your-resource.openai.azure.com" + api_version: "2024-02-15-preview" # Azure API version + rpm: 1000 + tpm: 150000 + litellm_params: + temperature: 0.7 + max_tokens: 4000 + # REQUIRED for Azure: Specify the underlying OpenAI model + # This fixes "Could not identify azure model" warnings + # Common base_model values: gpt-4, gpt-4-turbo, gpt-4o, gpt-4o-mini, gpt-3.5-turbo + base_model: "gpt-4o" + system_instructions: "" + use_default_system_instructions: true + citations_enabled: true + + # Example: Azure OpenAI GPT-4 Turbo + - id: -6 + name: "Global Azure GPT-4 Turbo" + description: "Azure OpenAI GPT-4 Turbo deployment" + billing_tier: "free" + anonymous_enabled: true + seo_enabled: true + seo_slug: "azure-gpt-4-turbo" + quota_reserve_tokens: 4000 + provider: "AZURE" + model_name: "azure/gpt-4-turbo-deployment" + api_key: "your-azure-api-key-here" + api_base: "https://your-resource.openai.azure.com" + api_version: "2024-02-15-preview" + rpm: 500 + tpm: 100000 + litellm_params: + temperature: 0.7 + max_tokens: 4000 + base_model: "gpt-4-turbo" # Maps to gpt-4-turbo-preview + system_instructions: "" + use_default_system_instructions: true + citations_enabled: true + + # Example: Groq - Fast inference + - id: -7 + name: "Global Groq Llama 3" + description: "Ultra-fast Llama 3 70B via Groq" + billing_tier: "free" + anonymous_enabled: true + seo_enabled: true + seo_slug: "groq-llama-3" + quota_reserve_tokens: 8000 + provider: "GROQ" + model_name: "llama3-70b-8192" + api_key: "your-groq-api-key-here" + api_base: "" + rpm: 30 # Groq has lower rate limits on free tier + tpm: 14400 + litellm_params: + temperature: 0.7 + max_tokens: 8000 + system_instructions: "" + use_default_system_instructions: true + citations_enabled: true + + # Example: MiniMax M3 - High-performance with 512K context window + - id: -8 + name: "Global MiniMax M3" + description: "MiniMax M3 with 512K context window and competitive pricing" + billing_tier: "free" + anonymous_enabled: true + seo_enabled: true + seo_slug: "minimax-m3" + quota_reserve_tokens: 4000 + provider: "MINIMAX" + model_name: "MiniMax-M3" + api_key: "your-minimax-api-key-here" + api_base: "https://api.minimax.io/v1" + rpm: 60 + tpm: 100000 + litellm_params: + temperature: 1.0 # MiniMax requires temperature in (0.0, 1.0], cannot be 0 + max_tokens: 4000 + system_instructions: "" + use_default_system_instructions: true + citations_enabled: true + + # Example: Planner LLM - small, fast model used for internal utility tasks + # + # The PLANNER role handles short, structured internal calls (KB query + # rewriting, date extraction, recency classification, etc.) that don't + # need frontier-tier capability. Pointing the planner at a cheap+fast + # model (gpt-4o-mini, Claude Haiku, Azure gpt-5.x-nano, Groq Llama, ...) + # typically saves 500ms-1.5s per turn vs. routing those same internal + # calls through the user's chat model. + # + # Activation: + # - Mark EXACTLY ONE global config with ``is_planner: true``. + # - If multiple are marked, the first one wins and a WARNING is logged. + # - If none is marked, every internal call falls back to the user's + # chat LLM (same behavior as before this flag existed). + # + # This config is operator-only — it is NOT exposed in the user-facing + # model selector, never billed against premium quota, and the + # billing_tier / anonymous_enabled fields below are ignored. - id: -9 - name: "Azure GPT 5.x Nano Planner" + name: "Global Planner (GPT-4o mini)" + description: "Internal-only planner LLM for query rewriting and classification" is_planner: true billing_tier: "free" anonymous_enabled: false seo_enabled: false quota_reserve_tokens: 1000 - provider: "azure" - model_name: "gpt-5.4-nano" - supports_image_input: false - supports_tools: false - router_pool_eligible: false - api_key: "your-azure-api-key-here" - api_base: "https://your-resource.openai.azure.com" - rpm: 20000 - tpm: 4000000 + provider: "OPENAI" + model_name: "gpt-4o-mini" + api_key: "sk-your-openai-api-key-here" + api_base: "" + rpm: 3500 + tpm: 200000 litellm_params: temperature: 0 max_tokens: 1000 - base_model: "gpt-5.4-nano" system_instructions: "" use_default_system_instructions: true citations_enabled: false # ============================================================================= -# OpenRouter Dynamic Model Integration +# OpenRouter Integration # ============================================================================= -# When enabled, SurfSense fetches the OpenRouter catalog at startup and injects -# supported models as GLOBAL chat and optionally image-generation models. -# Tier is derived per model from OpenRouter data: -# - model id ends with ":free" -> billing_tier=free -# - prompt and completion pricing are zero -> billing_tier=free -# - otherwise -> billing_tier=premium -# -# Do not use deprecated openrouter_integration.billing_tier or -# openrouter_integration.anonymous_enabled. Use the tier-specific anonymous -# switches below. +# When enabled, dynamically fetches ALL available models from the OpenRouter API +# and injects them as global configs. This gives premium users access to any model +# on OpenRouter (Claude, Gemini, Llama, Mistral, etc.) via their premium token quota, +# while free-tier OpenRouter models show up with a green Free badge and do NOT +# consume premium quota. +# Models are fetched at startup and refreshed periodically in the background. +# All calls go through LiteLLM with the openrouter/ prefix. openrouter_integration: enabled: false api_key: "sk-or-your-openrouter-api-key" + # Tier is derived PER MODEL from OpenRouter's own API signals: + # - id ends with ":free" -> billing_tier=free + # - pricing.prompt AND pricing.completion == "0" -> billing_tier=free + # - otherwise -> billing_tier=premium + # No global billing_tier knob is honored; any legacy value emits a startup warning. + + # Anonymous access is split by tier so operators can expose only free + # models to no-login users without leaking paid inference. anonymous_enabled_paid: false anonymous_enabled_free: false + seo_enabled: false + # quota_reserve_tokens: tokens reserved per call for quota enforcement quota_reserve_tokens: 4000 - - # Base negative ID namespace for dynamic chat models. IDs are derived - # deterministically so they survive catalog churn. Do not overlap static IDs. + # id_offset: base negative ID for dynamically generated configs. + # Model IDs are derived deterministically via BLAKE2b so they survive + # catalogue churn. Must not overlap with your static global_llm_configs IDs. id_offset: -10000 - - # Separate base negative ID namespace for dynamic image-generation models. - image_id_offset: -20000 - - # How often to refresh the OpenRouter catalog. 0 means startup only. + # refresh_interval_hours: how often to re-fetch models from OpenRouter (0 = startup only) refresh_interval_hours: 24 - # Paid OpenRouter models may join curated router pools when eligible. + # Rate limits for PAID OpenRouter models. These are used by LiteLLM Router + # for per-deployment accounting when OR premium models participate in the + # shared sub-agent "auto" pool. They do NOT cap OpenRouter itself — your + # real account limits live at https://openrouter.ai/settings/limits. rpm: 200 tpm: 1000000 - # Free OpenRouter models are available for user-facing selection/pinning but - # should be treated as a shared-account bucket, not normal router capacity. + # Rate limits for FREE OpenRouter models. Informational only: free OR + # models are intentionally kept OUT of the LiteLLM Router pool, because + # OpenRouter enforces free-tier limits globally per account (~20 RPM + + # 50-1000 daily requests across every ":free" model combined) — + # per-deployment router accounting can't represent a shared bucket + # correctly. Free OR models stay fully available in the model selector + # and for user-facing Auto thread pinning. free_rpm: 20 free_tpm: 100000 - # Image generation is opt-in to avoid injecting a large image catalog during - # upgrades. Vision-capable chat models are represented with - # supports_image_input: true. + # Image generation + vision LLM emission are OPT-IN. OpenRouter's catalogue + # contains hundreds of image- and vision-capable models; turning these on + # injects them into the global Image-Generation / Vision-LLM model + # selectors alongside any static configs. Tier (free/premium) is derived + # per model the same way it is for chat (`:free` suffix or zero pricing). + # When a user picks a premium image/vision model the call debits the + # shared $5 USD-cost-based premium credit pool — so leaving these off + # avoids surprise quota burn on existing deployments. Default: false. image_generation_enabled: false vision_enabled: false @@ -241,80 +367,191 @@ openrouter_integration: citations_enabled: true # ============================================================================= -# Image Generation Auto Mode Router Settings +# Image Generation Configuration # ============================================================================= +# These configurations power the image generation feature using litellm.aimage_generation(). +# Supported providers: OpenAI, Azure, Google AI Studio, Vertex AI, AWS Bedrock, +# Recraft, OpenRouter, Xinference, Nscale +# +# Auto mode (ID 0) uses LiteLLM Router for load balancing across all image gen configs. + +# Router Settings for Image Generation Auto Mode image_generation_router_settings: routing_strategy: "usage-based-routing" num_retries: 3 allowed_fails: 3 cooldown_time: 60 -# ============================================================================= -# Static GLOBAL Image Generation Models -# ============================================================================= global_image_generation_configs: - - id: -2001 - name: "Azure GPT Image 1.5" - billing_tier: "premium" - provider: "azure" - model_name: "gpt-image-1.5" + # Example: OpenAI DALL-E 3 + - id: -1 + name: "Global DALL-E 3" + description: "OpenAI's DALL-E 3 for high-quality image generation" + provider: "OPENAI" + model_name: "dall-e-3" + api_key: "sk-your-openai-api-key-here" + api_base: "" + rpm: 50 # Requests per minute (image gen is rate-limited by RPM, not tokens) + litellm_params: {} + + # Example: OpenAI GPT Image 1 + - id: -2 + name: "Global GPT Image 1" + description: "OpenAI's GPT Image 1 model" + provider: "OPENAI" + model_name: "gpt-image-1" + api_key: "sk-your-openai-api-key-here" + api_base: "" + rpm: 50 + litellm_params: {} + + # Example: Azure OpenAI DALL-E 3 + - id: -3 + name: "Global Azure DALL-E 3" + description: "Azure-hosted DALL-E 3 deployment" + provider: "AZURE_OPENAI" + model_name: "azure/dall-e-3-deployment" api_key: "your-azure-api-key-here" api_base: "https://your-resource.openai.azure.com" - # api_version: "2025-04-01-preview" - rpm: 60 + api_version: "2024-02-15-preview" + rpm: 50 litellm_params: - base_model: "gpt-image-1.5" + base_model: "dall-e-3" - - id: -2002 - name: "Azure GPT Image 1 Mini" - billing_tier: "free" - provider: "azure" - model_name: "gpt-image-1-mini" - api_key: "your-azure-api-key-here" - api_base: "https://your-resource.openai.azure.com" - # api_version: "2025-04-01-preview" - rpm: 120 - litellm_params: - base_model: "gpt-image-1-mini" + # Example: OpenRouter Gemini Image Generation + # - id: -4 + # name: "Global Gemini Image Gen" + # description: "Google Gemini image generation via OpenRouter" + # provider: "OPENROUTER" + # model_name: "google/gemini-2.5-flash-image" + # api_key: "your-openrouter-api-key-here" + # api_base: "" + # rpm: 30 + # litellm_params: {} # ============================================================================= -# Field Notes +# Vision LLM Configuration # ============================================================================= -# Common chat/image fields: -# - provider: Canonical provider adapter name. Example: azure, openai, -# anthropic, openrouter, groq, bedrock. -# - model_name: Provider model or deployment id. For Azure, use the bare -# deployment name. The resolver prefixes LiteLLM model strings from provider. -# - api_base: Provider endpoint/root URL. For OpenAI-compatible providers, the -# resolver adds /v1 when needed. -# - api_version: Optional provider-specific API version, stored on the -# materialized connection extra metadata. -# - litellm_params: Passed to LiteLLM when invoking the model. Also used for -# base_model and inline pricing registration. +# These configurations power the vision autocomplete feature (screenshot analysis). +# Only vision-capable models should be used here (e.g. GPT-4o, Gemini Pro, Claude 3). +# Supported providers: OpenAI, Anthropic, Google, Azure OpenAI, Vertex AI, Bedrock, +# xAI, OpenRouter, Ollama, Groq, Together AI, Fireworks AI, DeepSeek, Mistral, Custom # -# Chat model fields: -# - supports_image_input: true when the chat model can consume image inputs. -# - supports_tools: true when the model can use tools/function calling. -# - max_input_tokens: Optional UI/catalog metadata for context size. -# - router_pool_eligible: false keeps a model out of shared router pools while -# still allowing direct selection/pinning. -# - is_planner: true marks the internal-only planner model. Only one config -# should set this flag. +# Auto mode (ID 0) uses LiteLLM Router for load balancing across all vision configs. + +# Router Settings for Vision LLM Auto Mode +vision_llm_router_settings: + routing_strategy: "usage-based-routing" + num_retries: 3 + allowed_fails: 3 + cooldown_time: 60 + +global_vision_llm_configs: + # Example: OpenAI GPT-4o (recommended for vision) + - id: -1 + name: "Global GPT-4o Vision" + description: "OpenAI's GPT-4o with strong vision capabilities" + provider: "OPENAI" + model_name: "gpt-4o" + api_key: "sk-your-openai-api-key-here" + api_base: "" + rpm: 500 + tpm: 100000 + litellm_params: + temperature: 0.3 + max_tokens: 1000 + + # Example: Google Gemini 2.0 Flash + - id: -2 + name: "Global Gemini 2.0 Flash" + description: "Google's fast vision model with large context" + provider: "GOOGLE" + model_name: "gemini-2.0-flash" + api_key: "your-google-ai-api-key-here" + api_base: "" + rpm: 1000 + tpm: 200000 + litellm_params: + temperature: 0.3 + max_tokens: 1000 + + # Example: Anthropic Claude 3.5 Sonnet + - id: -3 + name: "Global Claude 3.5 Sonnet Vision" + description: "Anthropic's Claude 3.5 Sonnet with vision support" + provider: "ANTHROPIC" + model_name: "claude-3-5-sonnet-20241022" + api_key: "sk-ant-your-anthropic-api-key-here" + api_base: "" + rpm: 1000 + tpm: 100000 + litellm_params: + temperature: 0.3 + max_tokens: 1000 + + # Example: Azure OpenAI GPT-4o + # - id: -4 + # name: "Global Azure GPT-4o Vision" + # description: "Azure-hosted GPT-4o for vision analysis" + # provider: "AZURE_OPENAI" + # model_name: "azure/gpt-4o-deployment" + # api_key: "your-azure-api-key-here" + # api_base: "https://your-resource.openai.azure.com" + # api_version: "2024-02-15-preview" + # rpm: 500 + # tpm: 100000 + # litellm_params: + # temperature: 0.3 + # max_tokens: 1000 + # base_model: "gpt-4o" + +# Notes: +# - ID 0 is reserved for "Auto" mode - uses LiteLLM Router for load balancing +# - Use negative IDs to distinguish global configs from user configs (NewLLMConfig in DB) +# - IDs should be unique and sequential (e.g., -1, -2, -3, etc.) +# - The 'api_key' field will not be exposed to users via API +# - system_instructions: Custom prompt or empty string to use defaults +# - use_default_system_instructions: true = use SURFSENSE_SYSTEM_INSTRUCTIONS when system_instructions is empty +# - citations_enabled: true = include citation instructions, false = include anti-citation instructions +# - All standard LiteLLM providers are supported +# - rpm/tpm: Optional rate limits for load balancing (requests/tokens per minute) +# These help the router distribute load evenly and avoid rate limit errors # -# Catalog and access fields: -# - billing_tier: "free" or "premium". -# - anonymous_enabled: Whether the model appears in the public no-login catalog. -# - seo_enabled: Whether a /free/ landing page is generated. -# - seo_slug: Stable URL slug for SEO pages. Keep unique and do not change once -# public. -# - seo_title / seo_description: Optional SEO metadata overrides. -# - quota_reserve_tokens: Tokens reserved before each chat LLM call. -# - rpm / tpm: Optional rate limits for router accounting and load balancing. # -# Image generation notes: -# - Image-generation configs use the same GLOBAL ID namespace as chat models. -# - Only RPM is relevant for most image-generation APIs. -# - The runtime uses litellm.aimage_generation(). -# - Image billing currently uses billing_tier and model catalog metadata. Keep -# quota reserve tuning in code/catalog unless the materializer copies a YAML -# key for image quota reservation. +# IMAGE GENERATION NOTES: +# - Image generation configs use the same ID scheme as LLM configs (negative for global) +# - Supported models: dall-e-2, dall-e-3, gpt-image-1 (OpenAI), azure/* (Azure), +# bedrock/* (AWS), vertex_ai/* (Google), recraft/* (Recraft), openrouter/* (OpenRouter) +# - The router uses litellm.aimage_generation() for async image generation +# - Only RPM (requests per minute) is relevant for image generation rate limiting. +# TPM (tokens per minute) does not apply since image APIs are billed/rate-limited per request, not per token. +# +# VISION LLM NOTES: +# - Vision configs use the same ID scheme (negative for global, positive for user DB) +# - Only use vision-capable models (GPT-4o, Gemini, Claude 3, etc.) +# - Lower temperature (0.3) is recommended for accurate screenshot analysis +# - Lower max_tokens (1000) is sufficient since autocomplete produces short suggestions +# +# PLANNER LLM NOTES: +# - is_planner: true marks a config as the internal-only planner LLM (small, +# fast model used for KB query rewriting, date extraction, recency +# classification, etc.). Only one config may carry this flag — if +# multiple do, the first one wins and a startup WARNING is logged. +# - When no config is marked is_planner, every internal utility call falls +# back to the user's chat LLM (the historical behavior). +# - Planner configs are NOT shown in the user-facing model selector and +# are NOT billed against the user's premium quota. Their billing_tier, +# anonymous_enabled, seo_* fields are ignored. +# - Recommended models: gpt-4o-mini, claude-3-5-haiku, gemini-1.5-flash, +# azure gpt-5.x-nano, groq llama3-8b — anything <200ms p50 on a 1-2k +# prompt. Frontier models here defeat the purpose of the flag. +# +# TOKEN QUOTA & ANONYMOUS ACCESS NOTES: +# - billing_tier: "free" or "premium". Controls whether registered users need premium token quota. +# - anonymous_enabled: true/false. Whether the model appears in the public no-login catalog. +# - seo_enabled: true/false. Whether a /free/ landing page is generated. +# - seo_slug: Stable URL slug for SEO pages. Must be unique. Do NOT change once public. +# - seo_title: Optional HTML title tag override for the model's /free/ page. +# - seo_description: Optional meta description override for the model's /free/ page. +# - quota_reserve_tokens: Tokens reserved before each LLM call for quota enforcement. +# Independent of litellm_params.max_tokens. Used by the token quota service. diff --git a/surfsense_backend/app/connectors/dropbox/content_extractor.py b/surfsense_backend/app/connectors/dropbox/content_extractor.py index 300010c26..372d2fc82 100644 --- a/surfsense_backend/app/connectors/dropbox/content_extractor.py +++ b/surfsense_backend/app/connectors/dropbox/content_extractor.py @@ -90,12 +90,11 @@ async def download_and_extract_content( if error: return None, metadata, error - from app.etl_pipeline.cache import extract_with_cache from app.etl_pipeline.etl_document import EtlRequest + from app.etl_pipeline.etl_pipeline_service import EtlPipelineService - result = await extract_with_cache( - EtlRequest(file_path=temp_file_path, filename=file_name), - vision_llm=vision_llm, + result = await EtlPipelineService(vision_llm=vision_llm).extract( + EtlRequest(file_path=temp_file_path, filename=file_name) ) markdown = result.markdown_content return markdown, metadata, None diff --git a/surfsense_backend/app/connectors/google_drive/content_extractor.py b/surfsense_backend/app/connectors/google_drive/content_extractor.py index 1ea047978..59392831d 100644 --- a/surfsense_backend/app/connectors/google_drive/content_extractor.py +++ b/surfsense_backend/app/connectors/google_drive/content_extractor.py @@ -122,13 +122,12 @@ async def download_and_extract_content( async def _parse_file_to_markdown( file_path: str, filename: str, *, vision_llm=None ) -> str: - """Parse a local file to markdown via the cache-aware ETL pipeline.""" - from app.etl_pipeline.cache import extract_with_cache + """Parse a local file to markdown using the unified ETL pipeline.""" from app.etl_pipeline.etl_document import EtlRequest + from app.etl_pipeline.etl_pipeline_service import EtlPipelineService - result = await extract_with_cache( - EtlRequest(file_path=file_path, filename=filename), - vision_llm=vision_llm, + result = await EtlPipelineService(vision_llm=vision_llm).extract( + EtlRequest(file_path=file_path, filename=filename) ) return result.markdown_content diff --git a/surfsense_backend/app/connectors/onedrive/content_extractor.py b/surfsense_backend/app/connectors/onedrive/content_extractor.py index fb1d31fbc..3154f2eca 100644 --- a/surfsense_backend/app/connectors/onedrive/content_extractor.py +++ b/surfsense_backend/app/connectors/onedrive/content_extractor.py @@ -84,12 +84,11 @@ async def download_and_extract_content( async def _parse_file_to_markdown( file_path: str, filename: str, *, vision_llm=None ) -> str: - """Parse a local file to markdown via the cache-aware ETL pipeline.""" - from app.etl_pipeline.cache import extract_with_cache + """Parse a local file to markdown using the unified ETL pipeline.""" from app.etl_pipeline.etl_document import EtlRequest + from app.etl_pipeline.etl_pipeline_service import EtlPipelineService - result = await extract_with_cache( - EtlRequest(file_path=file_path, filename=filename), - vision_llm=vision_llm, + result = await EtlPipelineService(vision_llm=vision_llm).extract( + EtlRequest(file_path=file_path, filename=filename) ) return result.markdown_content diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 3f098d5d2..2d672131b 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -1,4 +1,3 @@ -import logging import uuid from collections.abc import AsyncGenerator from contextlib import asynccontextmanager @@ -35,8 +34,6 @@ from app.config import config if config.AUTH_TYPE == "GOOGLE": from fastapi_users.db import SQLAlchemyBaseOAuthAccountTableUUID -logger = logging.getLogger(__name__) - DATABASE_URL = config.DATABASE_URL @@ -201,15 +198,79 @@ class DocumentStatus: return None -class ConnectionScope(StrEnum): - GLOBAL = "GLOBAL" - SEARCH_SPACE = "SEARCH_SPACE" - USER = "USER" +class LiteLLMProvider(StrEnum): + """ + Enum for LLM providers supported by LiteLLM. + """ + + OPENAI = "OPENAI" + ANTHROPIC = "ANTHROPIC" + GOOGLE = "GOOGLE" + AZURE_OPENAI = "AZURE_OPENAI" + BEDROCK = "BEDROCK" + VERTEX_AI = "VERTEX_AI" + GROQ = "GROQ" + COHERE = "COHERE" + MISTRAL = "MISTRAL" + DEEPSEEK = "DEEPSEEK" + XAI = "XAI" + OPENROUTER = "OPENROUTER" + TOGETHER_AI = "TOGETHER_AI" + FIREWORKS_AI = "FIREWORKS_AI" + REPLICATE = "REPLICATE" + PERPLEXITY = "PERPLEXITY" + OLLAMA = "OLLAMA" + ALIBABA_QWEN = "ALIBABA_QWEN" + MOONSHOT = "MOONSHOT" + ZHIPU = "ZHIPU" + ANYSCALE = "ANYSCALE" + DEEPINFRA = "DEEPINFRA" + CEREBRAS = "CEREBRAS" + SAMBANOVA = "SAMBANOVA" + AI21 = "AI21" + CLOUDFLARE = "CLOUDFLARE" + DATABRICKS = "DATABRICKS" + COMETAPI = "COMETAPI" + HUGGINGFACE = "HUGGINGFACE" + GITHUB_MODELS = "GITHUB_MODELS" + MINIMAX = "MINIMAX" + CUSTOM = "CUSTOM" -class ModelSource(StrEnum): - DISCOVERED = "DISCOVERED" - MANUAL = "MANUAL" +class ImageGenProvider(StrEnum): + """ + Enum for image generation providers supported by LiteLLM. + This is a subset of LLM providers — only those that support image generation. + See: https://docs.litellm.ai/docs/image_generation#supported-providers + """ + + OPENAI = "OPENAI" + AZURE_OPENAI = "AZURE_OPENAI" + GOOGLE = "GOOGLE" # Google AI Studio + VERTEX_AI = "VERTEX_AI" + BEDROCK = "BEDROCK" # AWS Bedrock + RECRAFT = "RECRAFT" + OPENROUTER = "OPENROUTER" + XINFERENCE = "XINFERENCE" + NSCALE = "NSCALE" + + +class VisionProvider(StrEnum): + OPENAI = "OPENAI" + ANTHROPIC = "ANTHROPIC" + GOOGLE = "GOOGLE" + AZURE_OPENAI = "AZURE_OPENAI" + VERTEX_AI = "VERTEX_AI" + BEDROCK = "BEDROCK" + XAI = "XAI" + OPENROUTER = "OPENROUTER" + OLLAMA = "OLLAMA" + GROQ = "GROQ" + TOGETHER_AI = "TOGETHER_AI" + FIREWORKS_AI = "FIREWORKS_AI" + DEEPSEEK = "DEEPSEEK" + MISTRAL = "MISTRAL" + CUSTOM = "CUSTOM" class LogLevel(StrEnum): @@ -638,11 +699,11 @@ class NewChatThread(BaseModel, TimestampMixin): default=False, server_default="false", ) - # Auto model pin for this thread: concrete resolved global LLM + # Auto (Fastest) model pin for this thread: concrete resolved global LLM # config id. NULL means no pin; Auto will resolve on the next turn. # Single-writer invariant: only app.services.auto_model_pin_service sets # or clears this column (plus bulk clears when a search space's - # chat_model_id changes). Unindexed: all reads are by primary key. + # agent_llm_id changes). Unindexed: all reads are by primary key. pinned_llm_config_id = Column(Integer, nullable=True) # Surface metadata for first-party SurfSense and external chat threads. @@ -1423,10 +1484,7 @@ class Document(BaseModel, TimestampMixin): created_by = relationship("User", back_populates="documents") connector = relationship("SearchSourceConnector", back_populates="documents") chunks = relationship( - "Chunk", - back_populates="document", - cascade="all, delete-orphan", - order_by="Chunk.position", + "Chunk", back_populates="document", cascade="all, delete-orphan" ) # Original upload + future derived artifacts (redacted, filled-form). # Model lives in app.file_storage.persistence to keep that feature cohesive. @@ -1462,11 +1520,6 @@ class Chunk(BaseModel, TimestampMixin): content = Column(Text, nullable=False) embedding = Column(Vector(config.embedding_model_instance.dimension)) - # Explicit document order; ids don't follow it since incremental - # re-indexing keeps unchanged rows across edits. Deliberately not indexed: - # ordering reads are document-scoped (covered by ix_chunks_document_id) and - # building a position index on the large chunks table is not worth it. - position = Column(Integer, nullable=False, server_default="0") document_id = Column( Integer, @@ -1548,80 +1601,73 @@ class Report(BaseModel, TimestampMixin): thread = relationship("NewChatThread") -class Connection(BaseModel, TimestampMixin): - __tablename__ = "connections" +class ImageGenerationConfig(BaseModel, TimestampMixin): + """ + Dedicated configuration table for image generation models. - provider = Column(String(100), nullable=False, index=True) - base_url = Column(String(500), nullable=True) - api_key = Column(String, nullable=True) - extra = Column(JSONB, nullable=False, default=dict, server_default="{}") - scope = Column(SQLAlchemyEnum(ConnectionScope), nullable=False, index=True) - enabled = Column(Boolean, nullable=False, default=True, server_default="true") + Separate from NewLLMConfig because image generation models don't need + system_instructions, citations_enabled, or use_default_system_instructions. + They only need provider credentials and model parameters. + """ + + __tablename__ = "image_generation_configs" + + name = Column(String(100), nullable=False, index=True) + description = Column(String(500), nullable=True) + + # Provider & model (uses ImageGenProvider, NOT LiteLLMProvider) + provider = Column(SQLAlchemyEnum(ImageGenProvider), nullable=False) + custom_provider = Column(String(100), nullable=True) + model_name = Column(String(100), nullable=False) + + # Credentials + api_key = Column(String, nullable=False) + api_base = Column(String(500), nullable=True) + api_version = Column(String(50), nullable=True) # Azure-specific + + # Additional litellm parameters + litellm_params = Column(JSON, nullable=True, default={}) + + # Relationships + search_space_id = Column( + Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False + ) + search_space = relationship( + "SearchSpace", back_populates="image_generation_configs" + ) + + # User who created this config + user_id = Column( + UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False + ) + user = relationship("User", back_populates="image_generation_configs") + + +class VisionLLMConfig(BaseModel, TimestampMixin): + __tablename__ = "vision_llm_configs" + + name = Column(String(100), nullable=False, index=True) + description = Column(String(500), nullable=True) + + provider = Column(SQLAlchemyEnum(VisionProvider), nullable=False) + custom_provider = Column(String(100), nullable=True) + model_name = Column(String(100), nullable=False) + + api_key = Column(String, nullable=False) + api_base = Column(String(500), nullable=True) + api_version = Column(String(50), nullable=True) + + litellm_params = Column(JSON, nullable=True, default={}) search_space_id = Column( - Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=True + Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False ) + search_space = relationship("SearchSpace", back_populates="vision_llm_configs") + user_id = Column( - UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=True - ) - - search_space = relationship("SearchSpace", back_populates="connections") - user = relationship("User", back_populates="connections") - models = relationship( - "Model", - back_populates="connection", - order_by="Model.id", - cascade="all, delete-orphan", - passive_deletes=True, - ) - - __table_args__ = ( - CheckConstraint( - "(scope = 'GLOBAL' AND search_space_id IS NULL AND user_id IS NULL) OR " - "(scope = 'SEARCH_SPACE' AND search_space_id IS NOT NULL AND user_id IS NOT NULL) OR " - "(scope = 'USER' AND user_id IS NOT NULL)", - name="ck_connections_scope_owner", - ), - ) - - -class Model(BaseModel, TimestampMixin): - __tablename__ = "models" - - connection_id = Column( - Integer, - ForeignKey("connections.id", ondelete="CASCADE"), - nullable=False, - index=True, - ) - model_id = Column(String(255), nullable=False) - display_name = Column(String(255), nullable=True) - source = Column( - SQLAlchemyEnum(ModelSource), - nullable=False, - default=ModelSource.DISCOVERED, - server_default=ModelSource.DISCOVERED.value, - ) - supports_chat = Column(Boolean, nullable=True) - max_input_tokens = Column(Integer, nullable=True) - supports_image_input = Column(Boolean, nullable=True) - supports_tools = Column(Boolean, nullable=True) - supports_image_generation = Column(Boolean, nullable=True) - capabilities_override = Column( - JSONB, nullable=False, default=dict, server_default="{}" - ) - enabled = Column(Boolean, nullable=False, default=True, server_default="true") - billing_tier = Column(String(50), nullable=True, index=True) - catalog = Column(JSONB, nullable=False, default=dict, server_default="{}") - - connection = relationship("Connection", back_populates="models") - - __table_args__ = ( - UniqueConstraint( - "connection_id", "model_id", name="uq_models_connection_model_id" - ), - Index("ix_models_model_id", "model_id"), + UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False ) + user = relationship("User", back_populates="vision_llm_configs") class ImageGeneration(BaseModel, TimestampMixin): @@ -1655,9 +1701,10 @@ class ImageGeneration(BaseModel, TimestampMixin): style = Column(String(50), nullable=True) # Model-specific style parameter response_format = Column(String(50), nullable=True) # "url" or "b64_json" - # Image generation model provenance. - # 0 = Auto mode, negative IDs = GLOBAL models, positive IDs = Model records. - image_gen_model_id = Column(Integer, nullable=True) + # Image generation config reference + # 0 = Auto mode (router), negative IDs = global configs from YAML, + # positive IDs = ImageGenerationConfig records in DB + image_generation_config_id = Column(Integer, nullable=True) # Response data (full litellm response as JSONB) — present on success response_data = Column(JSONB, nullable=True) @@ -1699,19 +1746,19 @@ class SearchSpace(BaseModel, TimestampMixin): shared_memory_md = Column(Text, nullable=True, server_default="") - # Connection/model role bindings. - # Note: ID values preserve the existing convention: - # - 0: Auto mode - # - Negative IDs: Global virtual models from global_llm_config.yaml - # - Positive IDs: User/search-space models from the models table - chat_model_id = Column( - Integer, nullable=True, default=0, server_default="0" + # Search space-level LLM preferences (shared by all members) + # Note: ID values: + # - 0: Auto mode (uses LiteLLM Router for load balancing) - default for new search spaces + # - Negative IDs: Global configs from YAML + # - Positive IDs: Custom configs from DB (NewLLMConfig table) + agent_llm_id = Column( + Integer, nullable=True, default=0 ) # For agent/chat operations, defaults to Auto mode - image_gen_model_id = Column( - Integer, nullable=True, default=0, server_default="0" - ) # For image generation, defaults to Auto mode when eligible - vision_model_id = Column( - Integer, nullable=True, default=0, server_default="0" + image_generation_config_id = Column( + Integer, nullable=True, default=0 + ) # For image generation, defaults to Auto mode + vision_llm_config_id = Column( + Integer, nullable=True, default=0 ) # For vision/screenshot analysis, defaults to Auto mode ai_file_sort_enabled = Column( @@ -1783,12 +1830,23 @@ class SearchSpace(BaseModel, TimestampMixin): order_by="SearchSourceConnector.id", cascade="all, delete-orphan", ) - connections = relationship( - "Connection", + new_llm_configs = relationship( + "NewLLMConfig", back_populates="search_space", - order_by="Connection.id", + order_by="NewLLMConfig.id", + cascade="all, delete-orphan", + ) + image_generation_configs = relationship( + "ImageGenerationConfig", + back_populates="search_space", + order_by="ImageGenerationConfig.id", + cascade="all, delete-orphan", + ) + vision_llm_configs = relationship( + "VisionLLMConfig", + back_populates="search_space", + order_by="VisionLLMConfig.id", cascade="all, delete-orphan", - passive_deletes=True, ) automations = relationship( @@ -1891,6 +1949,64 @@ class SearchSourceConnector(BaseModel, TimestampMixin): documents = relationship("Document", back_populates="connector") +class NewLLMConfig(BaseModel, TimestampMixin): + """ + New LLM configuration table that combines model settings with prompt configuration. + + This table provides: + - LLM model configuration (provider, model_name, api_key, etc.) + - Configurable system instructions (defaults to SURFSENSE_SYSTEM_INSTRUCTIONS) + - Citation toggle (enable/disable citation instructions) + + Note: Tools instructions are built by get_tools_instructions(thread_visibility) (personal vs shared memory). + """ + + __tablename__ = "new_llm_configs" + + name = Column(String(100), nullable=False, index=True) + description = Column(String(500), nullable=True) + + # === LLM Model Configuration (from original LLMConfig, excluding 'language') === + # Provider from the enum + provider = Column(SQLAlchemyEnum(LiteLLMProvider), nullable=False) + # Custom provider name when provider is CUSTOM + custom_provider = Column(String(100), nullable=True) + # Just the model name without provider prefix + model_name = Column(String(100), nullable=False) + # API Key should be encrypted before storing + api_key = Column(String, nullable=False) + api_base = Column(String(500), nullable=True) + # For any other parameters that litellm supports + litellm_params = Column(JSON, nullable=True, default={}) + + # === Prompt Configuration === + # Configurable system instructions (defaults to SURFSENSE_SYSTEM_INSTRUCTIONS) + # Users can customize this from the UI + system_instructions = Column( + Text, + nullable=False, + default="", # Empty string means use default SURFSENSE_SYSTEM_INSTRUCTIONS + ) + # Whether to use the default system instructions when system_instructions is empty + use_default_system_instructions = Column(Boolean, nullable=False, default=True) + + # Citation toggle - when enabled, SURFSENSE_CITATION_INSTRUCTIONS is injected + # When disabled, an anti-citation prompt is injected instead + citations_enabled = Column(Boolean, nullable=False, default=True) + + # === Relationships === + search_space_id = Column( + Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False + ) + search_space = relationship("SearchSpace", back_populates="new_llm_configs") + + # User who created this config + user_id = Column( + UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False + ) + user = relationship("User", back_populates="new_llm_configs") + + class Log(BaseModel, TimestampMixin): __tablename__ = "logs" @@ -2257,8 +2373,22 @@ if config.AUTH_TYPE == "GOOGLE": passive_deletes=True, ) - connections = relationship( - "Connection", + # LLM configs created by this user + new_llm_configs = relationship( + "NewLLMConfig", + back_populates="user", + passive_deletes=True, + ) + + # Image generation configs created by this user + image_generation_configs = relationship( + "ImageGenerationConfig", + back_populates="user", + passive_deletes=True, + ) + + vision_llm_configs = relationship( + "VisionLLMConfig", back_populates="user", passive_deletes=True, ) @@ -2389,8 +2519,22 @@ else: passive_deletes=True, ) - connections = relationship( - "Connection", + # LLM configs created by this user + new_llm_configs = relationship( + "NewLLMConfig", + back_populates="user", + passive_deletes=True, + ) + + # Image generation configs created by this user + image_generation_configs = relationship( + "ImageGenerationConfig", + back_populates="user", + passive_deletes=True, + ) + + vision_llm_configs = relationship( + "VisionLLMConfig", back_populates="user", passive_deletes=True, ) @@ -2720,39 +2864,13 @@ from app.automations.persistence import ( # noqa: E402, F401 AutomationRun, AutomationTrigger, ) -from app.etl_pipeline.cache.persistence.models import CachedParse # noqa: E402, F401 from app.file_storage.persistence import DocumentFile # noqa: E402, F401 -from app.indexing_pipeline.cache.persistence.models import ( # noqa: E402, F401 - CachedEmbeddingSet, -) from app.notifications.persistence import Notification # noqa: E402, F401 from app.podcasts.persistence import ( # noqa: E402, F401 Podcast, PodcastStatus, ) - -def _build_connect_args() -> dict: - """Build driver connect_args, including a protective idle-in-transaction - timeout for asyncpg connections. - - A single abandoned ``idle in transaction`` session can hold table/row locks - indefinitely and wedge writes plus boot-time DDL (the classic "FastAPI - stuck at application startup" failure). Setting - ``idle_in_transaction_session_timeout`` server-side makes Postgres reap such - sessions automatically. It never affects sessions that are actively running - statements — only ones that opened a transaction and went idle. - """ - connect_args: dict = {} - idle_ms = config.DB_IDLE_IN_TX_TIMEOUT_MS - # ``server_settings`` is asyncpg-specific; only apply it for that driver. - if idle_ms and idle_ms > 0 and DATABASE_URL and "asyncpg" in DATABASE_URL: - connect_args["server_settings"] = { - "idle_in_transaction_session_timeout": str(idle_ms) - } - return connect_args - - engine = create_async_engine( DATABASE_URL, pool_size=30, @@ -2760,7 +2878,6 @@ engine = create_async_engine( pool_recycle=1800, pool_pre_ping=True, pool_timeout=30, - connect_args=_build_connect_args(), ) async_session_maker = async_sessionmaker(engine, expire_on_commit=False) @@ -2785,117 +2902,54 @@ async def shielded_async_session(): await session.close() -# (index_name, table, CREATE statement). Built with CONCURRENTLY so an index -# build only takes a non-blocking ShareUpdateExclusiveLock — ingestion -# INSERT/UPDATE on documents/chunks keep flowing while the index builds, and a -# slow build can never freeze the FastAPI lifespan or block writers. -_INDEX_DEFINITIONS: list[tuple[str, str, str]] = [ - ( - "document_vector_index", - "documents", - "CREATE INDEX CONCURRENTLY IF NOT EXISTS document_vector_index ON documents USING hnsw (embedding public.vector_cosine_ops)", - ), - ( - "document_search_index", - "documents", - "CREATE INDEX CONCURRENTLY IF NOT EXISTS document_search_index ON documents USING gin (to_tsvector('english', content))", - ), - ( - "chucks_vector_index", - "chunks", - "CREATE INDEX CONCURRENTLY IF NOT EXISTS chucks_vector_index ON chunks USING hnsw (embedding public.vector_cosine_ops)", - ), - ( - "chucks_search_index", - "chunks", - "CREATE INDEX CONCURRENTLY IF NOT EXISTS chucks_search_index ON chunks USING gin (to_tsvector('english', content))", - ), - # pg_trgm index for efficient ILIKE '%term%' searches on titles — critical - # for the document mention picker (@mentions) to scale. - ( - "idx_documents_title_trgm", - "documents", - "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_documents_title_trgm ON documents USING gin (title gin_trgm_ops)", - ), - ( - "idx_documents_search_space_id", - "documents", - "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_documents_search_space_id ON documents (search_space_id)", - ), - # Covering index for "recent documents" query — enables index-only scan. - ( - "idx_documents_search_space_updated", - "documents", - "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_documents_search_space_updated ON documents (search_space_id, updated_at DESC NULLS LAST) INCLUDE (id, title, document_type)", - ), -] - - -async def _drop_invalid_index(conn, name: str) -> None: - """Drop a leftover *invalid* index so it can be rebuilt. - - A ``CREATE INDEX CONCURRENTLY`` that is interrupted (timeout, crash, - cancellation) leaves behind an ``indisvalid = false`` index. Because the - name now exists, a later ``CREATE INDEX CONCURRENTLY IF NOT EXISTS`` would - skip it and the broken index would persist forever. Detect and drop it - first. - """ - result = await conn.execute( - text("SELECT indisvalid FROM pg_index WHERE indexrelid = to_regclass(:n)"), - {"n": name}, - ) - row = result.first() - if row is not None and row[0] is False: - logger.warning( - "[startup] dropping invalid leftover index %s before rebuild", name +async def setup_indexes(): + async with engine.begin() as conn: + # Create indexes + # Document embedding indexes + await conn.execute( + text( + "CREATE INDEX IF NOT EXISTS document_vector_index ON documents USING hnsw (embedding public.vector_cosine_ops)" + ) + ) + await conn.execute( + text( + "CREATE INDEX IF NOT EXISTS document_search_index ON documents USING gin (to_tsvector('english', content))" + ) + ) + # Document Chuck Indexes + await conn.execute( + text( + "CREATE INDEX IF NOT EXISTS chucks_vector_index ON chunks USING hnsw (embedding public.vector_cosine_ops)" + ) + ) + await conn.execute( + text( + "CREATE INDEX IF NOT EXISTS chucks_search_index ON chunks USING gin (to_tsvector('english', content))" + ) + ) + # pg_trgm indexes for efficient ILIKE '%term%' searches on titles + # Critical for document mention picker (@mentions) to scale + await conn.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_documents_title_trgm ON documents USING gin (title gin_trgm_ops)" + ) + ) + # B-tree index on search_space_id for fast filtering + await conn.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_documents_search_space_id ON documents (search_space_id)" + ) + ) + # Covering index for "recent documents" query - enables index-only scan + await conn.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_documents_search_space_updated ON documents (search_space_id, updated_at DESC NULLS LAST) INCLUDE (id, title, document_type)" + ) ) - await conn.execute(text(f'DROP INDEX CONCURRENTLY IF EXISTS "{name}"')) - - -async def setup_indexes() -> None: - """Ensure search/vector indexes exist without ever blocking startup. - - Each index is created with ``CONCURRENTLY`` (so it never takes a blocking - SHARE lock on documents/chunks) under a short per-session ``lock_timeout`` - (so a contended boot fails fast instead of hanging the lifespan forever). - Failures are logged and swallowed per-index — a missing index just gets - retried on the next boot rather than crash-looping the API. - """ - lock_timeout_ms = int(config.DB_DDL_LOCK_TIMEOUT_MS) - # AUTOCOMMIT is mandatory: CREATE INDEX CONCURRENTLY cannot run inside a - # transaction block. - async with engine.connect() as base_conn: - conn = await base_conn.execution_options(isolation_level="AUTOCOMMIT") - await conn.execute(text(f"SET lock_timeout = {lock_timeout_ms}")) - for name, table, ddl in _INDEX_DEFINITIONS: - try: - await _drop_invalid_index(conn, name) - await conn.execute(text(ddl)) - except Exception as exc: - # Non-fatal by design: a missing index is retried next boot. - logger.warning( - "[startup] index %s on %s not ready (%s: %s); " - "will retry on next boot", - name, - table, - exc.__class__.__name__, - exc, - ) async def create_db_and_tables(): - if not config.DB_BOOTSTRAP_ON_STARTUP: - logger.info( - "[startup] DB bootstrap skipped (DB_BOOTSTRAP_ON_STARTUP=FALSE); " - "schema/indexes are expected to be managed by migrations" - ) - return - - lock_timeout_ms = int(config.DB_DDL_LOCK_TIMEOUT_MS) async with engine.begin() as conn: - # Fail fast instead of hanging forever if another session holds a - # conflicting lock on a table we need to touch. - await conn.execute(text(f"SET LOCAL lock_timeout = {lock_timeout_ms}")) await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) await conn.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm")) await conn.run_sync(Base.metadata.create_all) diff --git a/surfsense_backend/app/etl_pipeline/cache/__init__.py b/surfsense_backend/app/etl_pipeline/cache/__init__.py deleted file mode 100644 index 3f4585778..000000000 --- a/surfsense_backend/app/etl_pipeline/cache/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Content-addressed reuse of expensive ETL parser output across workspaces.""" - -from __future__ import annotations - -from app.etl_pipeline.cache.cached_extraction import extract_with_cache -from app.etl_pipeline.cache.service import EtlCacheService - -__all__ = [ - "EtlCacheService", - "extract_with_cache", -] diff --git a/surfsense_backend/app/etl_pipeline/cache/cached_extraction.py b/surfsense_backend/app/etl_pipeline/cache/cached_extraction.py deleted file mode 100644 index de4186b69..000000000 --- a/surfsense_backend/app/etl_pipeline/cache/cached_extraction.py +++ /dev/null @@ -1,86 +0,0 @@ -"""Entry point: serve ETL parses from cache, parsing only on a miss.""" - -from __future__ import annotations - -import asyncio -import hashlib -import logging - -from app.config import config -from app.etl_pipeline.cache.eligibility import is_parse_cacheable -from app.etl_pipeline.cache.schemas import ParseKey -from app.etl_pipeline.cache.service import EtlCacheService -from app.etl_pipeline.cache.settings import load_etl_cache_settings -from app.etl_pipeline.etl_document import EtlRequest, EtlResult -from app.etl_pipeline.etl_pipeline_service import EtlPipelineService -from app.observability import metrics - -logger = logging.getLogger(__name__) - -_HASH_CHUNK = 1024 * 1024 - - -async def extract_with_cache(request: EtlRequest, *, vision_llm=None) -> EtlResult: - """Drop-in for ``EtlPipelineService.extract`` that reuses prior parser output.""" - settings = load_etl_cache_settings() - - cacheable = is_parse_cacheable( - filename=request.filename, - etl_service=config.ETL_SERVICE, - cache_enabled=settings.enabled, - has_vision_llm=vision_llm is not None, - ) - if not cacheable: - return await EtlPipelineService(vision_llm=vision_llm).extract(request) - - key = ParseKey.for_document( - await asyncio.to_thread(_hash_file, request.file_path), - etl_service=config.ETL_SERVICE, - mode=request.processing_mode.value, - version=settings.parser_version, - ) - - cached_result = await _recall(key) - if cached_result is not None: - metrics.record_etl_cache_lookup( - etl_service=key.etl_service, mode=key.mode, outcome="hit" - ) - logger.debug("ETL cache hit for %s", key.source_sha256) - return cached_result - - metrics.record_etl_cache_lookup( - etl_service=key.etl_service, mode=key.mode, outcome="miss" - ) - result = await EtlPipelineService(vision_llm=vision_llm).extract(request) - await _remember(key, result) - return result - - -async def _recall(key: ParseKey) -> EtlResult | None: - # Caching is best-effort: any failure falls through to a normal parse. - try: - from app.tasks.celery_tasks import get_celery_session_maker - - async with get_celery_session_maker()() as session: - return await EtlCacheService(session).recall(key) - except Exception: - logger.warning("ETL cache recall failed; parsing fresh", exc_info=True) - return None - - -async def _remember(key: ParseKey, result: EtlResult) -> None: - try: - from app.tasks.celery_tasks import get_celery_session_maker - - async with get_celery_session_maker()() as session: - await EtlCacheService(session).remember(key, result) - except Exception: - logger.warning("ETL cache write failed; result not cached", exc_info=True) - - -def _hash_file(path: str) -> str: - digest = hashlib.sha256() - with open(path, "rb") as handle: - for chunk in iter(lambda: handle.read(_HASH_CHUNK), b""): - digest.update(chunk) - return digest.hexdigest() diff --git a/surfsense_backend/app/etl_pipeline/cache/eligibility.py b/surfsense_backend/app/etl_pipeline/cache/eligibility.py deleted file mode 100644 index 18f096218..000000000 --- a/surfsense_backend/app/etl_pipeline/cache/eligibility.py +++ /dev/null @@ -1,28 +0,0 @@ -"""Gating rule: may this upload be served from / written to the parse cache?""" - -from __future__ import annotations - -from app.etl_pipeline.file_classifier import FileCategory, classify_file - - -def is_parse_cacheable( - *, - filename: str, - etl_service: str | None, - cache_enabled: bool, - has_vision_llm: bool, -) -> bool: - """Only deterministic document parses are shareable across workspaces. - - Vision-LLM runs append model-generated content not captured by the cache key, - and a missing ETL service means there is no document parser to key against -- - both bypass the cache. Non-document categories (plaintext, audio, images, - direct-convert) are cheap or parser-agnostic and are handled outside it. - """ - if not cache_enabled: - return False - if has_vision_llm: - return False - if not etl_service: - return False - return classify_file(filename) == FileCategory.DOCUMENT diff --git a/surfsense_backend/app/etl_pipeline/cache/eviction/__init__.py b/surfsense_backend/app/etl_pipeline/cache/eviction/__init__.py deleted file mode 100644 index f47b9c4e0..000000000 --- a/surfsense_backend/app/etl_pipeline/cache/eviction/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Background pruning of the parse cache by age and size budget.""" - -from __future__ import annotations - -from .task import evict_etl_cache_task - -__all__ = [ - "evict_etl_cache_task", -] diff --git a/surfsense_backend/app/etl_pipeline/cache/eviction/policy.py b/surfsense_backend/app/etl_pipeline/cache/eviction/policy.py deleted file mode 100644 index 5a80752d6..000000000 --- a/surfsense_backend/app/etl_pipeline/cache/eviction/policy.py +++ /dev/null @@ -1,28 +0,0 @@ -"""Pure selection rules for which cached entries to drop.""" - -from __future__ import annotations - -from collections.abc import Iterable - -from app.etl_pipeline.cache.schemas import EvictionCandidate - - -def select_over_budget( - coldest_first: Iterable[EvictionCandidate], - *, - current_total_bytes: int, - max_total_bytes: int, -) -> list[EvictionCandidate]: - """Pick coldest entries until the footprint drops under the budget.""" - bytes_to_free = current_total_bytes - max_total_bytes - if bytes_to_free <= 0: - return [] - - chosen: list[EvictionCandidate] = [] - bytes_freed = 0 - for candidate in coldest_first: - if bytes_freed >= bytes_to_free: - break - chosen.append(candidate) - bytes_freed += candidate.size_bytes - return chosen diff --git a/surfsense_backend/app/etl_pipeline/cache/eviction/task.py b/surfsense_backend/app/etl_pipeline/cache/eviction/task.py deleted file mode 100644 index 61433f8a7..000000000 --- a/surfsense_backend/app/etl_pipeline/cache/eviction/task.py +++ /dev/null @@ -1,68 +0,0 @@ -"""Celery task that prunes the parse cache by TTL, then by size budget.""" - -from __future__ import annotations - -import contextlib -import logging -from datetime import UTC, datetime, timedelta - -from app.celery_app import celery_app -from app.etl_pipeline.cache.eviction.policy import select_over_budget -from app.etl_pipeline.cache.persistence import CachedParseRepository -from app.etl_pipeline.cache.schemas import EvictionCandidate -from app.etl_pipeline.cache.settings import load_etl_cache_settings -from app.etl_pipeline.cache.storage import MarkdownCacheStore -from app.observability import metrics -from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task - -logger = logging.getLogger(__name__) - - -@celery_app.task(name="evict_etl_cache") -def evict_etl_cache_task(): - return run_async_celery_task(_evict) - - -async def _evict() -> None: - """Expire stale entries, then shed the coldest overflow only if still over budget.""" - settings = load_etl_cache_settings() - if not settings.enabled: - return - - store = MarkdownCacheStore() - async with get_celery_session_maker()() as session: - index = CachedParseRepository(session) - - cutoff = datetime.now(UTC) - timedelta(days=settings.ttl_days) - expired = await index.select_expired( - cutoff=cutoff, limit=settings.eviction_batch - ) - await _drop(index, store, expired, phase="ttl") - - total = await index.total_size_bytes() - if total > settings.max_total_bytes: - coldest = await index.select_coldest(limit=settings.eviction_batch) - over_budget = select_over_budget( - coldest, - current_total_bytes=total, - max_total_bytes=settings.max_total_bytes, - ) - await _drop(index, store, over_budget, phase="size") - - -async def _drop( - index: CachedParseRepository, - store: MarkdownCacheStore, - candidates: list[EvictionCandidate], - *, - phase: str, -) -> None: - if not candidates: - return - for candidate in candidates: - # Drop the index row even if the blob delete fails (orphan blob is harmless). - with contextlib.suppress(Exception): - await store.delete(candidate.storage_key) - await index.delete_by_ids([candidate.id for candidate in candidates]) - metrics.record_etl_cache_eviction(len(candidates), phase=phase) - logger.info("Evicted %d cached parses (%s)", len(candidates), phase) diff --git a/surfsense_backend/app/etl_pipeline/cache/persistence/__init__.py b/surfsense_backend/app/etl_pipeline/cache/persistence/__init__.py deleted file mode 100644 index 666e4cfa8..000000000 --- a/surfsense_backend/app/etl_pipeline/cache/persistence/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Database access for cached parse rows.""" - -from __future__ import annotations - -from .models import CachedParse -from .repository import CachedParseRepository - -__all__ = [ - "CachedParse", - "CachedParseRepository", -] diff --git a/surfsense_backend/app/etl_pipeline/cache/persistence/models.py b/surfsense_backend/app/etl_pipeline/cache/persistence/models.py deleted file mode 100644 index bd20bdd12..000000000 --- a/surfsense_backend/app/etl_pipeline/cache/persistence/models.py +++ /dev/null @@ -1,49 +0,0 @@ -"""``etl_cache_parses``: one reusable parser result per (bytes + recipe).""" - -from __future__ import annotations - -from sqlalchemy import ( - BigInteger, - Column, - DateTime, - Index, - Integer, - String, - UniqueConstraint, -) - -from app.db import BaseModel, TimestampMixin - - -class CachedParse(BaseModel, TimestampMixin): - __tablename__ = "etl_cache_parses" - - # Key: raw bytes + the recipe that produced the markdown. - source_sha256 = Column(String(64), nullable=False) - etl_service = Column(String(32), nullable=False) - mode = Column(String(16), nullable=False) - parser_version = Column(Integer, nullable=False) - - # Where the markdown blob lives (kept out of the row to stay small). - storage_backend = Column(String(32), nullable=False) - storage_key = Column(String, nullable=False) - size_bytes = Column(BigInteger, nullable=False) - - # Payload needed to rebuild the EtlResult on a hit. - content_type = Column(String(32), nullable=False) - actual_pages = Column(Integer, nullable=False, default=0, server_default="0") - - # Drives eviction (popularity + recency). - times_reused = Column(BigInteger, nullable=False, default=0, server_default="0") - last_used_at = Column(DateTime(timezone=True), nullable=False) - - __table_args__ = ( - UniqueConstraint( - "source_sha256", - "etl_service", - "mode", - "parser_version", - name="uq_etl_cache_parses_key", - ), - Index("ix_etl_cache_parses_last_used_at", "last_used_at"), - ) diff --git a/surfsense_backend/app/etl_pipeline/cache/persistence/repository.py b/surfsense_backend/app/etl_pipeline/cache/persistence/repository.py deleted file mode 100644 index 05f40eae5..000000000 --- a/surfsense_backend/app/etl_pipeline/cache/persistence/repository.py +++ /dev/null @@ -1,121 +0,0 @@ -"""CRUD and eviction selectors for ``etl_cache_parses`` (no business rules).""" - -from __future__ import annotations - -from datetime import UTC, datetime - -from sqlalchemy import delete, func, select, update -from sqlalchemy.dialects.postgresql import insert as pg_insert -from sqlalchemy.ext.asyncio import AsyncSession - -from app.etl_pipeline.cache.schemas import EvictionCandidate, ParseKey - -from .models import CachedParse - -_EVICTION_COLUMNS = ( - CachedParse.id, - CachedParse.storage_key, - CachedParse.size_bytes, - CachedParse.last_used_at, - CachedParse.times_reused, -) - - -def _as_eviction_candidate(row) -> EvictionCandidate: - return EvictionCandidate( - id=row.id, - storage_key=row.storage_key, - size_bytes=row.size_bytes, - last_used_at=row.last_used_at, - times_reused=row.times_reused, - ) - - -class CachedParseRepository: - def __init__(self, session: AsyncSession) -> None: - self._session = session - - async def get(self, key: ParseKey) -> CachedParse | None: - result = await self._session.execute( - select(CachedParse).where( - CachedParse.source_sha256 == key.source_sha256, - CachedParse.etl_service == key.etl_service, - CachedParse.mode == key.mode, - CachedParse.parser_version == key.version, - ) - ) - return result.scalars().first() - - async def insert( - self, - *, - key: ParseKey, - content_type: str, - actual_pages: int, - storage_backend: str, - storage_key: str, - size_bytes: int, - ) -> None: - # Concurrent writers parse identical bytes, so a lost race is harmless. - now = datetime.now(UTC) - await self._session.execute( - pg_insert(CachedParse) - .values( - source_sha256=key.source_sha256, - etl_service=key.etl_service, - mode=key.mode, - parser_version=key.version, - content_type=content_type, - actual_pages=actual_pages, - storage_backend=storage_backend, - storage_key=storage_key, - size_bytes=size_bytes, - times_reused=0, - last_used_at=now, - created_at=now, - ) - .on_conflict_do_nothing(constraint="uq_etl_cache_parses_key") - ) - await self._session.commit() - - async def mark_used(self, row_id: int) -> None: - await self._session.execute( - update(CachedParse) - .where(CachedParse.id == row_id) - .values( - times_reused=CachedParse.times_reused + 1, - last_used_at=datetime.now(UTC), - ) - ) - await self._session.commit() - - async def total_size_bytes(self) -> int: - result = await self._session.execute( - select(func.coalesce(func.sum(CachedParse.size_bytes), 0)) - ) - return int(result.scalar() or 0) - - async def select_expired( - self, *, cutoff: datetime, limit: int - ) -> list[EvictionCandidate]: - result = await self._session.execute( - select(*_EVICTION_COLUMNS) - .where(CachedParse.last_used_at < cutoff) - .order_by(CachedParse.last_used_at.asc()) - .limit(limit) - ) - return [_as_eviction_candidate(row) for row in result] - - async def select_coldest(self, *, limit: int) -> list[EvictionCandidate]: - result = await self._session.execute( - select(*_EVICTION_COLUMNS) - .order_by(CachedParse.times_reused.asc(), CachedParse.last_used_at.asc()) - .limit(limit) - ) - return [_as_eviction_candidate(row) for row in result] - - async def delete_by_ids(self, ids: list[int]) -> None: - if not ids: - return - await self._session.execute(delete(CachedParse).where(CachedParse.id.in_(ids))) - await self._session.commit() diff --git a/surfsense_backend/app/etl_pipeline/cache/schemas/__init__.py b/surfsense_backend/app/etl_pipeline/cache/schemas/__init__.py deleted file mode 100644 index c88ac0c72..000000000 --- a/surfsense_backend/app/etl_pipeline/cache/schemas/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Pure value objects for the parse cache.""" - -from __future__ import annotations - -from .eviction_candidate import EvictionCandidate -from .parse_key import ParseKey - -__all__ = [ - "EvictionCandidate", - "ParseKey", -] diff --git a/surfsense_backend/app/etl_pipeline/cache/schemas/eviction_candidate.py b/surfsense_backend/app/etl_pipeline/cache/schemas/eviction_candidate.py deleted file mode 100644 index 13a903e7d..000000000 --- a/surfsense_backend/app/etl_pipeline/cache/schemas/eviction_candidate.py +++ /dev/null @@ -1,15 +0,0 @@ -"""Row projection handed to the eviction policy.""" - -from __future__ import annotations - -from dataclasses import dataclass -from datetime import datetime - - -@dataclass(frozen=True, slots=True) -class EvictionCandidate: - id: int - storage_key: str - size_bytes: int - last_used_at: datetime - times_reused: int diff --git a/surfsense_backend/app/etl_pipeline/cache/schemas/parse_key.py b/surfsense_backend/app/etl_pipeline/cache/schemas/parse_key.py deleted file mode 100644 index 88133a418..000000000 --- a/surfsense_backend/app/etl_pipeline/cache/schemas/parse_key.py +++ /dev/null @@ -1,28 +0,0 @@ -"""Identity of a cacheable parse: equal keys yield identical markdown.""" - -from __future__ import annotations - -from dataclasses import dataclass - - -@dataclass(frozen=True, slots=True) -class ParseKey: - source_sha256: str - etl_service: str - mode: str - version: int - - @classmethod - def for_document( - cls, source_sha256: str, *, etl_service: str, mode: str, version: int - ) -> ParseKey: - return cls( - source_sha256=source_sha256, - etl_service=etl_service, - mode=mode, - version=version, - ) - - @property - def object_suffix(self) -> str: - return f"{self.etl_service}.{self.mode}.v{self.version}.md" diff --git a/surfsense_backend/app/etl_pipeline/cache/service.py b/surfsense_backend/app/etl_pipeline/cache/service.py deleted file mode 100644 index 49398faf8..000000000 --- a/surfsense_backend/app/etl_pipeline/cache/service.py +++ /dev/null @@ -1,53 +0,0 @@ -"""Recall and remember parser output, coordinating the index and blob store.""" - -from __future__ import annotations - -import logging - -from sqlalchemy.ext.asyncio import AsyncSession - -from app.etl_pipeline.cache.persistence import CachedParseRepository -from app.etl_pipeline.cache.schemas import ParseKey -from app.etl_pipeline.cache.storage import MarkdownCacheStore -from app.etl_pipeline.etl_document import EtlResult - -logger = logging.getLogger(__name__) - - -class EtlCacheService: - def __init__(self, session: AsyncSession) -> None: - self._index = CachedParseRepository(session) - self._store = MarkdownCacheStore() - - async def recall(self, key: ParseKey) -> EtlResult | None: - """Return the cached result, or None on a miss.""" - row = await self._index.get(key) - if row is None: - return None - - try: - markdown = await self._store.load(row.storage_key) - except Exception: - # Index points at a blob that is gone; treat as a miss and re-parse. - logger.warning("Cache blob missing: %s", row.storage_key, exc_info=True) - return None - - await self._index.mark_used(row.id) - return EtlResult( - markdown_content=markdown, - etl_service=row.etl_service, - actual_pages=row.actual_pages, - content_type=row.content_type, - ) - - async def remember(self, key: ParseKey, result: EtlResult) -> None: - """Store a freshly parsed result for future reuse.""" - storage_key = await self._store.save(key, result.markdown_content) - await self._index.insert( - key=key, - content_type=result.content_type, - actual_pages=result.actual_pages, - storage_backend=self._store.backend_name, - storage_key=storage_key, - size_bytes=len(result.markdown_content.encode("utf-8")), - ) diff --git a/surfsense_backend/app/etl_pipeline/cache/settings.py b/surfsense_backend/app/etl_pipeline/cache/settings.py deleted file mode 100644 index 5911ea222..000000000 --- a/surfsense_backend/app/etl_pipeline/cache/settings.py +++ /dev/null @@ -1,33 +0,0 @@ -"""Cache configuration resolved from the central ``Config``.""" - -from __future__ import annotations - -from dataclasses import dataclass - - -@dataclass(frozen=True) -class EtlCacheSettings: - enabled: bool - parser_version: int - ttl_days: int - max_total_bytes: int - eviction_batch: int - # None for any storage_* field means: reuse the main file_storage backend. - storage_backend: str | None - storage_container: str | None - storage_local_root: str | None - - -def load_etl_cache_settings() -> EtlCacheSettings: - from app.config import config - - return EtlCacheSettings( - enabled=config.ETL_CACHE_ENABLED, - parser_version=config.ETL_CACHE_PARSER_VERSION, - ttl_days=config.ETL_CACHE_TTL_DAYS, - max_total_bytes=config.ETL_CACHE_MAX_TOTAL_MB * 1024 * 1024, - eviction_batch=config.ETL_CACHE_EVICTION_BATCH, - storage_backend=config.ETL_CACHE_STORAGE_BACKEND or None, - storage_container=config.ETL_CACHE_STORAGE_CONTAINER or None, - storage_local_root=config.ETL_CACHE_STORAGE_LOCAL_PATH or None, - ) diff --git a/surfsense_backend/app/etl_pipeline/cache/storage/__init__.py b/surfsense_backend/app/etl_pipeline/cache/storage/__init__.py deleted file mode 100644 index bed39c510..000000000 --- a/surfsense_backend/app/etl_pipeline/cache/storage/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Blob storage for cached parse markdown.""" - -from __future__ import annotations - -from .markdown_store import MarkdownCacheStore - -__all__ = [ - "MarkdownCacheStore", -] diff --git a/surfsense_backend/app/etl_pipeline/cache/storage/backend.py b/surfsense_backend/app/etl_pipeline/cache/storage/backend.py deleted file mode 100644 index 4f68ac0d3..000000000 --- a/surfsense_backend/app/etl_pipeline/cache/storage/backend.py +++ /dev/null @@ -1,48 +0,0 @@ -"""Resolve the storage backend for cache blobs: shared main store or a dedicated one.""" - -from __future__ import annotations - -from functools import lru_cache - -from app.file_storage.backends.base import StorageBackend - - -@lru_cache(maxsize=1) -def resolve_cache_backend() -> StorageBackend: - from app.etl_pipeline.cache.settings import load_etl_cache_settings - - settings = load_etl_cache_settings() - - if not settings.storage_backend: - from app.file_storage.factory import get_storage_backend - - return get_storage_backend() - - backend = settings.storage_backend.strip().lower() - - if backend == "azure": - from app.config import config - - if not settings.storage_container: - raise ValueError("ETL_CACHE_STORAGE_CONTAINER is required for azure cache.") - if not config.AZURE_STORAGE_CONNECTION_STRING: - raise ValueError( - "AZURE_STORAGE_CONNECTION_STRING is required for azure cache." - ) - from app.file_storage.backends.azure import AzureBlobBackend - - return AzureBlobBackend( - connection_string=config.AZURE_STORAGE_CONNECTION_STRING, - container=settings.storage_container, - ) - - if backend == "local": - if not settings.storage_local_root: - raise ValueError( - "ETL_CACHE_STORAGE_LOCAL_PATH is required for local cache." - ) - from app.file_storage.backends.local import LocalFileBackend - - return LocalFileBackend(settings.storage_local_root) - - raise ValueError(f"Unknown ETL_CACHE_STORAGE_BACKEND: {settings.storage_backend!r}") diff --git a/surfsense_backend/app/etl_pipeline/cache/storage/markdown_store.py b/surfsense_backend/app/etl_pipeline/cache/storage/markdown_store.py deleted file mode 100644 index 189f3508b..000000000 --- a/surfsense_backend/app/etl_pipeline/cache/storage/markdown_store.py +++ /dev/null @@ -1,35 +0,0 @@ -"""Read and write cached markdown blobs through the resolved backend.""" - -from __future__ import annotations - -from app.etl_pipeline.cache.schemas import ParseKey -from app.etl_pipeline.cache.storage.backend import resolve_cache_backend -from app.etl_pipeline.cache.storage.object_keys import build_parse_object_key - -_MARKDOWN_CONTENT_TYPE = "text/markdown; charset=utf-8" - - -class MarkdownCacheStore: - def __init__(self) -> None: - self._backend = resolve_cache_backend() - - @property - def backend_name(self) -> str: - return self._backend.backend_name - - async def save(self, key: ParseKey, markdown: str) -> str: - """Persist the markdown and return its storage key for the index row.""" - storage_key = build_parse_object_key(key) - await self._backend.put( - storage_key, - markdown.encode("utf-8"), - content_type=_MARKDOWN_CONTENT_TYPE, - ) - return storage_key - - async def load(self, storage_key: str) -> str: - chunks = [chunk async for chunk in self._backend.open_stream(storage_key)] - return b"".join(chunks).decode("utf-8") - - async def delete(self, storage_key: str) -> None: - await self._backend.delete(storage_key) diff --git a/surfsense_backend/app/etl_pipeline/cache/storage/object_keys.py b/surfsense_backend/app/etl_pipeline/cache/storage/object_keys.py deleted file mode 100644 index 7b89c3f92..000000000 --- a/surfsense_backend/app/etl_pipeline/cache/storage/object_keys.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Object keys for cached markdown, namespaced under a dedicated prefix.""" - -from __future__ import annotations - -from app.etl_pipeline.cache.schemas import ParseKey - -CACHE_PREFIX = "etl_cache" - - -def build_parse_object_key(key: ParseKey) -> str: - # Content-addressed: identical bytes + recipe always map to the same key. - return f"{CACHE_PREFIX}/{key.source_sha256}/{key.object_suffix}" diff --git a/surfsense_backend/app/gateway/__init__.py b/surfsense_backend/app/gateway/__init__.py index 89b931bc3..8b79b3160 100644 --- a/surfsense_backend/app/gateway/__init__.py +++ b/surfsense_backend/app/gateway/__init__.py @@ -8,7 +8,7 @@ from app.config import config def require_gateway_enabled() -> None: - """FastAPI dependency that gates gateway operational routes on the global flag. + """FastAPI dependency that gates all gateway HTTP routes on the global flag. Returns 404 (rather than 503) when ``GATEWAY_ENABLED`` is FALSE so that disabling the gateway makes its webhook/OAuth/pairing surface indistinguishable diff --git a/surfsense_backend/app/indexing_pipeline/cache/__init__.py b/surfsense_backend/app/indexing_pipeline/cache/__init__.py deleted file mode 100644 index d3b9e5f0d..000000000 --- a/surfsense_backend/app/indexing_pipeline/cache/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Content-addressed reuse of chunk+embedding output across workspaces.""" - -from __future__ import annotations - -from app.indexing_pipeline.cache.cached_indexing import build_chunk_embeddings -from app.indexing_pipeline.cache.service import EmbeddingCacheService - -__all__ = [ - "EmbeddingCacheService", - "build_chunk_embeddings", -] diff --git a/surfsense_backend/app/indexing_pipeline/cache/cached_indexing.py b/surfsense_backend/app/indexing_pipeline/cache/cached_indexing.py deleted file mode 100644 index 95321a229..000000000 --- a/surfsense_backend/app/indexing_pipeline/cache/cached_indexing.py +++ /dev/null @@ -1,129 +0,0 @@ -"""Entry point: serve chunk embeddings from cache, embedding only on a miss. - -Embeddings are a pure function of the markdown, the embedding model, and the -chunker -- so identical markdown is chunked and embedded once and reused across -workspaces, even when it came from different sources. -""" - -from __future__ import annotations - -import asyncio -import hashlib -import logging - -import numpy as np - -from app.config import config -from app.indexing_pipeline.cache.eligibility import is_embedding_cacheable -from app.indexing_pipeline.cache.schemas import CachedChunk, EmbeddingKey, EmbeddingSet -from app.indexing_pipeline.cache.service import EmbeddingCacheService -from app.indexing_pipeline.cache.settings import load_embedding_cache_settings -from app.indexing_pipeline.document_chunker import chunk_text, chunk_text_hybrid -from app.indexing_pipeline.document_embedder import embed_texts -from app.observability import metrics - -logger = logging.getLogger(__name__) - -ChunkPair = tuple[str, np.ndarray] - - -async def build_chunk_embeddings( - markdown: str, *, use_code_chunker: bool -) -> tuple[np.ndarray, list[ChunkPair]]: - """Return the document-level vector and ordered ``(chunk_text, vector)`` pairs. - - Drop-in for the inline chunk+embed step; reuses prior output when the same - markdown has already been embedded with the current model and chunker. - """ - settings = load_embedding_cache_settings() - chunker_kind = "code" if use_code_chunker else "hybrid" - embedding_dim = getattr(config.embedding_model_instance, "dimension", None) - - cacheable = is_embedding_cacheable( - cache_enabled=settings.enabled, - embedding_model=config.EMBEDDING_MODEL, - embedding_dim=embedding_dim, - ) - if not cacheable: - return await _compute(markdown, use_code_chunker=use_code_chunker) - - key = EmbeddingKey( - markdown_sha256=_hash_text(markdown), - embedding_model=config.EMBEDDING_MODEL, - embedding_dim=int(embedding_dim), - chunker_kind=chunker_kind, - chunker_version=settings.chunker_version, - ) - - cached = await _recall(key) - if cached is not None: - metrics.record_embedding_cache_lookup( - embedding_model=key.embedding_model, - chunker_kind=chunker_kind, - outcome="hit", - ) - logger.debug("Embedding cache hit for %s", key.markdown_sha256) - return cached.summary_embedding, [(c.text, c.embedding) for c in cached.chunks] - - metrics.record_embedding_cache_lookup( - embedding_model=key.embedding_model, chunker_kind=chunker_kind, outcome="miss" - ) - summary_embedding, chunk_pairs = await _compute( - markdown, use_code_chunker=use_code_chunker - ) - await _remember(key, summary_embedding, chunk_pairs) - return summary_embedding, chunk_pairs - - -async def chunk_markdown(markdown: str, *, use_code_chunker: bool) -> list[str]: - """Chunk markdown into ordered texts with the pipeline's chunker selection.""" - if use_code_chunker: - return await asyncio.to_thread(chunk_text, markdown, use_code_chunker=True) - # Table-aware hybrid chunker keeps Markdown tables intact (issue #1334). - return await asyncio.to_thread(chunk_text_hybrid, markdown) - - -async def embed_batch(texts: list[str]) -> list[np.ndarray]: - """Embed texts in one batch off the event loop.""" - return await asyncio.to_thread(embed_texts, texts) - - -async def _compute( - markdown: str, *, use_code_chunker: bool -) -> tuple[np.ndarray, list[ChunkPair]]: - chunk_texts = await chunk_markdown(markdown, use_code_chunker=use_code_chunker) - embeddings = await embed_batch([markdown, *chunk_texts]) - summary_embedding, *chunk_embeddings = embeddings - return summary_embedding, list(zip(chunk_texts, chunk_embeddings, strict=False)) - - -async def _recall(key: EmbeddingKey) -> EmbeddingSet | None: - # Caching is best-effort: any failure falls through to a normal embed. - try: - from app.tasks.celery_tasks import get_celery_session_maker - - async with get_celery_session_maker()() as session: - return await EmbeddingCacheService(session).recall(key) - except Exception: - logger.warning("Embedding cache recall failed; embedding fresh", exc_info=True) - return None - - -async def _remember( - key: EmbeddingKey, summary_embedding: np.ndarray, chunk_pairs: list[ChunkPair] -) -> None: - try: - from app.tasks.celery_tasks import get_celery_session_maker - - embedding_set = EmbeddingSet( - summary_embedding=summary_embedding, - chunks=[CachedChunk(text=text, embedding=vec) for text, vec in chunk_pairs], - ) - async with get_celery_session_maker()() as session: - await EmbeddingCacheService(session).remember(key, embedding_set) - except Exception: - logger.warning("Embedding cache write failed; result not cached", exc_info=True) - - -def _hash_text(text: str) -> str: - return hashlib.sha256(text.encode("utf-8")).hexdigest() diff --git a/surfsense_backend/app/indexing_pipeline/cache/eligibility.py b/surfsense_backend/app/indexing_pipeline/cache/eligibility.py deleted file mode 100644 index 446bea2f8..000000000 --- a/surfsense_backend/app/indexing_pipeline/cache/eligibility.py +++ /dev/null @@ -1,21 +0,0 @@ -"""Gating rule: may this document be served from / written to the embedding cache?""" - -from __future__ import annotations - - -def is_embedding_cacheable( - *, - cache_enabled: bool, - embedding_model: str | None, - embedding_dim: int | None, -) -> bool: - """Cache only when a concrete embedding model and dimension are configured. - - Without a model there is nothing to key against, and without a dimension the - blob's integrity guard cannot run -- both bypass the cache. - """ - if not cache_enabled: - return False - if not embedding_model: - return False - return bool(embedding_dim) diff --git a/surfsense_backend/app/indexing_pipeline/cache/eviction/__init__.py b/surfsense_backend/app/indexing_pipeline/cache/eviction/__init__.py deleted file mode 100644 index a0f74b360..000000000 --- a/surfsense_backend/app/indexing_pipeline/cache/eviction/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Background pruning of the embedding cache by age and size budget.""" - -from __future__ import annotations - -from .task import evict_embedding_cache_task - -__all__ = [ - "evict_embedding_cache_task", -] diff --git a/surfsense_backend/app/indexing_pipeline/cache/eviction/task.py b/surfsense_backend/app/indexing_pipeline/cache/eviction/task.py deleted file mode 100644 index 70eff6ea5..000000000 --- a/surfsense_backend/app/indexing_pipeline/cache/eviction/task.py +++ /dev/null @@ -1,68 +0,0 @@ -"""Celery task that prunes the embedding cache by TTL, then by size budget.""" - -from __future__ import annotations - -import contextlib -import logging -from datetime import UTC, datetime, timedelta - -from app.celery_app import celery_app -from app.etl_pipeline.cache.eviction.policy import select_over_budget -from app.etl_pipeline.cache.schemas import EvictionCandidate -from app.indexing_pipeline.cache.persistence import CachedEmbeddingSetRepository -from app.indexing_pipeline.cache.settings import load_embedding_cache_settings -from app.indexing_pipeline.cache.storage import EmbeddingCacheStore -from app.observability import metrics -from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task - -logger = logging.getLogger(__name__) - - -@celery_app.task(name="evict_embedding_cache") -def evict_embedding_cache_task(): - return run_async_celery_task(_evict) - - -async def _evict() -> None: - """Expire stale entries, then shed the coldest overflow only if still over budget.""" - settings = load_embedding_cache_settings() - if not settings.enabled: - return - - store = EmbeddingCacheStore() - async with get_celery_session_maker()() as session: - index = CachedEmbeddingSetRepository(session) - - cutoff = datetime.now(UTC) - timedelta(days=settings.ttl_days) - expired = await index.select_expired( - cutoff=cutoff, limit=settings.eviction_batch - ) - await _drop(index, store, expired, phase="ttl") - - total = await index.total_size_bytes() - if total > settings.max_total_bytes: - coldest = await index.select_coldest(limit=settings.eviction_batch) - over_budget = select_over_budget( - coldest, - current_total_bytes=total, - max_total_bytes=settings.max_total_bytes, - ) - await _drop(index, store, over_budget, phase="size") - - -async def _drop( - index: CachedEmbeddingSetRepository, - store: EmbeddingCacheStore, - candidates: list[EvictionCandidate], - *, - phase: str, -) -> None: - if not candidates: - return - for candidate in candidates: - # Drop the index row even if the blob delete fails (orphan blob is harmless). - with contextlib.suppress(Exception): - await store.delete(candidate.storage_key) - await index.delete_by_ids([candidate.id for candidate in candidates]) - metrics.record_embedding_cache_eviction(len(candidates), phase=phase) - logger.info("Evicted %d cached embedding sets (%s)", len(candidates), phase) diff --git a/surfsense_backend/app/indexing_pipeline/cache/persistence/__init__.py b/surfsense_backend/app/indexing_pipeline/cache/persistence/__init__.py deleted file mode 100644 index 62cde0d05..000000000 --- a/surfsense_backend/app/indexing_pipeline/cache/persistence/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Database access for cached embedding sets.""" - -from __future__ import annotations - -from .models import CachedEmbeddingSet -from .repository import CachedEmbeddingSetRepository - -__all__ = [ - "CachedEmbeddingSet", - "CachedEmbeddingSetRepository", -] diff --git a/surfsense_backend/app/indexing_pipeline/cache/persistence/models.py b/surfsense_backend/app/indexing_pipeline/cache/persistence/models.py deleted file mode 100644 index af34d92d2..000000000 --- a/surfsense_backend/app/indexing_pipeline/cache/persistence/models.py +++ /dev/null @@ -1,47 +0,0 @@ -"""``embedding_cache_sets``: one reusable chunk+embedding set per markdown.""" - -from __future__ import annotations - -from sqlalchemy import ( - BigInteger, - Column, - DateTime, - Index, - Integer, - String, - UniqueConstraint, -) - -from app.db import BaseModel, TimestampMixin - - -class CachedEmbeddingSet(BaseModel, TimestampMixin): - __tablename__ = "embedding_cache_sets" - - # Key: markdown text + the recipe that turned it into vectors. - markdown_sha256 = Column(String(64), nullable=False) - embedding_model = Column(String(255), nullable=False) - embedding_dim = Column(Integer, nullable=False) - chunker_kind = Column(String(8), nullable=False) - chunker_version = Column(Integer, nullable=False) - - # Where the embedding blob lives (kept out of the row to stay small). - storage_backend = Column(String(32), nullable=False) - storage_key = Column(String, nullable=False) - size_bytes = Column(BigInteger, nullable=False) - chunk_count = Column(Integer, nullable=False, default=0, server_default="0") - - # Drives eviction (popularity + recency). - times_reused = Column(BigInteger, nullable=False, default=0, server_default="0") - last_used_at = Column(DateTime(timezone=True), nullable=False) - - __table_args__ = ( - UniqueConstraint( - "markdown_sha256", - "embedding_model", - "chunker_kind", - "chunker_version", - name="uq_embedding_cache_sets_key", - ), - Index("ix_embedding_cache_sets_last_used_at", "last_used_at"), - ) diff --git a/surfsense_backend/app/indexing_pipeline/cache/persistence/repository.py b/surfsense_backend/app/indexing_pipeline/cache/persistence/repository.py deleted file mode 100644 index f7f1f4345..000000000 --- a/surfsense_backend/app/indexing_pipeline/cache/persistence/repository.py +++ /dev/null @@ -1,126 +0,0 @@ -"""CRUD and eviction selectors for ``embedding_cache_sets`` (no business rules).""" - -from __future__ import annotations - -from datetime import UTC, datetime - -from sqlalchemy import delete, func, select, update -from sqlalchemy.dialects.postgresql import insert as pg_insert -from sqlalchemy.ext.asyncio import AsyncSession - -from app.etl_pipeline.cache.schemas import EvictionCandidate -from app.indexing_pipeline.cache.schemas import EmbeddingKey - -from .models import CachedEmbeddingSet - -_EVICTION_COLUMNS = ( - CachedEmbeddingSet.id, - CachedEmbeddingSet.storage_key, - CachedEmbeddingSet.size_bytes, - CachedEmbeddingSet.last_used_at, - CachedEmbeddingSet.times_reused, -) - - -def _as_eviction_candidate(row) -> EvictionCandidate: - return EvictionCandidate( - id=row.id, - storage_key=row.storage_key, - size_bytes=row.size_bytes, - last_used_at=row.last_used_at, - times_reused=row.times_reused, - ) - - -class CachedEmbeddingSetRepository: - def __init__(self, session: AsyncSession) -> None: - self._session = session - - async def get(self, key: EmbeddingKey) -> CachedEmbeddingSet | None: - result = await self._session.execute( - select(CachedEmbeddingSet).where( - CachedEmbeddingSet.markdown_sha256 == key.markdown_sha256, - CachedEmbeddingSet.embedding_model == key.embedding_model, - CachedEmbeddingSet.chunker_kind == key.chunker_kind, - CachedEmbeddingSet.chunker_version == key.chunker_version, - ) - ) - return result.scalars().first() - - async def insert( - self, - *, - key: EmbeddingKey, - storage_backend: str, - storage_key: str, - size_bytes: int, - chunk_count: int, - ) -> None: - # Concurrent writers embed identical markdown, so a lost race is harmless. - now = datetime.now(UTC) - await self._session.execute( - pg_insert(CachedEmbeddingSet) - .values( - markdown_sha256=key.markdown_sha256, - embedding_model=key.embedding_model, - embedding_dim=key.embedding_dim, - chunker_kind=key.chunker_kind, - chunker_version=key.chunker_version, - storage_backend=storage_backend, - storage_key=storage_key, - size_bytes=size_bytes, - chunk_count=chunk_count, - times_reused=0, - last_used_at=now, - created_at=now, - ) - .on_conflict_do_nothing(constraint="uq_embedding_cache_sets_key") - ) - await self._session.commit() - - async def mark_used(self, row_id: int) -> None: - await self._session.execute( - update(CachedEmbeddingSet) - .where(CachedEmbeddingSet.id == row_id) - .values( - times_reused=CachedEmbeddingSet.times_reused + 1, - last_used_at=datetime.now(UTC), - ) - ) - await self._session.commit() - - async def total_size_bytes(self) -> int: - result = await self._session.execute( - select(func.coalesce(func.sum(CachedEmbeddingSet.size_bytes), 0)) - ) - return int(result.scalar() or 0) - - async def select_expired( - self, *, cutoff: datetime, limit: int - ) -> list[EvictionCandidate]: - result = await self._session.execute( - select(*_EVICTION_COLUMNS) - .where(CachedEmbeddingSet.last_used_at < cutoff) - .order_by(CachedEmbeddingSet.last_used_at.asc()) - .limit(limit) - ) - return [_as_eviction_candidate(row) for row in result] - - async def select_coldest(self, *, limit: int) -> list[EvictionCandidate]: - result = await self._session.execute( - select(*_EVICTION_COLUMNS) - .order_by( - CachedEmbeddingSet.times_reused.asc(), - CachedEmbeddingSet.last_used_at.asc(), - ) - .limit(limit) - ) - return [_as_eviction_candidate(row) for row in result] - - async def delete_by_ids(self, ids: list[int]) -> None: - if not ids: - return - await self._session.execute( - delete(CachedEmbeddingSet).where(CachedEmbeddingSet.id.in_(ids)) - ) - await self._session.commit() diff --git a/surfsense_backend/app/indexing_pipeline/cache/schemas/__init__.py b/surfsense_backend/app/indexing_pipeline/cache/schemas/__init__.py deleted file mode 100644 index c200ca1a6..000000000 --- a/surfsense_backend/app/indexing_pipeline/cache/schemas/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Pure value objects for the embedding cache.""" - -from __future__ import annotations - -from .embedding_key import EmbeddingKey -from .embedding_set import CachedChunk, EmbeddingSet - -__all__ = [ - "CachedChunk", - "EmbeddingKey", - "EmbeddingSet", -] diff --git a/surfsense_backend/app/indexing_pipeline/cache/schemas/embedding_key.py b/surfsense_backend/app/indexing_pipeline/cache/schemas/embedding_key.py deleted file mode 100644 index 55d891e73..000000000 --- a/surfsense_backend/app/indexing_pipeline/cache/schemas/embedding_key.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Identity of a cacheable embedding set: equal keys yield identical vectors. - -Embeddings depend on the markdown text, the embedding model, and the chunker -- -never on how the markdown was produced. So the key is the markdown's own hash -plus the model and chunker recipe, not the upstream parse identity. -""" - -from __future__ import annotations - -import hashlib -from dataclasses import dataclass - - -@dataclass(frozen=True, slots=True) -class EmbeddingKey: - markdown_sha256: str - embedding_model: str - embedding_dim: int - chunker_kind: str - chunker_version: int - - @property - def object_suffix(self) -> str: - # Fingerprint the model so distinct models never share a blob, while the - # markdown hash (the object's folder) stays human-readable. - fingerprint = hashlib.sha256(self.embedding_model.encode("utf-8")).hexdigest() - return f"{fingerprint[:16]}.{self.chunker_kind}.v{self.chunker_version}.emb" diff --git a/surfsense_backend/app/indexing_pipeline/cache/schemas/embedding_set.py b/surfsense_backend/app/indexing_pipeline/cache/schemas/embedding_set.py deleted file mode 100644 index 68c3a5211..000000000 --- a/surfsense_backend/app/indexing_pipeline/cache/schemas/embedding_set.py +++ /dev/null @@ -1,29 +0,0 @@ -"""The cached payload: a document's chunk texts paired with their vectors.""" - -from __future__ import annotations - -from dataclasses import dataclass - -import numpy as np - - -@dataclass(frozen=True, slots=True) -class CachedChunk: - text: str - embedding: np.ndarray - - -@dataclass(frozen=True, slots=True) -class EmbeddingSet: - """Everything the indexer needs to rebuild a document's chunks without embedding. - - ``summary_embedding`` is the document-level vector; ``chunks`` are the ordered - chunk texts and their vectors. - """ - - summary_embedding: np.ndarray - chunks: list[CachedChunk] - - @property - def chunk_count(self) -> int: - return len(self.chunks) diff --git a/surfsense_backend/app/indexing_pipeline/cache/serialization.py b/surfsense_backend/app/indexing_pipeline/cache/serialization.py deleted file mode 100644 index fde0acd00..000000000 --- a/surfsense_backend/app/indexing_pipeline/cache/serialization.py +++ /dev/null @@ -1,75 +0,0 @@ -"""Serialize an EmbeddingSet to a compact, self-describing blob (no pickle). - -Layout: ``MAGIC | uint32 header_len | json header | float32 matrix``. The header -carries the dim, chunk count, and ordered chunk texts; the matrix holds the -summary vector followed by one row per chunk, all float32 for compactness. -""" - -from __future__ import annotations - -import json -import struct - -import numpy as np - -from app.indexing_pipeline.cache.schemas import CachedChunk, EmbeddingSet - -# Marker at the start of every blob: "SurfSense EMBeddings, version 1"-> SSEMB1. Lets us -# reject foreign blobs and bump the trailing digit if the layout ever changes. -_MAGIC = b"SSEMB1" -# 4-byte big-endian unsigned int written before the variable-length JSON header, -# so the reader knows where the header ends and the float matrix begins. -_HEADER_LEN = struct.Struct(">I") - - -def serialize(embedding_set: EmbeddingSet) -> bytes: - summary = np.asarray(embedding_set.summary_embedding, dtype=np.float32).reshape(-1) - dim = int(summary.shape[0]) - - rows = [summary] - texts: list[str] = [] - for chunk in embedding_set.chunks: - vector = np.asarray(chunk.embedding, dtype=np.float32).reshape(-1) - if vector.shape[0] != dim: - raise ValueError( - "All vectors in an embedding set must share one dimension." - ) - rows.append(vector) - texts.append(chunk.text) - - matrix = np.stack(rows, axis=0) - header = json.dumps( - {"dim": dim, "count": len(texts), "texts": texts}, ensure_ascii=False - ).encode("utf-8") - return b"".join( - [_MAGIC, _HEADER_LEN.pack(len(header)), header, matrix.tobytes(order="C")] - ) - - -def deserialize(blob: bytes) -> EmbeddingSet: - view = memoryview(blob) - if bytes(view[: len(_MAGIC)]) != _MAGIC: - raise ValueError("Unrecognized embedding cache blob.") - - offset = len(_MAGIC) - (header_len,) = _HEADER_LEN.unpack(view[offset : offset + _HEADER_LEN.size]) - offset += _HEADER_LEN.size - - header = json.loads(bytes(view[offset : offset + header_len]).decode("utf-8")) - offset += header_len - - dim = int(header["dim"]) - count = int(header["count"]) - texts: list[str] = header["texts"] - - matrix = np.frombuffer(view[offset:], dtype=np.float32) - if matrix.shape[0] != (count + 1) * dim: - raise ValueError("Embedding cache blob is truncated or corrupt.") - matrix = matrix.reshape(count + 1, dim) - - return EmbeddingSet( - summary_embedding=matrix[0], - chunks=[ - CachedChunk(text=texts[i], embedding=matrix[i + 1]) for i in range(count) - ], - ) diff --git a/surfsense_backend/app/indexing_pipeline/cache/service.py b/surfsense_backend/app/indexing_pipeline/cache/service.py deleted file mode 100644 index b1d634782..000000000 --- a/surfsense_backend/app/indexing_pipeline/cache/service.py +++ /dev/null @@ -1,51 +0,0 @@ -"""Recall and remember embedding sets, coordinating the index and blob store.""" - -from __future__ import annotations - -import logging - -from sqlalchemy.ext.asyncio import AsyncSession - -from app.indexing_pipeline.cache.persistence import CachedEmbeddingSetRepository -from app.indexing_pipeline.cache.schemas import EmbeddingKey, EmbeddingSet -from app.indexing_pipeline.cache.storage import EmbeddingCacheStore - -logger = logging.getLogger(__name__) - - -class EmbeddingCacheService: - def __init__(self, session: AsyncSession) -> None: - self._index = CachedEmbeddingSetRepository(session) - self._store = EmbeddingCacheStore() - - async def recall(self, key: EmbeddingKey) -> EmbeddingSet | None: - """Return the cached embedding set, or None on a miss.""" - row = await self._index.get(key) - if row is None: - return None - - try: - embedding_set = await self._store.load(row.storage_key) - except Exception: - # Index points at a blob that is gone; treat as a miss and re-embed. - logger.warning("Cache blob missing: %s", row.storage_key, exc_info=True) - return None - - if int(embedding_set.summary_embedding.shape[0]) != key.embedding_dim: - # A model swapped its dimension under a reused name; never serve it. - logger.warning("Cached embedding dimension mismatch: %s", row.storage_key) - return None - - await self._index.mark_used(row.id) - return embedding_set - - async def remember(self, key: EmbeddingKey, embedding_set: EmbeddingSet) -> None: - """Store a freshly embedded set for future reuse.""" - storage_key, size_bytes = await self._store.save(key, embedding_set) - await self._index.insert( - key=key, - storage_backend=self._store.backend_name, - storage_key=storage_key, - size_bytes=size_bytes, - chunk_count=embedding_set.chunk_count, - ) diff --git a/surfsense_backend/app/indexing_pipeline/cache/settings.py b/surfsense_backend/app/indexing_pipeline/cache/settings.py deleted file mode 100644 index 9c6737445..000000000 --- a/surfsense_backend/app/indexing_pipeline/cache/settings.py +++ /dev/null @@ -1,30 +0,0 @@ -"""Embedding-cache configuration resolved from the central ``Config``. - -The blob backend is intentionally not configured here: it is shared with the ETL -parse cache (see ``ETL_CACHE_STORAGE_*``). -""" - -from __future__ import annotations - -from dataclasses import dataclass - - -@dataclass(frozen=True) -class EmbeddingCacheSettings: - enabled: bool - chunker_version: int - ttl_days: int - max_total_bytes: int - eviction_batch: int - - -def load_embedding_cache_settings() -> EmbeddingCacheSettings: - from app.config import config - - return EmbeddingCacheSettings( - enabled=config.EMBEDDING_CACHE_ENABLED, - chunker_version=config.EMBEDDING_CACHE_CHUNKER_VERSION, - ttl_days=config.EMBEDDING_CACHE_TTL_DAYS, - max_total_bytes=config.EMBEDDING_CACHE_MAX_TOTAL_MB * 1024 * 1024, - eviction_batch=config.EMBEDDING_CACHE_EVICTION_BATCH, - ) diff --git a/surfsense_backend/app/indexing_pipeline/cache/storage/__init__.py b/surfsense_backend/app/indexing_pipeline/cache/storage/__init__.py deleted file mode 100644 index 72b04c34d..000000000 --- a/surfsense_backend/app/indexing_pipeline/cache/storage/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Blob storage for cached embedding sets.""" - -from __future__ import annotations - -from .embedding_store import EmbeddingCacheStore - -__all__ = [ - "EmbeddingCacheStore", -] diff --git a/surfsense_backend/app/indexing_pipeline/cache/storage/embedding_store.py b/surfsense_backend/app/indexing_pipeline/cache/storage/embedding_store.py deleted file mode 100644 index 7b0329b4e..000000000 --- a/surfsense_backend/app/indexing_pipeline/cache/storage/embedding_store.py +++ /dev/null @@ -1,39 +0,0 @@ -"""Read and write cached embedding blobs through the shared cache backend. - -The blob backend is shared with the ETL parse cache (same bucket / root), so -markdown and its embeddings live side by side; only the object prefix differs. -""" - -from __future__ import annotations - -from app.etl_pipeline.cache.storage.backend import resolve_cache_backend -from app.indexing_pipeline.cache.schemas import EmbeddingKey, EmbeddingSet -from app.indexing_pipeline.cache.serialization import deserialize, serialize -from app.indexing_pipeline.cache.storage.object_keys import build_embedding_object_key - -_EMBEDDING_CONTENT_TYPE = "application/octet-stream" - - -class EmbeddingCacheStore: - def __init__(self) -> None: - self._backend = resolve_cache_backend() - - @property - def backend_name(self) -> str: - return self._backend.backend_name - - async def save( - self, key: EmbeddingKey, embedding_set: EmbeddingSet - ) -> tuple[str, int]: - """Persist the embedding set and return its storage key and byte size.""" - blob = serialize(embedding_set) - storage_key = build_embedding_object_key(key) - await self._backend.put(storage_key, blob, content_type=_EMBEDDING_CONTENT_TYPE) - return storage_key, len(blob) - - async def load(self, storage_key: str) -> EmbeddingSet: - chunks = [chunk async for chunk in self._backend.open_stream(storage_key)] - return deserialize(b"".join(chunks)) - - async def delete(self, storage_key: str) -> None: - await self._backend.delete(storage_key) diff --git a/surfsense_backend/app/indexing_pipeline/cache/storage/object_keys.py b/surfsense_backend/app/indexing_pipeline/cache/storage/object_keys.py deleted file mode 100644 index 6286ccf90..000000000 --- a/surfsense_backend/app/indexing_pipeline/cache/storage/object_keys.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Object keys for cached embedding sets, namespaced under a dedicated prefix.""" - -from __future__ import annotations - -from app.indexing_pipeline.cache.schemas import EmbeddingKey - -CACHE_PREFIX = "embedding_cache" - - -def build_embedding_object_key(key: EmbeddingKey) -> str: - # Content-addressed: identical markdown + recipe always map to the same key. - return f"{CACHE_PREFIX}/{key.markdown_sha256}/{key.object_suffix}" diff --git a/surfsense_backend/app/indexing_pipeline/chunk_reconciler.py b/surfsense_backend/app/indexing_pipeline/chunk_reconciler.py deleted file mode 100644 index 9354aeb9f..000000000 --- a/surfsense_backend/app/indexing_pipeline/chunk_reconciler.py +++ /dev/null @@ -1,56 +0,0 @@ -"""Diff a document's existing chunk rows against its freshly chunked texts. - -Embeddings are a pure function of chunk text, so a row whose content reappears -in the new chunking keeps its embedding (and its HNSW/GIN index entries); only -genuinely new texts are embedded and only vanished rows are deleted. Matching -is a greedy multiset match on content in document order, so duplicate -boilerplate chunks pair up one-to-one and reordered chunks become cheap -position updates instead of delete+reinsert. -""" - -from __future__ import annotations - -from collections import defaultdict, deque -from dataclasses import dataclass - - -@dataclass(frozen=True, slots=True) -class ExistingChunk: - id: int - content: str - position: int - - -@dataclass(frozen=True, slots=True) -class ChunkPlan: - """The minimal set of writes that turns the stored chunks into the new ones. - - ``reused`` holds only kept rows whose position actually changed; rows that - match in place need no write at all. Kept-row count (for metrics) is - ``len(existing) - len(to_delete)``. - """ - - reused: list[tuple[int, int]] # (existing_chunk_id, new_position) - to_embed: list[tuple[int, str]] # (new_position, text) - to_delete: list[int] # existing chunk ids - - -def reconcile(existing: list[ExistingChunk], new_texts: list[str]) -> ChunkPlan: - available: dict[str, deque[ExistingChunk]] = defaultdict(deque) - for chunk in sorted(existing, key=lambda c: c.position): - available[chunk.content].append(chunk) - - reused: list[tuple[int, int]] = [] - to_embed: list[tuple[int, str]] = [] - - for new_position, text in enumerate(new_texts): - matches = available.get(text) - if matches: - chunk = matches.popleft() - if chunk.position != new_position: - reused.append((chunk.id, new_position)) - else: - to_embed.append((new_position, text)) - - to_delete = [chunk.id for queue in available.values() for chunk in queue] - return ChunkPlan(reused=reused, to_embed=to_embed, to_delete=to_delete) diff --git a/surfsense_backend/app/indexing_pipeline/document_persistence.py b/surfsense_backend/app/indexing_pipeline/document_persistence.py index b716560d2..9fd8867e2 100644 --- a/surfsense_backend/app/indexing_pipeline/document_persistence.py +++ b/surfsense_backend/app/indexing_pipeline/document_persistence.py @@ -1,12 +1,12 @@ import contextlib import logging -import time from datetime import UTC, datetime from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import object_session from sqlalchemy.orm.attributes import set_committed_value -from app.db import Chunk, Document, DocumentStatus +from app.db import Document, DocumentStatus logger = logging.getLogger(__name__) @@ -22,6 +22,7 @@ async def rollback_and_persist_failure( try: await session.rollback() except Exception: + # Session is completely dead; surface it but never raise. logger.warning( "Rollback failed; cannot persist failed status for document %s", getattr(document, "id", "unknown"), @@ -34,6 +35,8 @@ async def rollback_and_persist_failure( document.status = DocumentStatus.failed(message) await session.commit() except Exception: + # Best-effort: the document stays non-ready and is retried next sync. + # Log it so a permanently-stuck document is at least traceable. logger.warning( "Could not persist failed status for document %s; will retry next sync", getattr(document, "id", "unknown"), @@ -43,60 +46,12 @@ async def rollback_and_persist_failure( await session.rollback() -async def persist_scratch_index( - session: AsyncSession, - document: Document, - content: str, - chunks: list[Chunk], - *, - batch_size: int, - perf: logging.Logger, -) -> None: - """Commit document content first, then chunk rows in batches, then mark ready.""" - if document.id is None: - raise ValueError("document.id is required to persist chunks") - - document.content = content - document.updated_at = datetime.now(UTC) - await session.commit() - - t_persist = time.perf_counter() - total = len(chunks) - if total == 0: - set_committed_value(document, "chunks", []) - document.status = DocumentStatus.ready() - document.updated_at = datetime.now(UTC) - await session.commit() - return - - effective_batch = total if batch_size <= 0 else batch_size - num_batches = (total + effective_batch - 1) // effective_batch - doc_id = document.id - - for batch_idx, start in enumerate(range(0, total, effective_batch), start=1): - batch = chunks[start : start + effective_batch] - t_batch = time.perf_counter() - for chunk in batch: - chunk.document_id = doc_id - session.add_all(batch) - await session.commit() - perf.info( - "[indexing] chunk batch doc=%d batch=%d/%d rows=%d in %.3fs", - doc_id, - batch_idx, - num_batches, - len(batch), - time.perf_counter() - t_batch, - ) - +def attach_chunks_to_document(document: Document, chunks: list) -> None: + """Assign chunks to a document without triggering SQLAlchemy async lazy loading.""" set_committed_value(document, "chunks", chunks) - document.status = DocumentStatus.ready() - document.updated_at = datetime.now(UTC) - await session.commit() - perf.info( - "[indexing] chunk persist doc=%d chunks=%d batches=%d in %.3fs", - doc_id, - total, - num_batches, - time.perf_counter() - t_persist, - ) + session = object_session(document) + if session is not None: + if document.id is not None: + for chunk in chunks: + chunk.document_id = document.id + session.add_all(chunks) diff --git a/surfsense_backend/app/indexing_pipeline/exceptions.py b/surfsense_backend/app/indexing_pipeline/exceptions.py index bf9d9e9fa..666fa4b9f 100644 --- a/surfsense_backend/app/indexing_pipeline/exceptions.py +++ b/surfsense_backend/app/indexing_pipeline/exceptions.py @@ -14,8 +14,6 @@ from litellm.exceptions import ( ) from sqlalchemy.exc import IntegrityError as IntegrityError -from app.services.llm_error_adapter import LLMErrorCategory, adapt_llm_exception - # Tuples for use directly in except clauses. RETRYABLE_LLM_ERRORS = ( RateLimitError, @@ -99,20 +97,38 @@ def safe_exception_message(exc: Exception) -> str: def llm_retryable_message(exc: Exception) -> str: try: - adapted = adapt_llm_exception(exc) - if adapted.category is LLMErrorCategory.UNKNOWN: - return safe_exception_message(exc) - return adapted.user_message + if isinstance(exc, RateLimitError): + return PipelineMessages.RATE_LIMIT + if isinstance(exc, Timeout): + return PipelineMessages.LLM_TIMEOUT + if isinstance(exc, ServiceUnavailableError): + return PipelineMessages.LLM_UNAVAILABLE + if isinstance(exc, BadGatewayError): + return PipelineMessages.LLM_BAD_GATEWAY + if isinstance(exc, InternalServerError): + return PipelineMessages.LLM_SERVER_ERROR + if isinstance(exc, APIConnectionError): + return PipelineMessages.LLM_CONNECTION + return safe_exception_message(exc) except Exception: return "Something went wrong when calling the LLM." def llm_permanent_message(exc: Exception) -> str: try: - adapted = adapt_llm_exception(exc) - if adapted.category is LLMErrorCategory.UNKNOWN: - return safe_exception_message(exc) - return adapted.user_message + if isinstance(exc, AuthenticationError): + return PipelineMessages.LLM_AUTH + if isinstance(exc, PermissionDeniedError): + return PipelineMessages.LLM_PERMISSION + if isinstance(exc, NotFoundError): + return PipelineMessages.LLM_NOT_FOUND + if isinstance(exc, BadRequestError): + return PipelineMessages.LLM_BAD_REQUEST + if isinstance(exc, UnprocessableEntityError): + return PipelineMessages.LLM_UNPROCESSABLE + if isinstance(exc, APIResponseValidationError): + return PipelineMessages.LLM_RESPONSE + return safe_exception_message(exc) except Exception: return "Something went wrong when calling the LLM." diff --git a/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py b/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py index 30ea9d5d6..67a6778e0 100644 --- a/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py +++ b/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py @@ -8,7 +8,7 @@ from collections.abc import Awaitable, Callable from dataclasses import dataclass, field from datetime import UTC, datetime -from sqlalchemy import delete, select, update +from sqlalchemy import delete, select from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession @@ -19,17 +19,16 @@ from app.db import ( DocumentStatus, DocumentType, ) -from app.indexing_pipeline.cache import build_chunk_embeddings -from app.indexing_pipeline.cache.cached_indexing import chunk_markdown, embed_batch -from app.indexing_pipeline.chunk_reconciler import ExistingChunk, reconcile from app.indexing_pipeline.connector_document import ConnectorDocument +from app.indexing_pipeline.document_chunker import chunk_text, chunk_text_hybrid +from app.indexing_pipeline.document_embedder import embed_texts from app.indexing_pipeline.document_hashing import ( compute_content_hash, compute_identifier_hash, compute_unique_identifier_hash, ) from app.indexing_pipeline.document_persistence import ( - persist_scratch_index, + attach_chunks_to_document, rollback_and_persist_failure, ) from app.indexing_pipeline.exceptions import ( @@ -381,50 +380,53 @@ class IndexingPipelineService: content = connector_doc.source_markdown - t_step = time.perf_counter() - existing = await self._load_existing_chunks(document.id) - if existing and self._reconcile_enabled(): - chunk_count = await self._reindex_incrementally( - document, content, connector_doc, existing - ) - perf.info( - "[indexing] chunk+embed doc=%d chunks=%d in %.3fs", - document.id, - chunk_count, - time.perf_counter() - t_step, - ) - document.content = content - document.updated_at = datetime.now(UTC) - document.status = DocumentStatus.ready() - await self.session.commit() - else: - from app.config import config + await self.session.execute( + delete(Chunk).where(Chunk.document_id == document.id) + ) - chunks = await self._reindex_from_scratch( - document, content, connector_doc + t_step = time.perf_counter() + if connector_doc.should_use_code_chunker: + chunk_texts = await asyncio.to_thread( + chunk_text, + connector_doc.source_markdown, + use_code_chunker=True, ) - chunk_count = len(chunks) - perf.info( - "[indexing] chunk+embed doc=%d chunks=%d in %.3fs", - document.id, - chunk_count, - time.perf_counter() - t_step, - ) - await persist_scratch_index( - self.session, - document, - content, - chunks, - batch_size=config.INDEXING_CHUNK_INSERT_BATCH_SIZE, - perf=perf, + else: + # Use the table-aware hybrid chunker so Markdown tables are not + # split mid-row (see issue #1334). + chunk_texts = await asyncio.to_thread( + chunk_text_hybrid, + connector_doc.source_markdown, ) + + texts_to_embed = [content, *chunk_texts] + embeddings = await asyncio.to_thread(embed_texts, texts_to_embed) + summary_embedding, *chunk_embeddings = embeddings + + chunks = [ + Chunk(content=text, embedding=emb) + for text, emb in zip(chunk_texts, chunk_embeddings, strict=False) + ] + perf.info( + "[indexing] chunk+embed doc=%d chunks=%d in %.3fs", + document.id, + len(chunks), + time.perf_counter() - t_step, + ) + + document.content = content + document.embedding = summary_embedding + attach_chunks_to_document(document, chunks) + document.updated_at = datetime.now(UTC) + document.status = DocumentStatus.ready() + await self.session.commit() perf.info( "[indexing] index TOTAL doc=%d chunks=%d in %.3fs", document.id, - chunk_count, + len(chunks), time.perf_counter() - t_index, ) - log_index_success(ctx, chunk_count=chunk_count) + log_index_success(ctx, chunk_count=len(chunks)) outcome_status = "success" await self._enqueue_ai_sort_if_enabled(document) @@ -481,89 +483,6 @@ class IndexingPipelineService: persist_span_cm.__exit__(*sys.exc_info()) return document - @staticmethod - def _reconcile_enabled() -> bool: - from app.config import config - - return config.CHUNK_RECONCILE_ENABLED - - async def _load_existing_chunks(self, document_id: int) -> list[ExistingChunk]: - result = await self.session.execute( - select(Chunk.id, Chunk.content, Chunk.position).where( - Chunk.document_id == document_id - ) - ) - return [ - ExistingChunk(id=row.id, content=row.content, position=row.position) - for row in result - ] - - async def _reindex_from_scratch( - self, document: Document, content: str, connector_doc: ConnectorDocument - ) -> list[Chunk]: - await self.session.execute( - delete(Chunk).where(Chunk.document_id == document.id) - ) - - summary_embedding, chunk_pairs = await build_chunk_embeddings( - content, - use_code_chunker=connector_doc.should_use_code_chunker, - ) - - document.embedding = summary_embedding - return [ - Chunk(content=text, embedding=emb, position=i) - for i, (text, emb) in enumerate(chunk_pairs) - ] - - async def _reindex_incrementally( - self, - document: Document, - content: str, - connector_doc: ConnectorDocument, - existing: list[ExistingChunk], - ) -> int: - """Edit path: keep rows whose text survived, embed only new texts. - - Unchanged rows keep their embedding and their HNSW/GIN index entries; - moved rows get a position-only UPDATE, which touches neither index. - """ - new_texts = await chunk_markdown( - content, use_code_chunker=connector_doc.should_use_code_chunker - ) - plan = reconcile(existing, new_texts) - - # One batch: the document-level summary vector plus the missing chunks. - embeddings = await embed_batch([content, *[t for _, t in plan.to_embed]]) - summary_embedding, *new_embeddings = embeddings - - if plan.reused: - await self.session.execute( - update(Chunk), - [{"id": cid, "position": pos} for cid, pos in plan.reused], - ) - if plan.to_delete: - await self.session.execute( - delete(Chunk).where(Chunk.id.in_(plan.to_delete)) - ) - self.session.add_all( - Chunk( - content=text, - embedding=emb, - position=pos, - document_id=document.id, - ) - for (pos, text), emb in zip(plan.to_embed, new_embeddings, strict=True) - ) - document.embedding = summary_embedding - - ot_metrics.record_chunk_reconcile( - reused=len(existing) - len(plan.to_delete), - embedded=len(plan.to_embed), - deleted=len(plan.to_delete), - ) - return len(new_texts) - async def _enqueue_ai_sort_if_enabled(self, document: Document) -> None: """Fire-and-forget: enqueue incremental AI sort if the search space has it enabled.""" try: diff --git a/surfsense_backend/app/notifications/constants.py b/surfsense_backend/app/notifications/constants.py index 4c7139972..6fc13e3c7 100644 --- a/surfsense_backend/app/notifications/constants.py +++ b/surfsense_backend/app/notifications/constants.py @@ -2,9 +2,6 @@ from __future__ import annotations -# Matches notifications.title VARCHAR(200). -TITLE_MAX_LENGTH = 200 - # Notifications newer than this are live-synced; older ones load via the list endpoint. SYNC_WINDOW_DAYS = 14 diff --git a/surfsense_backend/app/notifications/service/handlers/document_processing.py b/surfsense_backend/app/notifications/service/handlers/document_processing.py index 714c4f1aa..8644df2c8 100644 --- a/surfsense_backend/app/notifications/service/handlers/document_processing.py +++ b/surfsense_backend/app/notifications/service/handlers/document_processing.py @@ -28,7 +28,7 @@ class DocumentProcessingNotificationHandler(BaseNotificationHandler): ) -> Notification: """Open the notification when document processing is queued.""" operation_id = msg.operation_id(document_type, document_name, search_space_id) - title = msg.started_title(document_name) + title = f"Processing: {document_name}" message = "Waiting in queue" metadata = { diff --git a/surfsense_backend/app/notifications/service/messages/document_processing.py b/surfsense_backend/app/notifications/service/messages/document_processing.py index 1f324b35d..3805c2847 100644 --- a/surfsense_backend/app/notifications/service/messages/document_processing.py +++ b/surfsense_backend/app/notifications/service/messages/document_processing.py @@ -6,8 +6,6 @@ import hashlib from datetime import UTC, datetime from typing import Any -from app.notifications.service.messages.text import format_title - def operation_id(document_type: str, filename: str, search_space_id: int) -> str: """Build a unique id for a document processing run.""" @@ -16,11 +14,6 @@ def operation_id(document_type: str, filename: str, search_space_id: int) -> str return f"doc_{document_type}_{search_space_id}_{timestamp}_{filename_hash}" -def started_title(document_name: str) -> str: - """Title shown when document processing is queued.""" - return format_title("Processing: ", document_name) - - def progress( stage: str, stage_message: str | None = None, @@ -51,11 +44,11 @@ def completion( ) -> tuple[str, str, str, dict[str, Any]]: """Compute the final title, message, status, and metadata for a finished run.""" if error_message: - title = format_title("Failed: ", document_name) + title = f"Failed: {document_name}" message = f"Processing failed: {error_message}" status = "failed" else: - title = format_title("Ready: ", document_name) + title = f"Ready: {document_name}" message = "Now searchable!" status = "completed" diff --git a/surfsense_backend/app/notifications/service/messages/text.py b/surfsense_backend/app/notifications/service/messages/text.py index 344c9eb4e..98d5284cb 100644 --- a/surfsense_backend/app/notifications/service/messages/text.py +++ b/surfsense_backend/app/notifications/service/messages/text.py @@ -2,21 +2,7 @@ from __future__ import annotations -from app.notifications.constants import TITLE_MAX_LENGTH - def truncate(text: str, limit: int) -> str: """Return ``text`` capped at ``limit`` chars, appending an ellipsis if cut.""" return text[:limit] + "..." if len(text) > limit else text - - -def format_title(prefix: str, text: str, *, max_length: int = TITLE_MAX_LENGTH) -> str: - """Build a notification title that fits ``max_length`` including ``prefix``.""" - budget = max_length - len(prefix) - if budget <= 0: - return prefix[:max_length] - if len(text) <= budget: - return f"{prefix}{text}" - if budget <= 3: - return f"{prefix}{text[:budget]}" - return f"{prefix}{text[: budget - 3]}..." diff --git a/surfsense_backend/app/observability/metrics.py b/surfsense_backend/app/observability/metrics.py index ade43ab01..5ba3be059 100644 --- a/surfsense_backend/app/observability/metrics.py +++ b/surfsense_backend/app/observability/metrics.py @@ -289,49 +289,6 @@ def _etl_extract_outcome(): ) -@lru_cache(maxsize=1) -def _etl_cache_lookups(): - return _get_meter().create_counter( - "surfsense.etl.cache.lookups", - description="Count of ETL parse-cache lookups by outcome (hit/miss).", - ) - - -@lru_cache(maxsize=1) -def _etl_cache_evictions(): - return _get_meter().create_counter( - "surfsense.etl.cache.evictions", - description="Count of ETL parse-cache entries evicted, by phase.", - ) - - -@lru_cache(maxsize=1) -def _embedding_cache_lookups(): - return _get_meter().create_counter( - "surfsense.embedding.cache.lookups", - description="Count of embedding (chunk+embedding) cache lookups by outcome (hit/miss).", - ) - - -@lru_cache(maxsize=1) -def _embedding_cache_evictions(): - return _get_meter().create_counter( - "surfsense.embedding.cache.evictions", - description="Count of embedding cache entries evicted, by phase.", - ) - - -@lru_cache(maxsize=1) -def _chunk_reconcile_chunks(): - return _get_meter().create_counter( - "surfsense.indexing.reconcile.chunks", - description=( - "Chunks handled by incremental re-indexing, by outcome " - "(reused/embedded/deleted)." - ), - ) - - @lru_cache(maxsize=1) def _celery_heartbeat_refreshes(): return _get_meter().create_counter( @@ -713,61 +670,6 @@ def record_etl_extract_outcome( ) -def record_etl_cache_lookup( - *, etl_service: str | None, mode: str | None, outcome: str -) -> None: - """Record a parse-cache lookup. ``outcome`` is ``hit`` or ``miss``.""" - _add( - _etl_cache_lookups(), - 1, - { - "etl.service": etl_service or "unknown", - "mode": mode or "unknown", - "outcome": outcome, - }, - ) - - -def record_etl_cache_eviction(count: int, *, phase: str) -> None: - """Record evicted entries. ``phase`` is ``ttl`` or ``size``.""" - if count <= 0: - return - _add(_etl_cache_evictions(), count, {"phase": phase}) - - -def record_embedding_cache_lookup( - *, embedding_model: str | None, chunker_kind: str | None, outcome: str -) -> None: - """Record an embedding-cache lookup. ``outcome`` is ``hit`` or ``miss``.""" - _add( - _embedding_cache_lookups(), - 1, - { - "embedding.model": embedding_model or "unknown", - "chunker.kind": chunker_kind or "unknown", - "outcome": outcome, - }, - ) - - -def record_embedding_cache_eviction(count: int, *, phase: str) -> None: - """Record evicted entries. ``phase`` is ``ttl`` or ``size``.""" - if count <= 0: - return - _add(_embedding_cache_evictions(), count, {"phase": phase}) - - -def record_chunk_reconcile(*, reused: int, embedded: int, deleted: int) -> None: - """Record an incremental re-index: how many chunks were kept vs recomputed.""" - for outcome, count in ( - ("reused", reused), - ("embedded", embedded), - ("deleted", deleted), - ): - if count > 0: - _add(_chunk_reconcile_chunks(), count, {"outcome": outcome}) - - def record_celery_heartbeat_refresh(*, heartbeat_type: str) -> None: _add(_celery_heartbeat_refreshes(), 1, {"heartbeat.type": heartbeat_type}) @@ -961,14 +863,9 @@ __all__ = [ "record_celery_queue_latency", "record_chat_request_duration", "record_chat_request_outcome", - "record_chunk_reconcile", "record_compaction_run", "record_connector_sync_duration", "record_connector_sync_outcome", - "record_embedding_cache_eviction", - "record_embedding_cache_lookup", - "record_etl_cache_eviction", - "record_etl_cache_lookup", "record_etl_extract_duration", "record_etl_extract_outcome", "record_indexing_document_duration", diff --git a/surfsense_backend/app/podcasts/api/routes.py b/surfsense_backend/app/podcasts/api/routes.py index cfcb2ede9..80e5e1c64 100644 --- a/surfsense_backend/app/podcasts/api/routes.py +++ b/surfsense_backend/app/podcasts/api/routes.py @@ -27,14 +27,14 @@ from app.db import ( get_async_session, ) from app.podcasts.generation.brief import propose_brief -from app.podcasts.persistence import Podcast, PodcastRepository, PodcastStatus +from app.podcasts.persistence import Podcast, PodcastRepository from app.podcasts.service import ( InvalidTransitionError, PodcastService, PreconditionFailedError, SpecConflictError, ) -from app.podcasts.storage import audio_exists, open_audio_stream, purge_audio +from app.podcasts.storage import open_audio_stream, purge_audio from app.podcasts.tasks import draft_transcript_task from app.podcasts.tts import get_text_to_speech from app.podcasts.voices import ( @@ -47,7 +47,6 @@ from app.utils.rbac import check_permission from .schemas import ( CreatePodcastRequest, - LanguageOptions, PodcastDetail, PodcastSummary, UpdateSpecRequest, @@ -115,20 +114,6 @@ async def list_voices(language: str | None = None): ] -@router.get("/podcasts/languages", response_model=LanguageOptions) -async def list_languages(): - """Languages the active TTS provider can offer the brief editor.""" - if not app_config.TTS_SERVICE: - raise HTTPException(status_code=503, detail="No TTS provider configured") - - provider = provider_from_service(app_config.TTS_SERVICE) - offering = get_voice_catalog().offerable_languages(provider) - return LanguageOptions( - languages=offering.languages, - allows_custom=offering.allows_custom, - ) - - @router.get("/podcasts/voices/{voice_id}/preview") async def preview_voice( voice_id: str, @@ -172,8 +157,8 @@ async def create_podcast( session, search_space_id=body.search_space_id, speaker_count=body.speaker_count, - min_seconds=body.min_seconds, - max_seconds=body.max_seconds, + min_minutes=body.min_minutes, + max_minutes=body.max_minutes, focus=body.focus, ) await service.attach_brief(podcast, spec) @@ -287,11 +272,6 @@ async def stream_podcast( podcast = await _load(session, user, podcast_id, Permission.PODCASTS_READ) if podcast.storage_key: - # Verify first so a missing object is a 404, not a mid-stream crash. - if not await audio_exists(podcast): - raise HTTPException( - status_code=404, detail="Podcast audio is no longer available" - ) return StreamingResponse( open_audio_stream(podcast), media_type="audio/mpeg", @@ -315,10 +295,7 @@ async def stream_podcast( }, ) - # No audio: terminal states never will have any, otherwise it's in flight. - if PodcastStatus(podcast.status).is_terminal: - raise HTTPException(status_code=404, detail="Podcast audio not found") - raise HTTPException(status_code=409, detail="Podcast audio is not ready yet") + raise HTTPException(status_code=404, detail="Podcast audio not found") async def _require( diff --git a/surfsense_backend/app/podcasts/api/schemas.py b/surfsense_backend/app/podcasts/api/schemas.py index cb8559651..7f1f8cc7c 100644 --- a/surfsense_backend/app/podcasts/api/schemas.py +++ b/surfsense_backend/app/podcasts/api/schemas.py @@ -11,12 +11,6 @@ from datetime import datetime from pydantic import BaseModel, ConfigDict, Field -from app.podcasts.duration_limits import ( - DEFAULT_MAX_SECONDS, - DEFAULT_MIN_SECONDS, - MAX_DURATION_SECONDS, - MIN_DURATION_SECONDS, -) from app.podcasts.persistence import Podcast, PodcastStatus from app.podcasts.schemas import PodcastSpec, Transcript from app.podcasts.service import has_stored_episode, read_spec, read_transcript @@ -24,6 +18,8 @@ from app.podcasts.service import has_stored_episode, read_spec, read_transcript # Defaults applied when a create request omits brief sizing; the brief gate lets # the user adjust before any cost is incurred. DEFAULT_SPEAKER_COUNT = 2 +DEFAULT_MIN_MINUTES = 10 +DEFAULT_MAX_MINUTES = 20 class CreatePodcastRequest(BaseModel): @@ -34,16 +30,8 @@ class CreatePodcastRequest(BaseModel): source_content: str = Field(..., min_length=1) thread_id: int | None = None speaker_count: int = Field(default=DEFAULT_SPEAKER_COUNT, ge=1, le=6) - min_seconds: int = Field( - default=DEFAULT_MIN_SECONDS, - ge=MIN_DURATION_SECONDS, - le=MAX_DURATION_SECONDS, - ) - max_seconds: int = Field( - default=DEFAULT_MAX_SECONDS, - ge=MIN_DURATION_SECONDS, - le=MAX_DURATION_SECONDS, - ) + min_minutes: int = Field(default=DEFAULT_MIN_MINUTES, ge=1) + max_minutes: int = Field(default=DEFAULT_MAX_MINUTES, ge=1) focus: str | None = Field(default=None, max_length=2000) @@ -63,17 +51,6 @@ class VoiceOption(BaseModel): gender: str -class LanguageOptions(BaseModel): - """The languages the brief editor may offer for the active provider. - - When ``allows_custom`` is true the list is a curated starting point and - the editor accepts any BCP-47 tag beyond it. - """ - - languages: list[str] - allows_custom: bool - - class PodcastSummary(BaseModel): """Lightweight list item.""" diff --git a/surfsense_backend/app/podcasts/duration_limits.py b/surfsense_backend/app/podcasts/duration_limits.py deleted file mode 100644 index fc7d29890..000000000 --- a/surfsense_backend/app/podcasts/duration_limits.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Shared bounds and defaults for podcast target duration.""" - -MAX_DURATION_SECONDS = 24 * 60 * 60 -MIN_DURATION_SECONDS = 15 -DEFAULT_MIN_SECONDS = 20 -DEFAULT_MAX_SECONDS = 30 diff --git a/surfsense_backend/app/podcasts/generation/brief/config.py b/surfsense_backend/app/podcasts/generation/brief/config.py index 9b206bde4..4f92585ae 100644 --- a/surfsense_backend/app/podcasts/generation/brief/config.py +++ b/surfsense_backend/app/podcasts/generation/brief/config.py @@ -6,13 +6,10 @@ from dataclasses import dataclass, field, fields from langchain_core.runnables import RunnableConfig -from app.podcasts.duration_limits import ( - DEFAULT_MAX_SECONDS, - DEFAULT_MIN_SECONDS, -) - # Sensible defaults for a fresh brief; the user adjusts the range at the gate. DEFAULT_SPEAKER_COUNT = 2 +DEFAULT_MIN_MINUTES = 10 +DEFAULT_MAX_MINUTES = 20 @dataclass(kw_only=True) @@ -20,8 +17,8 @@ class BriefConfig: """Signals used to propose a brief; everything here is non-LLM context.""" speaker_count: int = DEFAULT_SPEAKER_COUNT - min_seconds: int = DEFAULT_MIN_SECONDS - max_seconds: int = DEFAULT_MAX_SECONDS + min_minutes: int = DEFAULT_MIN_MINUTES + max_minutes: int = DEFAULT_MAX_MINUTES focus: str | None = None last_used_language: str | None = None last_used_voices: list[str] = field(default_factory=list) diff --git a/surfsense_backend/app/podcasts/generation/brief/nodes.py b/surfsense_backend/app/podcasts/generation/brief/nodes.py index de6a9717e..c0a6f1ae1 100644 --- a/surfsense_backend/app/podcasts/generation/brief/nodes.py +++ b/surfsense_backend/app/podcasts/generation/brief/nodes.py @@ -79,7 +79,7 @@ def propose_spec(state: BriefState, config: RunnableConfig) -> dict[str, Any]: style=PodcastStyle.CONVERSATIONAL, speakers=speakers, duration=DurationTarget( - min_seconds=brief.min_seconds, max_seconds=brief.max_seconds + min_minutes=brief.min_minutes, max_minutes=brief.max_minutes ), focus=brief.focus, ) diff --git a/surfsense_backend/app/podcasts/generation/brief/propose.py b/surfsense_backend/app/podcasts/generation/brief/propose.py index 09d74840e..17344702b 100644 --- a/surfsense_backend/app/podcasts/generation/brief/propose.py +++ b/surfsense_backend/app/podcasts/generation/brief/propose.py @@ -4,12 +4,11 @@ from __future__ import annotations from sqlalchemy.ext.asyncio import AsyncSession -from app.podcasts.duration_limits import DEFAULT_MAX_SECONDS, DEFAULT_MIN_SECONDS from app.podcasts.persistence import PodcastRepository from app.podcasts.schemas import PodcastSpec from app.podcasts.service import preferences_from -from .config import DEFAULT_SPEAKER_COUNT +from .config import DEFAULT_MAX_MINUTES, DEFAULT_MIN_MINUTES, DEFAULT_SPEAKER_COUNT from .graph import graph as brief_graph from .state import BriefState @@ -19,8 +18,8 @@ async def propose_brief( *, search_space_id: int, speaker_count: int = DEFAULT_SPEAKER_COUNT, - min_seconds: int = DEFAULT_MIN_SECONDS, - max_seconds: int = DEFAULT_MAX_SECONDS, + min_minutes: int = DEFAULT_MIN_MINUTES, + max_minutes: int = DEFAULT_MAX_MINUTES, focus: str | None = None, ) -> PodcastSpec: """Reuse the last-used language and voices, else English; return the spec.""" @@ -30,8 +29,8 @@ async def propose_brief( config = { "configurable": { "speaker_count": speaker_count, - "min_seconds": min_seconds, - "max_seconds": max_seconds, + "min_minutes": min_minutes, + "max_minutes": max_minutes, "focus": focus, "last_used_language": last_language, "last_used_voices": last_voices, diff --git a/surfsense_backend/app/podcasts/generation/transcript/nodes.py b/surfsense_backend/app/podcasts/generation/transcript/nodes.py index 7b472348d..44d6b219d 100644 --- a/surfsense_backend/app/podcasts/generation/transcript/nodes.py +++ b/surfsense_backend/app/podcasts/generation/transcript/nodes.py @@ -38,7 +38,7 @@ async def plan_outline( tc = TranscriptConfig.from_runnable_config(config) llm = await _require_llm(state, tc) - target_words = round(tc.spec.duration.midpoint_seconds * _WORDS_PER_MINUTE / 60) + target_words = round(tc.spec.duration.midpoint_minutes * _WORDS_PER_MINUTE) suggested_segments = max(1, round(target_words / _WORDS_PER_SEGMENT)) messages = [ diff --git a/surfsense_backend/app/podcasts/schemas/spec.py b/surfsense_backend/app/podcasts/schemas/spec.py index 3799d883b..1ef3dcfff 100644 --- a/surfsense_backend/app/podcasts/schemas/spec.py +++ b/surfsense_backend/app/podcasts/schemas/spec.py @@ -10,19 +10,17 @@ from __future__ import annotations import re from enum import StrEnum -from typing import Any from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator -from app.podcasts.duration_limits import ( - MAX_DURATION_SECONDS, - MIN_DURATION_SECONDS, -) - # A speaker count beyond this is almost never a real podcast and explodes the # voice/turn-attribution space, so we reject it at the brief gate. MAX_SPEAKERS = 6 +# Long-form is a goal, but an open-ended upper bound invites runaway TTS bills. +# One day of audio is a generous ceiling that still blocks obvious mistakes. +MAX_DURATION_MINUTES = 24 * 60 + # BCP-47 primary subtag plus optional region (e.g. ``en``, ``en-US``, ``pt-BR``). # Kept deliberately permissive: the voice catalog, not the brief, decides which # languages can actually be synthesised. Casing is normalised after matching. @@ -93,7 +91,7 @@ class SpeakerSpec(BaseModel): class DurationTarget(BaseModel): - """The desired finished length as an inclusive second range. + """The desired finished length as an inclusive minute range. Drafting aims for the midpoint and treats the bounds as soft guardrails; storing a range (rather than a point) keeps long-form expectations honest @@ -102,38 +100,19 @@ class DurationTarget(BaseModel): model_config = ConfigDict(extra="forbid") - min_seconds: int = Field(..., ge=MIN_DURATION_SECONDS, le=MAX_DURATION_SECONDS) - max_seconds: int = Field(..., ge=MIN_DURATION_SECONDS, le=MAX_DURATION_SECONDS) - - @model_validator(mode="before") - @classmethod - def _coerce_legacy_minutes(cls, data: Any) -> Any: - """Rows stored before seconds-based briefs still load from JSONB.""" - if ( - isinstance(data, dict) - and "min_seconds" not in data - and "min_minutes" in data - ): - migrated = dict(data) - migrated["min_seconds"] = int(migrated.pop("min_minutes")) * 60 - migrated["max_seconds"] = int(migrated.pop("max_minutes")) * 60 - return migrated - return data + min_minutes: int = Field(..., ge=1, le=MAX_DURATION_MINUTES) + max_minutes: int = Field(..., ge=1, le=MAX_DURATION_MINUTES) @model_validator(mode="after") def _check_order(self) -> DurationTarget: - if self.max_seconds < self.min_seconds: - raise ValueError("max_seconds must be >= min_seconds") + if self.max_minutes < self.min_minutes: + raise ValueError("max_minutes must be >= min_minutes") return self - @property - def midpoint_seconds(self) -> float: - """The runtime drafting should aim for within the range.""" - return (self.min_seconds + self.max_seconds) / 2 - @property def midpoint_minutes(self) -> float: - return self.midpoint_seconds / 60 + """The runtime drafting should aim for within the range.""" + return (self.min_minutes + self.max_minutes) / 2 class PodcastSpec(BaseModel): diff --git a/surfsense_backend/app/podcasts/storage.py b/surfsense_backend/app/podcasts/storage.py index c3326460d..f02429dff 100644 --- a/surfsense_backend/app/podcasts/storage.py +++ b/surfsense_backend/app/podcasts/storage.py @@ -42,13 +42,6 @@ def open_audio_stream(podcast: Podcast) -> AsyncIterator[bytes]: return get_storage_backend().open_stream(podcast.storage_key) -async def audio_exists(podcast: Podcast) -> bool: - """Whether the podcast's stored audio object is actually present.""" - return bool(podcast.storage_key) and await get_storage_backend().exists( - podcast.storage_key - ) - - async def purge_audio(podcast: Podcast) -> None: """Delete a podcast's stored audio if present; a missing object is fine.""" await purge_audio_object(podcast.storage_key) diff --git a/surfsense_backend/app/podcasts/voices/__init__.py b/surfsense_backend/app/podcasts/voices/__init__.py index 97874a655..ab1f8bbbf 100644 --- a/surfsense_backend/app/podcasts/voices/__init__.py +++ b/surfsense_backend/app/podcasts/voices/__init__.py @@ -6,7 +6,7 @@ configured provider via :func:`provider_from_service`. from __future__ import annotations -from .catalog import LanguageOffering, VoiceCatalog, get_voice_catalog +from .catalog import VoiceCatalog, get_voice_catalog from .preview import render_voice_preview from .provider import TtsProvider, provider_from_service from .voice import ANY_LANGUAGE, CatalogVoice, VoiceGender @@ -14,7 +14,6 @@ from .voice import ANY_LANGUAGE, CatalogVoice, VoiceGender __all__ = [ "ANY_LANGUAGE", "CatalogVoice", - "LanguageOffering", "TtsProvider", "VoiceCatalog", "VoiceGender", diff --git a/surfsense_backend/app/podcasts/voices/catalog.py b/surfsense_backend/app/podcasts/voices/catalog.py index 6bf39510a..c36313a0c 100644 --- a/surfsense_backend/app/podcasts/voices/catalog.py +++ b/surfsense_backend/app/podcasts/voices/catalog.py @@ -9,26 +9,11 @@ provider-native reference. from __future__ import annotations from collections.abc import Iterable -from dataclasses import dataclass from functools import lru_cache from .data import AZURE_VOICES, KOKORO_VOICES, OPENAI_VOICES, VERTEX_VOICES -from .data.languages import COMMON_LANGUAGES from .provider import TtsProvider -from .voice import ANY_LANGUAGE, CatalogVoice - - -@dataclass(frozen=True, slots=True) -class LanguageOffering: - """The languages a provider's roster can offer the brief form. - - ``allows_custom`` is true when the roster has wildcard voices: the listed - languages are then a curated starting point, not a limit, and any BCP-47 - tag may be entered. - """ - - languages: list[str] - allows_custom: bool +from .voice import CatalogVoice class VoiceCatalog: @@ -59,20 +44,6 @@ class VoiceCatalog: """Whether ``provider`` has at least one voice for ``language``.""" return any(v.speaks(language) for v in self.for_provider(provider)) - def offerable_languages(self, provider: TtsProvider) -> LanguageOffering: - """The languages ``provider`` can offer up front. - - Language-bound voices contribute their concrete tags; wildcard voices - cannot enumerate languages, so their presence merges in the curated - common list and opens free entry. - """ - voices = self.for_provider(provider) - tags = {v.language for v in voices if v.language != ANY_LANGUAGE} - has_wildcard = any(v.language == ANY_LANGUAGE for v in voices) - if has_wildcard: - tags.update(COMMON_LANGUAGES) - return LanguageOffering(languages=sorted(tags), allows_custom=has_wildcard) - @lru_cache(maxsize=1) def get_voice_catalog() -> VoiceCatalog: diff --git a/surfsense_backend/app/podcasts/voices/data/languages.py b/surfsense_backend/app/podcasts/voices/data/languages.py deleted file mode 100644 index c00fd7f05..000000000 --- a/surfsense_backend/app/podcasts/voices/data/languages.py +++ /dev/null @@ -1,33 +0,0 @@ -"""Curated languages offered when a roster has wildcard (any-language) voices. - -OpenAI-style multilingual voices speak whatever language the text is in, so -there is no provider list to enumerate. This is the set the brief form offers -up front for such providers; it is an offering, not a limit — the API flags -``allows_custom`` so users can enter any BCP-47 tag beyond it. -""" - -from __future__ import annotations - -COMMON_LANGUAGES: tuple[str, ...] = ( - "ar", - "bn", - "de", - "en", - "es", - "fr", - "hi", - "id", - "it", - "ja", - "ko", - "nl", - "pl", - "pt", - "ru", - "sw", - "th", - "tr", - "uk", - "vi", - "zh", -) diff --git a/surfsense_backend/app/prompts/default_system_instructions.py b/surfsense_backend/app/prompts/default_system_instructions.py index b968fc1f0..fd0a8e186 100644 --- a/surfsense_backend/app/prompts/default_system_instructions.py +++ b/surfsense_backend/app/prompts/default_system_instructions.py @@ -82,7 +82,7 @@ def build_configurable_system_prompt( *, model_name: str | None = None, ) -> str: - """Build a configurable SurfSense system prompt. + """Build a configurable SurfSense system prompt (NewLLMConfig path). See :func:`app.prompts.system_prompt_composer.composer.compose_system_prompt` for full parameter docs. @@ -104,7 +104,7 @@ def build_configurable_system_prompt( def get_default_system_instructions() -> str: """Return the default ```` block (no tools / citations). - Useful for populating the UI when editing custom system instructions. + Useful for populating the UI when seeding ``NewLLMConfig.system_instructions``. The output reflects the current fragment tree, not a baked-in constant. """ resolved_today = datetime.now(UTC).date().isoformat() diff --git a/surfsense_backend/app/prompts/system_prompt_composer/composer.py b/surfsense_backend/app/prompts/system_prompt_composer/composer.py index c639d4aa0..3849af313 100644 --- a/surfsense_backend/app/prompts/system_prompt_composer/composer.py +++ b/surfsense_backend/app/prompts/system_prompt_composer/composer.py @@ -348,7 +348,8 @@ def compose_system_prompt( mcp_connector_tools: ``{server_name: [tool_names...]}`` to inject an explicit MCP routing block. custom_system_instructions: Free-form instructions that override - the default ```` block. + the default ```` block (legacy support + for ``NewLLMConfig.system_instructions``). use_default_system_instructions: When ``custom_system_instructions`` is empty/None, fall back to defaults (legacy semantics). citations_enabled: Include ``citations_on.md`` (true) or diff --git a/surfsense_backend/app/retriever/chunks_hybrid_search.py b/surfsense_backend/app/retriever/chunks_hybrid_search.py index 5e5edec2e..47f7fe6b1 100644 --- a/surfsense_backend/app/retriever/chunks_hybrid_search.py +++ b/surfsense_backend/app/retriever/chunks_hybrid_search.py @@ -420,10 +420,7 @@ class ChucksHybridSearchRetriever: select( Chunk.id.label("chunk_id"), func.row_number() - .over( - partition_by=Chunk.document_id, - order_by=(Chunk.position, Chunk.id), - ) + .over(partition_by=Chunk.document_id, order_by=Chunk.id) .label("rn"), ) .where(Chunk.document_id.in_(doc_ids)) @@ -444,7 +441,7 @@ class ChucksHybridSearchRetriever: select(Chunk.id, Chunk.content, Chunk.document_id) .join(numbered, Chunk.id == numbered.c.chunk_id) .where(chunk_filter) - .order_by(Chunk.document_id, Chunk.position, Chunk.id) + .order_by(Chunk.document_id, Chunk.id) ) t_fetch = time.perf_counter() diff --git a/surfsense_backend/app/retriever/documents_hybrid_search.py b/surfsense_backend/app/retriever/documents_hybrid_search.py index d856e93cf..9ce86d404 100644 --- a/surfsense_backend/app/retriever/documents_hybrid_search.py +++ b/surfsense_backend/app/retriever/documents_hybrid_search.py @@ -357,10 +357,7 @@ class DocumentHybridSearchRetriever: select( Chunk.id.label("chunk_id"), func.row_number() - .over( - partition_by=Chunk.document_id, - order_by=(Chunk.position, Chunk.id), - ) + .over(partition_by=Chunk.document_id, order_by=Chunk.id) .label("rn"), ) .where(Chunk.document_id.in_(doc_ids)) @@ -372,7 +369,7 @@ class DocumentHybridSearchRetriever: select(Chunk.id, Chunk.content, Chunk.document_id) .join(numbered, Chunk.id == numbered.c.chunk_id) .where(numbered.c.rn <= _MAX_FETCH_CHUNKS_PER_DOC) - .order_by(Chunk.document_id, Chunk.position, Chunk.id) + .order_by(Chunk.document_id, Chunk.id) ) t_fetch = time.perf_counter() diff --git a/surfsense_backend/app/routes/__init__.py b/surfsense_backend/app/routes/__init__.py index 8ce84d179..a050651f6 100644 --- a/surfsense_backend/app/routes/__init__.py +++ b/surfsense_backend/app/routes/__init__.py @@ -24,10 +24,7 @@ from .dropbox_add_connector_route import router as dropbox_add_connector_router from .editor_routes import router as editor_router from .export_routes import router as export_router from .folders_routes import router as folders_router -from .gateway_webhook_routes import ( - config_router as gateway_config_router, - router as gateway_router, -) +from .gateway_webhook_routes import router as gateway_router from .gateway_whatsapp_baileys_routes import router as gateway_whatsapp_baileys_router from .gateway_whatsapp_webhook_routes import router as gateway_whatsapp_webhook_router from .google_calendar_add_connector_route import ( @@ -47,9 +44,9 @@ from .logs_routes import router as logs_router from .luma_add_connector_route import router as luma_add_connector_router from .mcp_oauth_route import router as mcp_oauth_router from .memory_routes import router as memory_router -from .model_connections_routes import router as model_connections_router from .model_list_routes import router as model_list_router from .new_chat_routes import router as new_chat_router +from .new_llm_config_routes import router as new_llm_config_router from .notes_routes import router as notes_router from .notion_add_connector_route import router as notion_add_connector_router from .obsidian_plugin_routes import router as obsidian_plugin_router @@ -66,6 +63,7 @@ from .stripe_routes import router as stripe_router from .team_memory_routes import router as team_memory_router from .teams_add_connector_route import router as teams_add_connector_router from .video_presentations_routes import router as video_presentations_router +from .vision_llm_routes import router as vision_llm_router from .youtube_routes import router as youtube_router router = APIRouter() @@ -77,7 +75,6 @@ router.include_router(export_router) router.include_router(documents_router) router.include_router(folders_router) _gateway_enabled_dep = [Depends(require_gateway_enabled)] -router.include_router(gateway_config_router) router.include_router(gateway_router, dependencies=_gateway_enabled_dep) router.include_router( gateway_whatsapp_webhook_router, dependencies=_gateway_enabled_dep @@ -101,6 +98,7 @@ router.include_router( ) # Video presentation status and streaming router.include_router(reports_router) # Report CRUD and multi-format export router.include_router(image_generation_router) # Image generation via litellm +router.include_router(vision_llm_router) # Vision LLM configs for screenshot analysis router.include_router(search_source_connectors_router) router.include_router(google_calendar_add_connector_router) router.include_router(google_gmail_add_connector_router) @@ -118,7 +116,7 @@ router.include_router(jira_add_connector_router) router.include_router(confluence_add_connector_router) router.include_router(clickup_add_connector_router) router.include_router(dropbox_add_connector_router) -router.include_router(model_connections_router) # Connection-centric model catalog +router.include_router(new_llm_config_router) # LLM configs with prompt configuration router.include_router(model_list_router) # Dynamic model catalogue from OpenRouter router.include_router(logs_router) router.include_router(circleback_webhook_router) # Circleback meeting webhooks diff --git a/surfsense_backend/app/routes/anonymous_chat_routes.py b/surfsense_backend/app/routes/anonymous_chat_routes.py index f6f984c20..ad3277375 100644 --- a/surfsense_backend/app/routes/anonymous_chat_routes.py +++ b/surfsense_backend/app/routes/anonymous_chat_routes.py @@ -18,7 +18,6 @@ from app.etl_pipeline.file_classifier import ( PLAINTEXT_EXTENSIONS, ) from app.rate_limiter import limiter -from app.tasks.chat.streaming.errors.classifier import classify_stream_exception logger = logging.getLogger(__name__) @@ -99,6 +98,7 @@ class AnonQuotaResponse(BaseModel): class AnonModelResponse(BaseModel): id: int name: str + description: str | None = None provider: str model_name: str billing_tier: str = "free" @@ -131,7 +131,8 @@ async def list_anonymous_models(): AnonModelResponse( id=cfg.get("id", 0), name=cfg.get("name", ""), - provider=cfg.get("provider") or cfg.get("litellm_provider", ""), + description=cfg.get("description"), + provider=cfg.get("provider", ""), model_name=cfg.get("model_name", ""), billing_tier=cfg.get("billing_tier", "free"), is_premium=cfg.get("billing_tier", "free") == "premium", @@ -159,7 +160,8 @@ async def get_anonymous_model(slug: str): return AnonModelResponse( id=cfg.get("id", 0), name=cfg.get("name", ""), - provider=cfg.get("provider") or cfg.get("litellm_provider", ""), + description=cfg.get("description"), + provider=cfg.get("provider", ""), model_name=cfg.get("model_name", ""), billing_tier=cfg.get("billing_tier", "free"), is_premium=cfg.get("billing_tier", "free") == "premium", @@ -472,15 +474,7 @@ async def stream_anonymous_chat( except Exception as e: logger.exception("Anonymous chat stream error") await TokenQuotaService.anon_release(session_key, ip_key, request_id) - _, error_code, _, _, user_message, extra = classify_stream_exception( - e, - flow_label="chat", - ) - yield streaming_service.format_error( - user_message, - error_code=error_code, - extra=extra, - ) + yield streaming_service.format_error(f"Error during chat: {e!s}") yield streaming_service.format_done() finally: await TokenQuotaService.anon_release_stream_slot(client_ip) diff --git a/surfsense_backend/app/routes/documents_routes.py b/surfsense_backend/app/routes/documents_routes.py index 53f03a0ca..865068fba 100644 --- a/surfsense_backend/app/routes/documents_routes.py +++ b/surfsense_backend/app/routes/documents_routes.py @@ -1014,8 +1014,8 @@ async def get_document_by_chunk_id( .filter( Chunk.document_id == document.id, or_( - Chunk.position < chunk.position, - and_(Chunk.position == chunk.position, Chunk.id < chunk.id), + Chunk.created_at < chunk.created_at, + and_(Chunk.created_at == chunk.created_at, Chunk.id < chunk.id), ), ) ) @@ -1027,7 +1027,7 @@ async def get_document_by_chunk_id( windowed_result = await session.execute( select(Chunk) .filter(Chunk.document_id == document.id) - .order_by(Chunk.position, Chunk.id) + .order_by(Chunk.created_at, Chunk.id) .offset(start) .limit(end - start) ) @@ -1137,7 +1137,7 @@ async def get_document_chunks_paginated( chunks_result = await session.execute( select(Chunk) .filter(Chunk.document_id == document_id) - .order_by(Chunk.position, Chunk.id) + .order_by(Chunk.created_at, Chunk.id) .offset(offset) .limit(page_size) ) diff --git a/surfsense_backend/app/routes/editor_routes.py b/surfsense_backend/app/routes/editor_routes.py index 8250fff98..166164c50 100644 --- a/surfsense_backend/app/routes/editor_routes.py +++ b/surfsense_backend/app/routes/editor_routes.py @@ -38,8 +38,7 @@ logger = logging.getLogger(__name__) router = APIRouter() -EDITOR_PLATE_MAX_BYTES = 1 * 1024 * 1024 -EDITOR_PLATE_MAX_LINES = 5000 +EDITOR_PLATE_MAX_BYTES = 5 * 1024 * 1024 @router.get("/search-spaces/{search_space_id}/documents/{document_id}/editor-content") @@ -84,22 +83,16 @@ async def get_editor_content( def _build_response(md: str) -> dict: size_bytes = len(md.encode("utf-8")) - line_count = md.count("\n") + 1 - too_large = ( - size_bytes > EDITOR_PLATE_MAX_BYTES or line_count > EDITOR_PLATE_MAX_LINES - ) - viewer_mode = "monaco" if too_large else "plate" + viewer_mode = "monaco" if size_bytes > EDITOR_PLATE_MAX_BYTES else "plate" return { "document_id": document.id, "title": document.title, "document_type": document.document_type.value, "source_markdown": md, "content_size_bytes": size_bytes, - "line_count": line_count, "chunk_count": chunk_count, "viewer_mode": viewer_mode, "editor_plate_max_bytes": EDITOR_PLATE_MAX_BYTES, - "editor_plate_max_lines": EDITOR_PLATE_MAX_LINES, "updated_at": document.updated_at.isoformat() if document.updated_at else None, @@ -126,7 +119,7 @@ async def get_editor_content( chunk_contents_result = await session.execute( select(Chunk.content) .filter(Chunk.document_id == document_id) - .order_by(Chunk.position, Chunk.id) + .order_by(Chunk.id) ) chunk_contents = chunk_contents_result.scalars().all() @@ -212,7 +205,7 @@ async def download_document_markdown( chunk_contents_result = await session.execute( select(Chunk.content) .filter(Chunk.document_id == document_id) - .order_by(Chunk.position, Chunk.id) + .order_by(Chunk.id) ) chunk_contents = chunk_contents_result.scalars().all() if chunk_contents: @@ -361,7 +354,7 @@ async def export_document( chunk_contents_result = await session.execute( select(Chunk.content) .filter(Chunk.document_id == document_id) - .order_by(Chunk.position, Chunk.id) + .order_by(Chunk.id) ) chunk_contents = chunk_contents_result.scalars().all() if chunk_contents: diff --git a/surfsense_backend/app/routes/gateway_webhook_routes.py b/surfsense_backend/app/routes/gateway_webhook_routes.py index 9b4af4b83..14f929567 100644 --- a/surfsense_backend/app/routes/gateway_webhook_routes.py +++ b/surfsense_backend/app/routes/gateway_webhook_routes.py @@ -56,7 +56,6 @@ from app.utils.oauth_security import OAuthStateManager, TokenEncryption from app.utils.rbac import check_search_space_access router = APIRouter(prefix="/gateway", tags=["gateway"]) -config_router = APIRouter(prefix="/gateway", tags=["gateway"]) logger = logging.getLogger(__name__) SLACK_AUTHORIZATION_URL = "https://slack.com/oauth/v2/authorize" @@ -968,20 +967,11 @@ async def list_platforms( ] -@config_router.get("/config") +@router.get("/config") async def get_gateway_config( user: User = Depends(current_active_user), ) -> dict[str, bool | str]: - if not config.GATEWAY_ENABLED: - return { - "enabled": False, - "telegram_enabled": False, - "whatsapp_intake_mode": "disabled", - "slack_enabled": False, - "discord_enabled": False, - } return { - "enabled": True, "telegram_enabled": _telegram_gateway_enabled(), "whatsapp_intake_mode": config.GATEWAY_WHATSAPP_INTAKE_MODE, "slack_enabled": _slack_gateway_enabled(), diff --git a/surfsense_backend/app/routes/image_generation_routes.py b/surfsense_backend/app/routes/image_generation_routes.py index cc3e51ed5..33caf8453 100644 --- a/surfsense_backend/app/routes/image_generation_routes.py +++ b/surfsense_backend/app/routes/image_generation_routes.py @@ -1,5 +1,7 @@ """ Image Generation routes: +- CRUD for ImageGenerationConfig (user-created image model configs) +- Global image gen configs endpoint (from YAML) - Image generation execution (calls litellm.aimage_generation()) - CRUD for ImageGeneration records (results) - Image serving endpoint (serves b64_json images from DB, protected by signed tokens) @@ -14,12 +16,11 @@ from litellm import aimage_generation from sqlalchemy import select from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import selectinload from app.config import config from app.db import ( ImageGeneration, - Model, + ImageGenerationConfig, Permission, SearchSpace, SearchSpaceMembership, @@ -27,14 +28,14 @@ from app.db import ( get_async_session, ) from app.schemas import ( + GlobalImageGenConfigRead, + ImageGenerationConfigCreate, + ImageGenerationConfigRead, + ImageGenerationConfigUpdate, ImageGenerationCreate, ImageGenerationListRead, ImageGenerationRead, ) -from app.services.auto_model_pin_service import ( - auto_model_candidates, - choose_auto_model_candidate, -) from app.services.billable_calls import ( DEFAULT_IMAGE_RESERVE_MICROS, QuotaInsufficientError, @@ -42,10 +43,10 @@ from app.services.billable_calls import ( ) from app.services.image_gen_router_service import ( IMAGE_GEN_AUTO_MODE_ID, + ImageGenRouterService, is_image_gen_auto_mode, ) -from app.services.model_capabilities import has_capability -from app.services.model_resolver import to_litellm +from app.services.provider_api_base import resolve_api_base from app.users import current_active_user from app.utils.rbac import check_permission from app.utils.signed_image_urls import verify_image_token @@ -53,16 +54,52 @@ from app.utils.signed_image_urls import verify_image_token router = APIRouter() logger = logging.getLogger(__name__) - -def _get_global_model(model_id: int) -> dict | None: - return next((m for m in config.GLOBAL_MODELS if m.get("id") == model_id), None) +# Provider mapping for building litellm model strings. +# Only includes providers that support image generation. +# See: https://docs.litellm.ai/docs/image_generation#supported-providers +_PROVIDER_MAP = { + "OPENAI": "openai", + "AZURE_OPENAI": "azure", + "GOOGLE": "gemini", # Google AI Studio + "VERTEX_AI": "vertex_ai", + "BEDROCK": "bedrock", # AWS Bedrock + "RECRAFT": "recraft", + "OPENROUTER": "openrouter", + "XINFERENCE": "xinference", + "NSCALE": "nscale", +} -def _get_global_connection(connection_id: int) -> dict | None: - return next( - (c for c in config.GLOBAL_CONNECTIONS if c.get("id") == connection_id), - None, - ) +def _get_global_image_gen_config(config_id: int) -> dict | None: + """Get a global image generation configuration by ID (negative IDs).""" + if config_id == IMAGE_GEN_AUTO_MODE_ID: + return { + "id": IMAGE_GEN_AUTO_MODE_ID, + "name": "Auto (Fastest)", + "provider": "AUTO", + "model_name": "auto", + "is_auto_mode": True, + } + if config_id > 0: + return None + for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS: + if cfg.get("id") == config_id: + return cfg + return None + + +def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str: + """Resolve the LiteLLM provider prefix used in model strings.""" + if custom_provider: + return custom_provider + return _PROVIDER_MAP.get(provider.upper(), provider.lower()) + + +def _build_model_string( + provider: str, model_name: str, custom_provider: str | None +) -> str: + """Build a litellm model string from provider + model_name.""" + return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}" async def _resolve_billing_for_image_gen( @@ -78,41 +115,34 @@ async def _resolve_billing_for_image_gen( config that will actually run, and so we don't open an ``ImageGeneration`` row for a request that's about to 402. - User-owned (positive ID) BYOK models are always free — they cost - the user nothing on our side. Auto mode resolves to one concrete - global or BYOK model before billing is calculated. + User-owned (positive ID) BYOK configs are always free — they cost + the user nothing on our side. Auto mode currently treats as free + because the underlying router can dispatch to either premium or + free YAML configs and we don't surface the resolved deployment up + here yet. Bringing Auto under premium billing would require + threading the chosen deployment back from ``ImageGenRouterService``. """ resolved_id = config_id if resolved_id is None: - resolved_id = search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID + resolved_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID if is_image_gen_auto_mode(resolved_id): - candidates = await auto_model_candidates( - session, - search_space_id=search_space.id, - user_id=search_space.user_id, - capability="image_gen", - ) - if not candidates: - return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS) - selected = choose_auto_model_candidate(candidates, search_space.id) - resolved_id = int(selected["id"]) + return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS) if resolved_id < 0: - global_model = _get_global_model(resolved_id) or {} - global_connection = _get_global_connection(global_model.get("connection_id", 0)) - billing_tier = str(global_model.get("billing_tier", "free")).lower() - if global_connection and global_model.get("model_id"): - base_model, _ = to_litellm(global_connection, global_model["model_id"]) - else: - base_model = "global_image_model" - catalog = global_model.get("catalog") or {} + cfg = _get_global_image_gen_config(resolved_id) or {} + billing_tier = str(cfg.get("billing_tier", "free")).lower() + base_model = _build_model_string( + cfg.get("provider", ""), + cfg.get("model_name", ""), + cfg.get("custom_provider"), + ) reserve_micros = int( - catalog.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS + cfg.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS ) return (billing_tier, base_model, reserve_micros) - # Positive ID = user-owned BYOK image-gen model — always free. + # Positive ID = user-owned BYOK image-gen config — always free. return ("free", "user_byok", DEFAULT_IMAGE_RESERVE_MICROS) @@ -125,14 +155,14 @@ async def _execute_image_generation( Call litellm.aimage_generation() with the appropriate config. Resolution order: - 1. Explicit image_gen_model_id on the request - 2. Search space's image_gen_model_id preference + 1. Explicit image_generation_config_id on the request + 2. Search space's image_generation_config_id preference 3. Falls back to Auto mode if available """ - config_id = image_gen.image_gen_model_id + config_id = image_gen.image_generation_config_id if config_id is None: - config_id = search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID - image_gen.image_gen_model_id = config_id + config_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID + image_gen.image_generation_config_id = config_id # Build kwargs gen_kwargs = {} @@ -148,30 +178,36 @@ async def _execute_image_generation( gen_kwargs["response_format"] = image_gen.response_format if is_image_gen_auto_mode(config_id): - candidates = await auto_model_candidates( - session, - search_space_id=search_space.id, - user_id=search_space.user_id, - capability="image_gen", + if not ImageGenRouterService.is_initialized(): + raise ValueError( + "Auto mode requested but Image Generation Router not initialized. " + "Ensure global_llm_config.yaml has global_image_generation_configs." + ) + response = await ImageGenRouterService.aimage_generation( + prompt=image_gen.prompt, model="auto", **gen_kwargs ) - if not candidates: - raise ValueError("No image-generation models are available for Auto mode") - config_id = int(choose_auto_model_candidate(candidates, search_space.id)["id"]) - image_gen.image_gen_model_id = config_id + elif config_id < 0: + # Global config from YAML + cfg = _get_global_image_gen_config(config_id) + if not cfg: + raise ValueError(f"Global image generation config {config_id} not found") - if config_id < 0: - global_model = _get_global_model(config_id) - if not global_model or not has_capability(global_model, "image_gen"): - raise ValueError(f"Global image generation model {config_id} not found") - global_connection = _get_global_connection(global_model["connection_id"]) - if not global_connection: - raise ValueError(f"Global connection for image model {config_id} not found") - - model_string, resolved_kwargs = to_litellm( - global_connection, - global_model["model_id"], + provider_prefix = _resolve_provider_prefix( + cfg.get("provider", ""), cfg.get("custom_provider") ) - gen_kwargs.update(resolved_kwargs) + model_string = f"{provider_prefix}/{cfg['model_name']}" + gen_kwargs["api_key"] = cfg.get("api_key") + api_base = resolve_api_base( + provider=cfg.get("provider"), + provider_prefix=provider_prefix, + config_api_base=cfg.get("api_base"), + ) + if api_base: + gen_kwargs["api_base"] = api_base + if cfg.get("api_version"): + gen_kwargs["api_version"] = cfg["api_version"] + if cfg.get("litellm_params"): + gen_kwargs.update(cfg["litellm_params"]) # User model override if image_gen.model: @@ -181,28 +217,30 @@ async def _execute_image_generation( prompt=image_gen.prompt, model=model_string, **gen_kwargs ) else: - # Positive ID = Model + Connection + # Positive ID = DB ImageGenerationConfig result = await session.execute( - select(Model) - .options(selectinload(Model.connection)) - .filter(Model.id == config_id, Model.enabled.is_(True)) + select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id) ) - db_model = result.scalars().first() - if not db_model or not db_model.connection or not db_model.connection.enabled: - raise ValueError(f"Image generation model {config_id} not found") - conn = db_model.connection - if conn.search_space_id is not None and conn.search_space_id != search_space.id: - raise ValueError(f"Image generation model {config_id} not found") - if conn.user_id is not None and conn.user_id != search_space.user_id: - raise ValueError(f"Image generation model {config_id} not found") - if not has_capability(db_model, "image_gen"): - raise ValueError(f"Model {config_id} is not image-generation capable") + db_cfg = result.scalars().first() + if not db_cfg: + raise ValueError(f"Image generation config {config_id} not found") - model_string, resolved_kwargs = to_litellm( - db_model.connection, - db_model.model_id, + provider_prefix = _resolve_provider_prefix( + db_cfg.provider.value, db_cfg.custom_provider ) - gen_kwargs.update(resolved_kwargs) + model_string = f"{provider_prefix}/{db_cfg.model_name}" + gen_kwargs["api_key"] = db_cfg.api_key + api_base = resolve_api_base( + provider=db_cfg.provider.value, + provider_prefix=provider_prefix, + config_api_base=db_cfg.api_base, + ) + if api_base: + gen_kwargs["api_base"] = api_base + if db_cfg.api_version: + gen_kwargs["api_version"] = db_cfg.api_version + if db_cfg.litellm_params: + gen_kwargs.update(db_cfg.litellm_params) # User model override if image_gen.model: @@ -222,6 +260,266 @@ async def _execute_image_generation( image_gen.model = hidden["model"] +# ============================================================================= +# Global Image Generation Configs (from YAML) +# ============================================================================= + + +@router.get( + "/global-image-generation-configs", + response_model=list[GlobalImageGenConfigRead], +) +async def get_global_image_gen_configs( + user: User = Depends(current_active_user), +): + """Get all global image generation configs. API keys are hidden.""" + try: + global_configs = config.GLOBAL_IMAGE_GEN_CONFIGS + safe_configs = [] + + if global_configs and len(global_configs) > 0: + safe_configs.append( + { + "id": 0, + "name": "Auto (Fastest)", + "description": "Automatically routes across available image generation providers.", + "provider": "AUTO", + "custom_provider": None, + "model_name": "auto", + "api_base": None, + "api_version": None, + "litellm_params": {}, + "is_global": True, + "is_auto_mode": True, + # Auto mode currently treated as free until per-deployment + # billing-tier surfacing lands (see _resolve_billing_for_image_gen). + "billing_tier": "free", + "is_premium": False, + } + ) + + for cfg in global_configs: + billing_tier = str(cfg.get("billing_tier", "free")).lower() + safe_configs.append( + { + "id": cfg.get("id"), + "name": cfg.get("name"), + "description": cfg.get("description"), + "provider": cfg.get("provider"), + "custom_provider": cfg.get("custom_provider"), + "model_name": cfg.get("model_name"), + "api_base": cfg.get("api_base") or None, + "api_version": cfg.get("api_version") or None, + "litellm_params": cfg.get("litellm_params", {}), + "is_global": True, + "billing_tier": billing_tier, + # Mirror chat (``new_llm_config_routes``) so the new-chat + # selector's premium badge logic keys off the same + # field across chat / image / vision tabs. + "is_premium": billing_tier == "premium", + "quota_reserve_micros": cfg.get("quota_reserve_micros"), + } + ) + + return safe_configs + except Exception as e: + logger.exception("Failed to fetch global image generation configs") + raise HTTPException( + status_code=500, detail=f"Failed to fetch configs: {e!s}" + ) from e + + +# ============================================================================= +# ImageGenerationConfig CRUD +# ============================================================================= + + +@router.post("/image-generation-configs", response_model=ImageGenerationConfigRead) +async def create_image_gen_config( + config_data: ImageGenerationConfigCreate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """Create a new image generation config for a search space.""" + try: + await check_permission( + session, + user, + config_data.search_space_id, + Permission.IMAGE_GENERATIONS_CREATE.value, + "You don't have permission to create image generation configs in this search space", + ) + + db_config = ImageGenerationConfig(**config_data.model_dump(), user_id=user.id) + session.add(db_config) + await session.commit() + await session.refresh(db_config) + return db_config + + except HTTPException: + raise + except Exception as e: + await session.rollback() + logger.exception("Failed to create ImageGenerationConfig") + raise HTTPException( + status_code=500, detail=f"Failed to create config: {e!s}" + ) from e + + +@router.get("/image-generation-configs", response_model=list[ImageGenerationConfigRead]) +async def list_image_gen_configs( + search_space_id: int, + skip: int = 0, + limit: int = 100, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """List image generation configs for a search space.""" + try: + await check_permission( + session, + user, + search_space_id, + Permission.IMAGE_GENERATIONS_READ.value, + "You don't have permission to view image generation configs in this search space", + ) + + result = await session.execute( + select(ImageGenerationConfig) + .filter(ImageGenerationConfig.search_space_id == search_space_id) + .order_by(ImageGenerationConfig.created_at.desc()) + .offset(skip) + .limit(limit) + ) + return result.scalars().all() + + except HTTPException: + raise + except Exception as e: + logger.exception("Failed to list ImageGenerationConfigs") + raise HTTPException( + status_code=500, detail=f"Failed to fetch configs: {e!s}" + ) from e + + +@router.get( + "/image-generation-configs/{config_id}", response_model=ImageGenerationConfigRead +) +async def get_image_gen_config( + config_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """Get a specific image generation config by ID.""" + try: + result = await session.execute( + select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id) + ) + db_config = result.scalars().first() + if not db_config: + raise HTTPException(status_code=404, detail="Config not found") + + await check_permission( + session, + user, + db_config.search_space_id, + Permission.IMAGE_GENERATIONS_READ.value, + "You don't have permission to view image generation configs in this search space", + ) + return db_config + + except HTTPException: + raise + except Exception as e: + logger.exception("Failed to get ImageGenerationConfig") + raise HTTPException( + status_code=500, detail=f"Failed to fetch config: {e!s}" + ) from e + + +@router.put( + "/image-generation-configs/{config_id}", response_model=ImageGenerationConfigRead +) +async def update_image_gen_config( + config_id: int, + update_data: ImageGenerationConfigUpdate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """Update an existing image generation config.""" + try: + result = await session.execute( + select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id) + ) + db_config = result.scalars().first() + if not db_config: + raise HTTPException(status_code=404, detail="Config not found") + + await check_permission( + session, + user, + db_config.search_space_id, + Permission.IMAGE_GENERATIONS_CREATE.value, + "You don't have permission to update image generation configs in this search space", + ) + + for key, value in update_data.model_dump(exclude_unset=True).items(): + setattr(db_config, key, value) + + await session.commit() + await session.refresh(db_config) + return db_config + + except HTTPException: + raise + except Exception as e: + await session.rollback() + logger.exception("Failed to update ImageGenerationConfig") + raise HTTPException( + status_code=500, detail=f"Failed to update config: {e!s}" + ) from e + + +@router.delete("/image-generation-configs/{config_id}", response_model=dict) +async def delete_image_gen_config( + config_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """Delete an image generation config.""" + try: + result = await session.execute( + select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id) + ) + db_config = result.scalars().first() + if not db_config: + raise HTTPException(status_code=404, detail="Config not found") + + await check_permission( + session, + user, + db_config.search_space_id, + Permission.IMAGE_GENERATIONS_DELETE.value, + "You don't have permission to delete image generation configs in this search space", + ) + + await session.delete(db_config) + await session.commit() + return { + "message": "Image generation config deleted successfully", + "id": config_id, + } + + except HTTPException: + raise + except Exception as e: + await session.rollback() + logger.exception("Failed to delete ImageGenerationConfig") + raise HTTPException( + status_code=500, detail=f"Failed to delete config: {e!s}" + ) from e + + # ============================================================================= # Image Generation Execution + Results CRUD # ============================================================================= @@ -270,7 +568,7 @@ async def create_image_generation( raise HTTPException(status_code=404, detail="Search space not found") billing_tier, base_model, reserve_micros = await _resolve_billing_for_image_gen( - session, data.image_gen_model_id, search_space + session, data.image_generation_config_id, search_space ) # billable_call runs OUTSIDE the inner try/except so QuotaInsufficientError @@ -296,7 +594,7 @@ async def create_image_generation( size=data.size, style=data.style, response_format=data.response_format, - image_gen_model_id=data.image_gen_model_id, + image_generation_config_id=data.image_generation_config_id, search_space_id=data.search_space_id, created_by_id=user.id, ) diff --git a/surfsense_backend/app/routes/model_connections_routes.py b/surfsense_backend/app/routes/model_connections_routes.py deleted file mode 100644 index 4d32a32af..000000000 --- a/surfsense_backend/app/routes/model_connections_routes.py +++ /dev/null @@ -1,811 +0,0 @@ -import logging - -from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy import select, update -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import selectinload - -from app.config import config -from app.db import ( - Connection, - ConnectionScope, - Model, - ModelSource, - NewChatThread, - Permission, - SearchSpace, - User, - get_async_session, -) -from app.schemas import ( - ConnectionCreate, - ConnectionRead, - ConnectionUpdate, - ModelCreate, - ModelPreviewRead, - ModelProviderRead, - ModelRead, - ModelRolesRead, - ModelRolesUpdate, - ModelsBulkUpdate, - ModelSelection, - ModelTestPreview, - ModelUpdate, - VerifyConnectionResponse, -) -from app.services.model_capabilities import has_capability -from app.services.model_connection_service import ( - ModelDiscoveryError, - derive_capabilities, - discover_models, - test_model, - verify_connection, -) -from app.services.provider_registry import REGISTRY -from app.users import current_active_user -from app.utils.rbac import check_permission - -router = APIRouter() -logger = logging.getLogger(__name__) - - -def _model_read(model: Model | dict) -> ModelRead: - return ModelRead.model_validate(model) - - -def _preview_model_read(item: dict) -> ModelPreviewRead: - return ModelPreviewRead( - model_id=item["model_id"], - display_name=item.get("display_name"), - source=item.get("source", ModelSource.DISCOVERED), - supports_chat=item.get("supports_chat"), - max_input_tokens=item.get("max_input_tokens"), - supports_image_input=item.get("supports_image_input"), - supports_tools=item.get("supports_tools"), - supports_image_generation=item.get("supports_image_generation"), - enabled=item.get("enabled", False), - metadata=item.get("metadata") or item.get("catalog") or {}, - ) - - -def _connection_read( - conn: Connection | dict, models: list[Model | dict] | None = None -) -> ConnectionRead: - if isinstance(conn, dict): - payload = { - **conn, - "has_api_key": bool(conn.get("api_key")), - "api_key": None, - "models": [_model_read(model) for model in (models or [])], - } - payload.pop("api_key", None) - return ConnectionRead.model_validate(payload) - - return ConnectionRead( - id=conn.id, - provider=conn.provider, - base_url=conn.base_url, - api_key=conn.api_key, - extra=conn.extra or {}, - scope=conn.scope, - search_space_id=conn.search_space_id, - user_id=conn.user_id, - enabled=conn.enabled, - has_api_key=bool(conn.api_key), - models=[_model_read(model) for model in (models or [])], - created_at=conn.created_at, - ) - - -def _apply_model_facts(model: Model, facts: dict) -> None: - model.supports_chat = facts.get("supports_chat") - model.max_input_tokens = facts.get("max_input_tokens") - model.supports_image_input = facts.get("supports_image_input") - model.supports_tools = facts.get("supports_tools") - model.supports_image_generation = facts.get("supports_image_generation") - - -def _complete_selection_facts(conn: Connection, selection: ModelSelection) -> dict: - facts = selection.model_dump() - derived = derive_capabilities(conn, selection.model_id.strip(), selection.metadata) - for key, value in derived.items(): - if facts.get(key) is None: - facts[key] = value - return facts - - -def _selection_to_model(conn: Connection, selection: ModelSelection) -> Model: - source = ( - selection.source - if isinstance(selection.source, ModelSource) - else ModelSource(selection.source) - ) - model = Model( - connection_id=conn.id, - model_id=selection.model_id.strip(), - display_name=selection.display_name, - source=source, - capabilities_override={}, - enabled=selection.enabled, - catalog=selection.metadata, - ) - _apply_model_facts(model, _complete_selection_facts(conn, selection)) - return model - - -def _default_model_for(models: list[Model], capability: str) -> int | None: - for model in models: - if model.enabled and has_capability(model, capability): - return model.id - return None - - -async def _load_role_model( - session: AsyncSession, - search_space_id: int, - model_id: int, -) -> Model | dict | None: - if model_id < 0: - return next( - (model for model in config.GLOBAL_MODELS if model.get("id") == model_id), - None, - ) - - result = await session.execute( - select(Model) - .options(selectinload(Model.connection)) - .where(Model.id == model_id) - ) - model = result.scalars().first() - if model is None or model.connection.search_space_id != search_space_id: - return None - return model - - -def _role_model_enabled(model: Model | dict) -> bool: - if isinstance(model, dict): - return bool(model.get("enabled", True)) - return bool(model.enabled and model.connection.enabled) - - -async def _validate_role_model_id( - session: AsyncSession, - *, - search_space_id: int, - model_id: int | None, - capability: str, -) -> int: - if model_id is None or model_id == 0: - return 0 - - model = await _load_role_model(session, search_space_id, model_id) - if model and _role_model_enabled(model) and has_capability(model, capability): - return model_id - - raise HTTPException( - status_code=400, - detail=f"Selected model is not available for {capability}", - ) - - -async def _resolve_role_model_id( - session: AsyncSession, - *, - search_space_id: int, - model_id: int | None, - capability: str, -) -> int: - try: - return await _validate_role_model_id( - session, - search_space_id=search_space_id, - model_id=model_id, - capability=capability, - ) - except HTTPException: - return 0 - - -async def _clear_invalid_roles( - session: AsyncSession, search_space_id: int -) -> SearchSpace: - search_space = await _get_search_space(session, search_space_id) - search_space.chat_model_id = await _resolve_role_model_id( - session, - search_space_id=search_space_id, - model_id=search_space.chat_model_id, - capability="chat", - ) - search_space.vision_model_id = await _resolve_role_model_id( - session, - search_space_id=search_space_id, - model_id=search_space.vision_model_id, - capability="vision", - ) - search_space.image_gen_model_id = await _resolve_role_model_id( - session, - search_space_id=search_space_id, - model_id=search_space.image_gen_model_id, - capability="image_gen", - ) - return search_space - - -async def _default_unset_roles( - session: AsyncSession, - conn: Connection, - models: list[Model], -) -> None: - if conn.scope != ConnectionScope.SEARCH_SPACE or conn.search_space_id is None: - return - search_space = await _get_search_space(session, conn.search_space_id) - if search_space.chat_model_id is None: - search_space.chat_model_id = _default_model_for(models, "chat") - if search_space.vision_model_id is None: - vision_default = None - if search_space.chat_model_id: - chat_model = next( - (m for m in models if m.id == search_space.chat_model_id), None - ) - if chat_model and has_capability(chat_model, "vision"): - vision_default = chat_model.id - search_space.vision_model_id = vision_default or _default_model_for( - models, "vision" - ) - if search_space.image_gen_model_id is None: - search_space.image_gen_model_id = _default_model_for(models, "image_gen") - - -@router.get("/model-providers", response_model=list[ModelProviderRead]) -async def list_model_providers(user: User = Depends(current_active_user)): - del user - local_only = {"ollama_chat", "lm_studio"} - return [ - ModelProviderRead( - provider=provider, - transport=spec.transport.value, - discovery=spec.discovery, - default_base_url=spec.default_base_url, - base_url_required=spec.base_url_required, - auth_style=spec.auth_style, - local_only=provider in local_only, - ) - for provider, spec in sorted(REGISTRY.items()) - ] - - -async def _get_search_space(session: AsyncSession, search_space_id: int) -> SearchSpace: - result = await session.execute( - select(SearchSpace).where(SearchSpace.id == search_space_id) - ) - search_space = result.scalars().first() - if not search_space: - raise HTTPException(status_code=404, detail="Search space not found") - return search_space - - -async def _load_connection(session: AsyncSession, connection_id: int) -> Connection: - result = await session.execute( - select(Connection) - .options(selectinload(Connection.models)) - .where(Connection.id == connection_id) - ) - conn = result.scalars().first() - if not conn: - raise HTTPException(status_code=404, detail="Connection not found") - return conn - - -async def _assert_connection_access( - session: AsyncSession, - user: User, - conn: Connection, - permission: str = Permission.LLM_CONFIGS_CREATE.value, -) -> None: - if conn.search_space_id: - await check_permission( - session, - user, - conn.search_space_id, - permission, - "You don't have permission to manage model connections in this search space", - ) - return - if conn.user_id != user.id: - raise HTTPException( - status_code=403, detail="Connection does not belong to user" - ) - - -@router.get("/global-llm-config-status") -async def global_llm_config_status(user: User = Depends(current_active_user)): - del user - return {"exists": config.GLOBAL_LLM_CONFIG_FILE_EXISTS} - - -@router.get("/global-model-connections", response_model=list[ConnectionRead]) -async def list_global_connections(user: User = Depends(current_active_user)): - del user - models_by_connection: dict[int, list[dict]] = {} - for model in config.GLOBAL_MODELS: - models_by_connection.setdefault(model["connection_id"], []).append(model) - return [ - _connection_read(conn, models_by_connection.get(conn["id"], [])) - for conn in config.GLOBAL_CONNECTIONS - ] - - -@router.get("/model-connections", response_model=list[ConnectionRead]) -async def list_connections( - search_space_id: int | None = None, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - stmt = select(Connection).options(selectinload(Connection.models)) - if search_space_id is not None: - await check_permission( - session, - user, - search_space_id, - Permission.LLM_CONFIGS_CREATE.value, - "You don't have permission to view model connections in this search space", - ) - stmt = stmt.where(Connection.search_space_id == search_space_id) - else: - stmt = stmt.where(Connection.user_id == user.id) - result = await session.execute(stmt.order_by(Connection.id)) - return [ - _connection_read(conn, list(conn.models)) for conn in result.scalars().all() - ] - - -@router.post("/model-connections", response_model=ConnectionRead) -async def create_connection( - data: ConnectionCreate, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - if data.scope == ConnectionScope.GLOBAL: - raise HTTPException(status_code=400, detail="GLOBAL connections are YAML-only") - if data.scope == ConnectionScope.SEARCH_SPACE: - if data.search_space_id is None: - raise HTTPException(status_code=400, detail="search_space_id is required") - await check_permission( - session, - user, - data.search_space_id, - Permission.LLM_CONFIGS_CREATE.value, - "You don't have permission to create model connections in this search space", - ) - payload = data.model_dump(exclude={"search_space_id", "models"}) - - conn = Connection( - **payload, - search_space_id=data.search_space_id - if data.scope == ConnectionScope.SEARCH_SPACE - else None, - user_id=user.id, - ) - session.add(conn) - await session.flush() - - seen_model_ids: set[str] = set() - for selection in data.models: - model_id = selection.model_id.strip() - if not model_id or model_id in seen_model_ids: - continue - seen_model_ids.add(model_id) - session.add(_selection_to_model(conn, selection)) - - await session.commit() - conn = await _load_connection(session, conn.id) - await _default_unset_roles(session, conn, list(conn.models)) - await session.commit() - conn = await _load_connection(session, conn.id) - return _connection_read(conn, list(conn.models)) - - -@router.post( - "/model-connections/discover-preview", response_model=list[ModelPreviewRead] -) -async def preview_connection_models( - data: ConnectionCreate, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - if data.scope == ConnectionScope.SEARCH_SPACE and data.search_space_id is not None: - await check_permission( - session, - user, - data.search_space_id, - Permission.LLM_CONFIGS_CREATE.value, - "You don't have permission to create model connections in this search space", - ) - - draft = Connection( - provider=data.provider, - base_url=data.base_url, - api_key=data.api_key, - extra=data.extra or {}, - scope=data.scope, - enabled=data.enabled, - search_space_id=data.search_space_id - if data.scope == ConnectionScope.SEARCH_SPACE - else None, - user_id=user.id, - ) - try: - discovered = await discover_models(draft) - except ModelDiscoveryError as exc: - raise HTTPException(status_code=400, detail=str(exc)) from exc - return [_preview_model_read(item) for item in discovered] - - -@router.post("/model-connections/test-preview", response_model=VerifyConnectionResponse) -async def test_preview_connection_model( - data: ModelTestPreview, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - if data.scope == ConnectionScope.SEARCH_SPACE and data.search_space_id is not None: - await check_permission( - session, - user, - data.search_space_id, - Permission.LLM_CONFIGS_CREATE.value, - "You don't have permission to create model connections in this search space", - ) - - model_id = data.model_id.strip() - if not model_id: - raise HTTPException(status_code=400, detail="model_id is required") - - draft = Connection( - provider=data.provider, - base_url=data.base_url, - api_key=data.api_key, - extra=data.extra or {}, - scope=data.scope, - enabled=data.enabled, - search_space_id=data.search_space_id - if data.scope == ConnectionScope.SEARCH_SPACE - else None, - user_id=user.id, - ) - model = Model( - connection_id=0, - model_id=model_id, - source=ModelSource.MANUAL, - enabled=True, - capabilities_override={}, - catalog={}, - ) - result = await test_model(draft, model) - return VerifyConnectionResponse( - status=result.status, ok=result.ok, message=result.message - ) - - -@router.put("/model-connections/{connection_id}", response_model=ConnectionRead) -async def update_connection( - connection_id: int, - data: ConnectionUpdate, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - conn = await _load_connection(session, connection_id) - await _assert_connection_access( - session, user, conn, Permission.LLM_CONFIGS_UPDATE.value - ) - search_space_id = conn.search_space_id - for key, value in data.model_dump(exclude_unset=True).items(): - setattr(conn, key, value) - await session.commit() - if search_space_id is not None: - await _clear_invalid_roles(session, search_space_id) - await session.commit() - conn = await _load_connection(session, connection_id) - return _connection_read(conn, list(conn.models)) - - -@router.delete("/model-connections/{connection_id}") -async def delete_connection( - connection_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - conn = await _load_connection(session, connection_id) - await _assert_connection_access( - session, user, conn, Permission.LLM_CONFIGS_DELETE.value - ) - search_space_id = conn.search_space_id - await session.delete(conn) - await session.commit() - if search_space_id is not None: - await _clear_invalid_roles(session, search_space_id) - await session.commit() - return {"status": "deleted"} - - -@router.post( - "/model-connections/{connection_id}/verify", response_model=VerifyConnectionResponse -) -async def verify_model_connection( - connection_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - conn = await _load_connection(session, connection_id) - await _assert_connection_access( - session, user, conn, Permission.LLM_CONFIGS_CREATE.value - ) - result = await verify_connection(conn) - return VerifyConnectionResponse( - status=result.status, ok=result.ok, message=result.message - ) - - -@router.post( - "/model-connections/{connection_id}/discover", response_model=list[ModelRead] -) -async def discover_connection_models( - connection_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - conn = await _load_connection(session, connection_id) - await _assert_connection_access( - session, user, conn, Permission.LLM_CONFIGS_CREATE.value - ) - try: - discovered = await discover_models(conn) - except ModelDiscoveryError as exc: - raise HTTPException(status_code=400, detail=str(exc)) from exc - by_model_id = {model.model_id: model for model in conn.models} - for item in discovered: - db_model = by_model_id.get(item["model_id"]) - if db_model is None: - db_model = Model( - connection_id=conn.id, - model_id=item["model_id"], - display_name=item.get("display_name"), - source=item["source"], - capabilities_override={}, - enabled=False, - catalog=item.get("metadata") or {}, - ) - _apply_model_facts(db_model, item) - session.add(db_model) - else: - db_model.display_name = item.get("display_name") or db_model.display_name - _apply_model_facts(db_model, item) - db_model.catalog = item.get("metadata") or db_model.catalog - await session.commit() - conn = await _load_connection(session, connection_id) - await _default_unset_roles(session, conn, list(conn.models)) - if conn.search_space_id is not None: - await _clear_invalid_roles(session, conn.search_space_id) - await session.commit() - conn = await _load_connection(session, connection_id) - return [_model_read(model) for model in conn.models] - - -@router.post("/model-connections/{connection_id}/models", response_model=ModelRead) -async def add_manual_model( - connection_id: int, - data: ModelCreate, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - conn = await _load_connection(session, connection_id) - await _assert_connection_access( - session, user, conn, Permission.LLM_CONFIGS_UPDATE.value - ) - - model_id = data.model_id.strip() - if not model_id: - raise HTTPException(status_code=400, detail="model_id is required") - if any(existing.model_id == model_id for existing in conn.models): - raise HTTPException( - status_code=400, detail="Model already exists on this connection" - ) - - capabilities = derive_capabilities(conn, model_id) - model = Model( - connection_id=conn.id, - model_id=model_id, - display_name=data.display_name or None, - source=ModelSource.MANUAL, - capabilities_override={}, - enabled=True, - catalog={}, - ) - _apply_model_facts(model, capabilities) - session.add(model) - await session.commit() - await session.refresh(model) - conn = await _load_connection(session, connection_id) - await _default_unset_roles(session, conn, list(conn.models)) - if conn.search_space_id is not None: - await _clear_invalid_roles(session, conn.search_space_id) - await session.commit() - await session.refresh(model) - return _model_read(model) - - -@router.patch( - "/model-connections/{connection_id}/models", response_model=list[ModelRead] -) -async def bulk_update_models( - connection_id: int, - data: ModelsBulkUpdate, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - conn = await _load_connection(session, connection_id) - await _assert_connection_access( - session, user, conn, Permission.LLM_CONFIGS_UPDATE.value - ) - search_space_id = conn.search_space_id - - model_ids = set(data.model_ids) - await session.execute( - update(Model) - .where(Model.connection_id == connection_id, Model.id.in_(model_ids)) - .values(enabled=data.enabled) - ) - await session.commit() - session.expire_all() - if search_space_id is not None: - await _clear_invalid_roles(session, search_space_id) - await session.commit() - session.expire_all() - - result = await session.execute( - select(Model) - .where(Model.connection_id == connection_id, Model.id.in_(model_ids)) - .order_by(Model.id) - ) - return [_model_read(model) for model in result.scalars().all()] - - -@router.put("/models/{model_id}", response_model=ModelRead) -async def update_model( - model_id: int, - data: ModelUpdate, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - result = await session.execute( - select(Model) - .options(selectinload(Model.connection)) - .where(Model.id == model_id) - ) - model = result.scalars().first() - if not model: - raise HTTPException(status_code=404, detail="Model not found") - await _assert_connection_access( - session, user, model.connection, Permission.LLM_CONFIGS_UPDATE.value - ) - search_space_id = model.connection.search_space_id - update = data.model_dump(exclude_unset=True) - for key, value in update.items(): - setattr(model, key, value) - await session.commit() - await session.refresh(model) - if search_space_id is not None: - await _clear_invalid_roles(session, search_space_id) - await session.commit() - await session.refresh(model) - return _model_read(model) - - -@router.post("/models/{model_id}/test", response_model=VerifyConnectionResponse) -async def test_connection_model( - model_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - result = await session.execute( - select(Model) - .options(selectinload(Model.connection)) - .where(Model.id == model_id) - ) - model = result.scalars().first() - if not model: - raise HTTPException(status_code=404, detail="Model not found") - await _assert_connection_access( - session, user, model.connection, Permission.LLM_CONFIGS_UPDATE.value - ) - result = await test_model(model.connection, model) - await session.commit() - return VerifyConnectionResponse( - status=result.status, ok=result.ok, message=result.message - ) - - -@router.get( - "/search-spaces/{search_space_id}/model-roles", response_model=ModelRolesRead -) -async def get_model_roles( - search_space_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - await check_permission( - session, - user, - search_space_id, - Permission.LLM_CONFIGS_CREATE.value, - "You don't have permission to view model roles in this search space", - ) - search_space = await _clear_invalid_roles(session, search_space_id) - await session.commit() - await session.refresh(search_space) - return ModelRolesRead( - chat_model_id=search_space.chat_model_id, - vision_model_id=search_space.vision_model_id, - image_gen_model_id=search_space.image_gen_model_id, - ) - - -@router.put( - "/search-spaces/{search_space_id}/model-roles", response_model=ModelRolesRead -) -async def update_model_roles( - search_space_id: int, - data: ModelRolesUpdate, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - await check_permission( - session, - user, - search_space_id, - Permission.LLM_CONFIGS_UPDATE.value, - "You don't have permission to update model roles in this search space", - ) - search_space = await _get_search_space(session, search_space_id) - updates = data.model_dump(exclude_unset=True) - if "chat_model_id" in updates: - previous_chat_model_id = search_space.chat_model_id - next_chat_model_id = await _validate_role_model_id( - session, - search_space_id=search_space_id, - model_id=updates["chat_model_id"], - capability="chat", - ) - search_space.chat_model_id = next_chat_model_id - if next_chat_model_id != previous_chat_model_id: - await session.execute( - update(NewChatThread) - .where(NewChatThread.search_space_id == search_space_id) - .values(pinned_llm_config_id=None) - ) - logger.info( - "Cleared auto model pins for search_space_id=%s after chat_model_id change (%s -> %s)", - search_space_id, - previous_chat_model_id, - next_chat_model_id, - ) - if "vision_model_id" in updates: - search_space.vision_model_id = await _validate_role_model_id( - session, - search_space_id=search_space_id, - model_id=updates["vision_model_id"], - capability="vision", - ) - if "image_gen_model_id" in updates: - search_space.image_gen_model_id = await _validate_role_model_id( - session, - search_space_id=search_space_id, - model_id=updates["image_gen_model_id"], - capability="image_gen", - ) - await session.commit() - await session.refresh(search_space) - return ModelRolesRead( - chat_model_id=search_space.chat_model_id, - vision_model_id=search_space.vision_model_id, - image_gen_model_id=search_space.image_gen_model_id, - ) diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index b5bc2571e..0e4e557be 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -1741,11 +1741,12 @@ async def handle_new_chat( if not search_space: raise HTTPException(status_code=404, detail="Search space not found") - # Use the converged model-connections role for chat operations. - # Positive IDs load Model + Connection rows; negative IDs load - # virtual GLOBAL models; 0 means Auto. + # Use agent_llm_id from search space for chat operations + # Positive IDs load from NewLLMConfig database table + # Negative IDs load from YAML global configs + # Falls back to -1 (first global config) if not configured llm_config_id = ( - search_space.chat_model_id if search_space.chat_model_id is not None else 0 + search_space.agent_llm_id if search_space.agent_llm_id is not None else -1 ) # Release the read-transaction so we don't hold ACCESS SHARE locks @@ -2227,7 +2228,7 @@ async def regenerate_response( raise HTTPException(status_code=404, detail="Search space not found") llm_config_id = ( - search_space.chat_model_id if search_space.chat_model_id is not None else 0 + search_space.agent_llm_id if search_space.agent_llm_id is not None else -1 ) # Release the read-transaction so we don't hold ACCESS SHARE locks @@ -2392,7 +2393,7 @@ async def resume_chat( raise HTTPException(status_code=404, detail="Search space not found") llm_config_id = ( - search_space.chat_model_id if search_space.chat_model_id is not None else 0 + search_space.agent_llm_id if search_space.agent_llm_id is not None else -1 ) decisions = [d.model_dump() for d in request.decisions] diff --git a/surfsense_backend/app/routes/new_llm_config_routes.py b/surfsense_backend/app/routes/new_llm_config_routes.py new file mode 100644 index 000000000..84d66bb13 --- /dev/null +++ b/surfsense_backend/app/routes/new_llm_config_routes.py @@ -0,0 +1,480 @@ +""" +API routes for NewLLMConfig CRUD operations. + +NewLLMConfig combines model settings with prompt configuration: +- LLM provider, model, API key, etc. +- Configurable system instructions +- Citation toggle +""" + +import logging + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + +from app.config import config +from app.db import ( + NewLLMConfig, + Permission, + User, + get_async_session, +) +from app.prompts.default_system_instructions import get_default_system_instructions +from app.schemas import ( + DefaultSystemInstructionsResponse, + GlobalNewLLMConfigRead, + NewLLMConfigCreate, + NewLLMConfigRead, + NewLLMConfigUpdate, +) +from app.services.llm_service import validate_llm_config +from app.services.provider_capabilities import derive_supports_image_input +from app.users import current_active_user +from app.utils.rbac import check_permission + +router = APIRouter() +logger = logging.getLogger(__name__) + + +def _serialize_byok_config(config: NewLLMConfig) -> NewLLMConfigRead: + """Augment a BYOK chat config row with the derived ``supports_image_input``. + + There is no DB column for ``supports_image_input`` — the value is + resolved at the API boundary from LiteLLM's authoritative model map + (default-allow on unknown). Returning ``NewLLMConfigRead`` here keeps + the response shape consistent across list / detail / create / update + endpoints without having to remember to set the field at every call + site. + """ + provider_value = ( + config.provider.value + if hasattr(config.provider, "value") + else str(config.provider) + ) + litellm_params = config.litellm_params or {} + base_model = ( + litellm_params.get("base_model") if isinstance(litellm_params, dict) else None + ) + supports_image_input = derive_supports_image_input( + provider=provider_value, + model_name=config.model_name, + base_model=base_model, + custom_provider=config.custom_provider, + ) + # ``model_validate`` runs the Pydantic conversion using the ORM + # attribute access path enabled by ``ConfigDict(from_attributes=True)``, + # then we layer the derived field on. ``model_copy(update=...)`` keeps + # the surface immutable from the caller's perspective. + base_read = NewLLMConfigRead.model_validate(config) + return base_read.model_copy(update={"supports_image_input": supports_image_input}) + + +# ============================================================================= +# Global Configs Routes +# ============================================================================= + + +@router.get("/global-new-llm-configs", response_model=list[GlobalNewLLMConfigRead]) +async def get_global_new_llm_configs( + user: User = Depends(current_active_user), +): + """ + Get all available global NewLLMConfig configurations. + These are pre-configured by the system administrator and available to all users. + API keys are not exposed through this endpoint. + + Includes: + - Auto mode (ID 0): Uses LiteLLM Router for automatic load balancing + - Global configs (negative IDs): Individual pre-configured LLM providers + """ + try: + global_configs = config.GLOBAL_LLM_CONFIGS + safe_configs = [] + + # Only include Auto mode if there are actual global configs to route to + # Auto mode requires at least one global config with valid API key + if global_configs and len(global_configs) > 0: + safe_configs.append( + { + "id": 0, + "name": "Auto (Fastest)", + "description": "Automatically routes requests across available LLM providers for optimal performance and rate limit handling. Recommended for most users.", + "provider": "AUTO", + "custom_provider": None, + "model_name": "auto", + "api_base": None, + "litellm_params": {}, + "system_instructions": "", + "use_default_system_instructions": True, + "citations_enabled": True, + "is_global": True, + "is_auto_mode": True, + "billing_tier": "free", + "is_premium": False, + "anonymous_enabled": False, + "seo_enabled": False, + "seo_slug": None, + "seo_title": None, + "seo_description": None, + "quota_reserve_tokens": None, + # Auto routes across the configured pool, which usually + # includes at least one vision-capable deployment, so + # treat Auto as image-capable. The router itself will + # still pick a vision-capable deployment for messages + # carrying image_url blocks (LiteLLM Router falls back + # on ``404`` per its ``allowed_fails`` policy). + "supports_image_input": True, + } + ) + + # Add individual global configs + for cfg in global_configs: + # Capability resolution: explicit value (YAML override or OR + # `_supports_image_input(model)` payload baked in by the + # OpenRouter integration service) wins. Fall back to the + # LiteLLM-driven helper which default-allows on unknown so + # we don't hide vision-capable models that happen to lack a + # YAML annotation. The streaming task safety net is the + # only place a False ever blocks. + if "supports_image_input" in cfg: + supports_image_input = bool(cfg.get("supports_image_input")) + else: + cfg_litellm_params = cfg.get("litellm_params") or {} + cfg_base_model = ( + cfg_litellm_params.get("base_model") + if isinstance(cfg_litellm_params, dict) + else None + ) + supports_image_input = derive_supports_image_input( + provider=cfg.get("provider"), + model_name=cfg.get("model_name"), + base_model=cfg_base_model, + custom_provider=cfg.get("custom_provider"), + ) + + safe_config = { + "id": cfg.get("id"), + "name": cfg.get("name"), + "description": cfg.get("description"), + "provider": cfg.get("provider"), + "custom_provider": cfg.get("custom_provider"), + "model_name": cfg.get("model_name"), + "api_base": cfg.get("api_base") or None, + "litellm_params": cfg.get("litellm_params", {}), + # New prompt configuration fields + "system_instructions": cfg.get("system_instructions", ""), + "use_default_system_instructions": cfg.get( + "use_default_system_instructions", True + ), + "citations_enabled": cfg.get("citations_enabled", True), + "is_global": True, + "billing_tier": cfg.get("billing_tier", "free"), + "is_premium": cfg.get("billing_tier", "free") == "premium", + "anonymous_enabled": cfg.get("anonymous_enabled", False), + "seo_enabled": cfg.get("seo_enabled", False), + "seo_slug": cfg.get("seo_slug"), + "seo_title": cfg.get("seo_title"), + "seo_description": cfg.get("seo_description"), + "quota_reserve_tokens": cfg.get("quota_reserve_tokens"), + "supports_image_input": supports_image_input, + } + safe_configs.append(safe_config) + + return safe_configs + except Exception as e: + logger.exception("Failed to fetch global NewLLMConfigs") + raise HTTPException( + status_code=500, detail=f"Failed to fetch global configurations: {e!s}" + ) from e + + +# ============================================================================= +# CRUD Routes +# ============================================================================= + + +@router.post("/new-llm-configs", response_model=NewLLMConfigRead) +async def create_new_llm_config( + config_data: NewLLMConfigCreate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Create a new NewLLMConfig for a search space. + Requires LLM_CONFIGS_CREATE permission. + """ + try: + # Verify user has permission + await check_permission( + session, + user, + config_data.search_space_id, + Permission.LLM_CONFIGS_CREATE.value, + "You don't have permission to create LLM configurations in this search space", + ) + + # Validate the LLM configuration by making a test API call + is_valid, error_message = await validate_llm_config( + provider=config_data.provider.value, + model_name=config_data.model_name, + api_key=config_data.api_key, + api_base=config_data.api_base, + custom_provider=config_data.custom_provider, + litellm_params=config_data.litellm_params, + ) + + if not is_valid: + raise HTTPException( + status_code=400, + detail=f"Invalid LLM configuration: {error_message}", + ) + + # Create the config with user association + db_config = NewLLMConfig(**config_data.model_dump(), user_id=user.id) + session.add(db_config) + await session.commit() + await session.refresh(db_config) + + return _serialize_byok_config(db_config) + + except HTTPException: + raise + except Exception as e: + await session.rollback() + logger.exception("Failed to create NewLLMConfig") + raise HTTPException( + status_code=500, detail=f"Failed to create configuration: {e!s}" + ) from e + + +@router.get("/new-llm-configs", response_model=list[NewLLMConfigRead]) +async def list_new_llm_configs( + search_space_id: int, + skip: int = 0, + limit: int = 100, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Get all NewLLMConfigs for a search space. + Requires LLM_CONFIGS_READ permission. + """ + try: + # Verify user has permission + await check_permission( + session, + user, + search_space_id, + Permission.LLM_CONFIGS_READ.value, + "You don't have permission to view LLM configurations in this search space", + ) + + result = await session.execute( + select(NewLLMConfig) + .filter(NewLLMConfig.search_space_id == search_space_id) + .order_by(NewLLMConfig.created_at.desc()) + .offset(skip) + .limit(limit) + ) + + return [_serialize_byok_config(cfg) for cfg in result.scalars().all()] + + except HTTPException: + raise + except Exception as e: + logger.exception("Failed to list NewLLMConfigs") + raise HTTPException( + status_code=500, detail=f"Failed to fetch configurations: {e!s}" + ) from e + + +@router.get( + "/new-llm-configs/default-system-instructions", + response_model=DefaultSystemInstructionsResponse, +) +async def get_default_system_instructions_endpoint( + user: User = Depends(current_active_user), +): + """ + Get the default SURFSENSE_SYSTEM_INSTRUCTIONS template. + Useful for pre-populating the UI when creating a new configuration. + """ + return DefaultSystemInstructionsResponse( + default_system_instructions=get_default_system_instructions() + ) + + +@router.get("/new-llm-configs/{config_id}", response_model=NewLLMConfigRead) +async def get_new_llm_config( + config_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Get a specific NewLLMConfig by ID. + Requires LLM_CONFIGS_READ permission. + """ + try: + result = await session.execute( + select(NewLLMConfig).filter(NewLLMConfig.id == config_id) + ) + config = result.scalars().first() + + if not config: + raise HTTPException(status_code=404, detail="Configuration not found") + + # Verify user has permission + await check_permission( + session, + user, + config.search_space_id, + Permission.LLM_CONFIGS_READ.value, + "You don't have permission to view LLM configurations in this search space", + ) + + return _serialize_byok_config(config) + + except HTTPException: + raise + except Exception as e: + logger.exception("Failed to get NewLLMConfig") + raise HTTPException( + status_code=500, detail=f"Failed to fetch configuration: {e!s}" + ) from e + + +@router.put("/new-llm-configs/{config_id}", response_model=NewLLMConfigRead) +async def update_new_llm_config( + config_id: int, + update_data: NewLLMConfigUpdate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Update an existing NewLLMConfig. + Requires LLM_CONFIGS_UPDATE permission. + """ + try: + result = await session.execute( + select(NewLLMConfig).filter(NewLLMConfig.id == config_id) + ) + config = result.scalars().first() + + if not config: + raise HTTPException(status_code=404, detail="Configuration not found") + + # Verify user has permission + await check_permission( + session, + user, + config.search_space_id, + Permission.LLM_CONFIGS_UPDATE.value, + "You don't have permission to update LLM configurations in this search space", + ) + + update_dict = update_data.model_dump(exclude_unset=True) + + # If updating LLM settings, validate them + if any( + key in update_dict + for key in [ + "provider", + "model_name", + "api_key", + "api_base", + "custom_provider", + "litellm_params", + ] + ): + # Build the validation config from existing + updates + validation_config = { + "provider": update_dict.get("provider", config.provider).value + if hasattr(update_dict.get("provider", config.provider), "value") + else update_dict.get("provider", config.provider.value), + "model_name": update_dict.get("model_name", config.model_name), + "api_key": update_dict.get("api_key", config.api_key), + "api_base": update_dict.get("api_base", config.api_base), + "custom_provider": update_dict.get( + "custom_provider", config.custom_provider + ), + "litellm_params": update_dict.get( + "litellm_params", config.litellm_params + ), + } + + is_valid, error_message = await validate_llm_config( + provider=validation_config["provider"], + model_name=validation_config["model_name"], + api_key=validation_config["api_key"], + api_base=validation_config["api_base"], + custom_provider=validation_config["custom_provider"], + litellm_params=validation_config["litellm_params"], + ) + + if not is_valid: + raise HTTPException( + status_code=400, + detail=f"Invalid LLM configuration: {error_message}", + ) + + # Apply updates + for key, value in update_dict.items(): + setattr(config, key, value) + + await session.commit() + await session.refresh(config) + + return _serialize_byok_config(config) + + except HTTPException: + raise + except Exception as e: + await session.rollback() + logger.exception("Failed to update NewLLMConfig") + raise HTTPException( + status_code=500, detail=f"Failed to update configuration: {e!s}" + ) from e + + +@router.delete("/new-llm-configs/{config_id}", response_model=dict) +async def delete_new_llm_config( + config_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Delete a NewLLMConfig. + Requires LLM_CONFIGS_DELETE permission. + """ + try: + result = await session.execute( + select(NewLLMConfig).filter(NewLLMConfig.id == config_id) + ) + config = result.scalars().first() + + if not config: + raise HTTPException(status_code=404, detail="Configuration not found") + + # Verify user has permission + await check_permission( + session, + user, + config.search_space_id, + Permission.LLM_CONFIGS_DELETE.value, + "You don't have permission to delete LLM configurations in this search space", + ) + + await session.delete(config) + await session.commit() + + return {"message": "Configuration deleted successfully", "id": config_id} + + except HTTPException: + raise + except Exception as e: + await session.rollback() + logger.exception("Failed to delete NewLLMConfig") + raise HTTPException( + status_code=500, detail=f"Failed to delete configuration: {e!s}" + ) from e diff --git a/surfsense_backend/app/routes/public_chat_routes.py b/surfsense_backend/app/routes/public_chat_routes.py index 516e976e6..53f4c2651 100644 --- a/surfsense_backend/app/routes/public_chat_routes.py +++ b/surfsense_backend/app/routes/public_chat_routes.py @@ -103,14 +103,8 @@ async def stream_public_podcast( if storage_key: from app.file_storage.factory import get_storage_backend - backend = get_storage_backend() - # Verify first so a missing object is a 404, not a mid-stream crash. - if not await backend.exists(storage_key): - raise HTTPException( - status_code=404, detail="Podcast audio is no longer available" - ) return StreamingResponse( - backend.open_stream(storage_key), + get_storage_backend().open_stream(storage_key), media_type="audio/mpeg", headers={"Accept-Ranges": "bytes"}, ) diff --git a/surfsense_backend/app/routes/search_source_connectors_routes.py b/surfsense_backend/app/routes/search_source_connectors_routes.py index 512b52ae4..dc26b4c02 100644 --- a/surfsense_backend/app/routes/search_source_connectors_routes.py +++ b/surfsense_backend/app/routes/search_source_connectors_routes.py @@ -745,23 +745,11 @@ async def index_connector_content( if not connector: raise HTTPException(status_code=404, detail="Connector not found") - # Ensure the connector actually belongs to the requested search space. - # Without this, the permission check below would authorize against the - # caller-supplied search_space_id (their own space) while the connector - # lives in another user's space, allowing cross-tenant indexing of a - # foreign connector (and use of its stored credentials). Returning 404 - # (rather than 403) on a mismatch also avoids disclosing the existence of - # connectors in other search spaces. - if connector.search_space_id != search_space_id: - raise HTTPException(status_code=404, detail="Connector not found") - - # Check if user has permission to update connectors (indexing is an update - # operation). Authorize against the connector's OWN search space — matching - # the read/update/delete handlers — not the client-supplied query param. + # Check if user has permission to update connectors (indexing is an update operation) await check_permission( session, user, - connector.search_space_id, + search_space_id, Permission.CONNECTORS_UPDATE.value, "You don't have permission to index content in this search space", ) diff --git a/surfsense_backend/app/routes/search_spaces_routes.py b/surfsense_backend/app/routes/search_spaces_routes.py index 592a9dd0e..898077b7a 100644 --- a/surfsense_backend/app/routes/search_spaces_routes.py +++ b/surfsense_backend/app/routes/search_spaces_routes.py @@ -1,20 +1,27 @@ import logging from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy import func +from sqlalchemy import func, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select +from app.config import config from app.db import ( + ImageGenerationConfig, + NewChatThread, + NewLLMConfig, Permission, SearchSpace, SearchSpaceMembership, SearchSpaceRole, User, + VisionLLMConfig, get_async_session, get_default_roles_config, ) from app.schemas import ( + LLMPreferencesRead, + LLMPreferencesUpdate, SearchSpaceCreate, SearchSpaceRead, SearchSpaceUpdate, @@ -370,6 +377,357 @@ async def delete_search_space( ) from e +# ============================================================================= +# LLM Preferences Routes +# ============================================================================= + + +async def _get_llm_config_by_id( + session: AsyncSession, config_id: int | None +) -> dict | None: + """ + Get an LLM config by ID as a dictionary. Returns database config for positive IDs, + global config for negative IDs, Auto mode config for ID 0, or None if ID is None. + """ + if config_id is None: + return None + + # Auto mode (ID 0) - uses LiteLLM Router for load balancing + if config_id == 0: + return { + "id": 0, + "name": "Auto (Fastest)", + "description": "Automatically routes requests across available LLM providers for optimal performance and rate limit handling", + "provider": "AUTO", + "custom_provider": None, + "model_name": "auto", + "api_base": None, + "litellm_params": {}, + "system_instructions": "", + "use_default_system_instructions": True, + "citations_enabled": True, + "is_global": True, + "is_auto_mode": True, + } + + if config_id < 0: + # Global config - find from YAML + global_configs = config.GLOBAL_LLM_CONFIGS + for cfg in global_configs: + if cfg.get("id") == config_id: + return { + "id": cfg.get("id"), + "name": cfg.get("name"), + "description": cfg.get("description"), + "provider": cfg.get("provider"), + "custom_provider": cfg.get("custom_provider"), + "model_name": cfg.get("model_name"), + "api_base": cfg.get("api_base"), + "litellm_params": cfg.get("litellm_params", {}), + "system_instructions": cfg.get("system_instructions", ""), + "use_default_system_instructions": cfg.get( + "use_default_system_instructions", True + ), + "citations_enabled": cfg.get("citations_enabled", True), + "is_global": True, + } + return None + else: + # Database config - convert to dict + result = await session.execute( + select(NewLLMConfig).filter(NewLLMConfig.id == config_id) + ) + db_config = result.scalars().first() + if db_config: + return { + "id": db_config.id, + "name": db_config.name, + "description": db_config.description, + "provider": db_config.provider.value if db_config.provider else None, + "custom_provider": db_config.custom_provider, + "model_name": db_config.model_name, + "api_key": db_config.api_key, + "api_base": db_config.api_base, + "litellm_params": db_config.litellm_params or {}, + "system_instructions": db_config.system_instructions or "", + "use_default_system_instructions": db_config.use_default_system_instructions, + "citations_enabled": db_config.citations_enabled, + "created_at": db_config.created_at.isoformat() + if db_config.created_at + else None, + "search_space_id": db_config.search_space_id, + } + return None + + +async def _get_image_gen_config_by_id( + session: AsyncSession, config_id: int | None +) -> dict | None: + """ + Get an image generation config by ID as a dictionary. + Returns Auto mode for ID 0, global config for negative IDs, + DB ImageGenerationConfig for positive IDs, or None. + """ + if config_id is None: + return None + + if config_id == 0: + return { + "id": 0, + "name": "Auto (Fastest)", + "description": "Automatically routes requests across available image generation providers", + "provider": "AUTO", + "model_name": "auto", + "is_global": True, + "is_auto_mode": True, + "billing_tier": "free", + } + + if config_id < 0: + for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS: + if cfg.get("id") == config_id: + return { + "id": cfg.get("id"), + "name": cfg.get("name"), + "description": cfg.get("description"), + "provider": cfg.get("provider"), + "custom_provider": cfg.get("custom_provider"), + "model_name": cfg.get("model_name"), + "api_base": cfg.get("api_base") or None, + "api_version": cfg.get("api_version") or None, + "litellm_params": cfg.get("litellm_params", {}), + "is_global": True, + "billing_tier": cfg.get("billing_tier", "free"), + } + return None + + # Positive ID: query ImageGenerationConfig table + result = await session.execute( + select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id) + ) + db_config = result.scalars().first() + if db_config: + return { + "id": db_config.id, + "name": db_config.name, + "description": db_config.description, + "provider": db_config.provider.value if db_config.provider else None, + "custom_provider": db_config.custom_provider, + "model_name": db_config.model_name, + "api_base": db_config.api_base, + "api_version": db_config.api_version, + "litellm_params": db_config.litellm_params or {}, + "created_at": db_config.created_at.isoformat() + if db_config.created_at + else None, + "search_space_id": db_config.search_space_id, + } + return None + + +async def _get_vision_llm_config_by_id( + session: AsyncSession, config_id: int | None +) -> dict | None: + if config_id is None: + return None + + if config_id == 0: + return { + "id": 0, + "name": "Auto (Fastest)", + "description": "Automatically routes requests across available vision LLM providers", + "provider": "AUTO", + "model_name": "auto", + "is_global": True, + "is_auto_mode": True, + "billing_tier": "free", + } + + if config_id < 0: + for cfg in config.GLOBAL_VISION_LLM_CONFIGS: + if cfg.get("id") == config_id: + return { + "id": cfg.get("id"), + "name": cfg.get("name"), + "description": cfg.get("description"), + "provider": cfg.get("provider"), + "custom_provider": cfg.get("custom_provider"), + "model_name": cfg.get("model_name"), + "api_base": cfg.get("api_base") or None, + "api_version": cfg.get("api_version") or None, + "litellm_params": cfg.get("litellm_params", {}), + "is_global": True, + "billing_tier": cfg.get("billing_tier", "free"), + } + return None + + result = await session.execute( + select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id) + ) + db_config = result.scalars().first() + if db_config: + return { + "id": db_config.id, + "name": db_config.name, + "description": db_config.description, + "provider": db_config.provider.value if db_config.provider else None, + "custom_provider": db_config.custom_provider, + "model_name": db_config.model_name, + "api_base": db_config.api_base, + "api_version": db_config.api_version, + "litellm_params": db_config.litellm_params or {}, + "created_at": db_config.created_at.isoformat() + if db_config.created_at + else None, + "search_space_id": db_config.search_space_id, + } + return None + + +@router.get( + "/search-spaces/{search_space_id}/llm-preferences", + response_model=LLMPreferencesRead, +) +async def get_llm_preferences( + search_space_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Get LLM preferences (role assignments) for a search space. + Requires LLM_CONFIGS_READ permission. + """ + try: + # Check permission + await check_permission( + session, + user, + search_space_id, + Permission.LLM_CONFIGS_READ.value, + "You don't have permission to view LLM preferences", + ) + + result = await session.execute( + select(SearchSpace).filter(SearchSpace.id == search_space_id) + ) + search_space = result.scalars().first() + + if not search_space: + raise HTTPException(status_code=404, detail="Search space not found") + + # Get full config objects for each role + agent_llm = await _get_llm_config_by_id(session, search_space.agent_llm_id) + image_generation_config = await _get_image_gen_config_by_id( + session, search_space.image_generation_config_id + ) + vision_llm_config = await _get_vision_llm_config_by_id( + session, search_space.vision_llm_config_id + ) + + return LLMPreferencesRead( + agent_llm_id=search_space.agent_llm_id, + image_generation_config_id=search_space.image_generation_config_id, + vision_llm_config_id=search_space.vision_llm_config_id, + agent_llm=agent_llm, + image_generation_config=image_generation_config, + vision_llm_config=vision_llm_config, + ) + + except HTTPException: + raise + except Exception as e: + logger.exception("Failed to get LLM preferences") + raise HTTPException( + status_code=500, detail=f"Failed to get LLM preferences: {e!s}" + ) from e + + +@router.put( + "/search-spaces/{search_space_id}/llm-preferences", + response_model=LLMPreferencesRead, +) +async def update_llm_preferences( + search_space_id: int, + preferences: LLMPreferencesUpdate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Update LLM preferences (role assignments) for a search space. + Requires LLM_CONFIGS_UPDATE permission. + """ + try: + # Check permission + await check_permission( + session, + user, + search_space_id, + Permission.LLM_CONFIGS_UPDATE.value, + "You don't have permission to update LLM preferences", + ) + + result = await session.execute( + select(SearchSpace).filter(SearchSpace.id == search_space_id) + ) + search_space = result.scalars().first() + + if not search_space: + raise HTTPException(status_code=404, detail="Search space not found") + + # Update preferences + update_data = preferences.model_dump(exclude_unset=True) + previous_agent_llm_id = search_space.agent_llm_id + for key, value in update_data.items(): + setattr(search_space, key, value) + + agent_llm_changed = ( + "agent_llm_id" in update_data + and update_data["agent_llm_id"] != previous_agent_llm_id + ) + if agent_llm_changed: + await session.execute( + update(NewChatThread) + .where(NewChatThread.search_space_id == search_space_id) + .values(pinned_llm_config_id=None) + ) + logger.info( + "Cleared auto model pins for search_space_id=%s after agent_llm_id change (%s -> %s)", + search_space_id, + previous_agent_llm_id, + update_data["agent_llm_id"], + ) + + await session.commit() + await session.refresh(search_space) + + # Get full config objects for response + agent_llm = await _get_llm_config_by_id(session, search_space.agent_llm_id) + image_generation_config = await _get_image_gen_config_by_id( + session, search_space.image_generation_config_id + ) + vision_llm_config = await _get_vision_llm_config_by_id( + session, search_space.vision_llm_config_id + ) + + return LLMPreferencesRead( + agent_llm_id=search_space.agent_llm_id, + image_generation_config_id=search_space.image_generation_config_id, + vision_llm_config_id=search_space.vision_llm_config_id, + agent_llm=agent_llm, + image_generation_config=image_generation_config, + vision_llm_config=vision_llm_config, + ) + + except HTTPException: + raise + except Exception as e: + await session.rollback() + logger.exception("Failed to update LLM preferences") + raise HTTPException( + status_code=500, detail=f"Failed to update LLM preferences: {e!s}" + ) from e + + @router.get("/searchspaces/{search_space_id}/snapshots") async def list_search_space_snapshots( search_space_id: int, diff --git a/surfsense_backend/app/routes/vision_llm_routes.py b/surfsense_backend/app/routes/vision_llm_routes.py new file mode 100644 index 000000000..e4f08f604 --- /dev/null +++ b/surfsense_backend/app/routes/vision_llm_routes.py @@ -0,0 +1,304 @@ +import logging + +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config import config +from app.db import ( + Permission, + User, + VisionLLMConfig, + get_async_session, +) +from app.schemas import ( + GlobalVisionLLMConfigRead, + VisionLLMConfigCreate, + VisionLLMConfigRead, + VisionLLMConfigUpdate, +) +from app.services.vision_model_list_service import get_vision_model_list +from app.users import current_active_user +from app.utils.rbac import check_permission + +router = APIRouter() +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Vision Model Catalogue (from OpenRouter, filtered for image-input models) +# ============================================================================= + + +class VisionModelListItem(BaseModel): + value: str + label: str + provider: str + context_window: str | None = None + + +@router.get("/vision-models", response_model=list[VisionModelListItem]) +async def list_vision_models( + user: User = Depends(current_active_user), +): + """Return vision-capable models sourced from OpenRouter (filtered by image input).""" + try: + return await get_vision_model_list() + except Exception as e: + logger.exception("Failed to fetch vision model list") + raise HTTPException( + status_code=500, detail=f"Failed to fetch vision model list: {e!s}" + ) from e + + +# ============================================================================= +# Global Vision LLM Configs (from YAML) +# ============================================================================= + + +@router.get( + "/global-vision-llm-configs", + response_model=list[GlobalVisionLLMConfigRead], +) +async def get_global_vision_llm_configs( + user: User = Depends(current_active_user), +): + try: + global_configs = config.GLOBAL_VISION_LLM_CONFIGS + safe_configs = [] + + if global_configs and len(global_configs) > 0: + safe_configs.append( + { + "id": 0, + "name": "Auto (Fastest)", + "description": "Automatically routes across available vision LLM providers.", + "provider": "AUTO", + "custom_provider": None, + "model_name": "auto", + "api_base": None, + "api_version": None, + "litellm_params": {}, + "is_global": True, + "is_auto_mode": True, + # Auto mode treated as free until per-deployment billing-tier + # surfacing lands; see ``get_vision_llm`` for parity. + "billing_tier": "free", + "is_premium": False, + } + ) + + for cfg in global_configs: + billing_tier = str(cfg.get("billing_tier", "free")).lower() + safe_configs.append( + { + "id": cfg.get("id"), + "name": cfg.get("name"), + "description": cfg.get("description"), + "provider": cfg.get("provider"), + "custom_provider": cfg.get("custom_provider"), + "model_name": cfg.get("model_name"), + "api_base": cfg.get("api_base") or None, + "api_version": cfg.get("api_version") or None, + "litellm_params": cfg.get("litellm_params", {}), + "is_global": True, + "billing_tier": billing_tier, + # Mirror chat (``new_llm_config_routes``) so the new-chat + # selector's premium badge logic keys off the same + # field across chat / image / vision tabs. + "is_premium": billing_tier == "premium", + "quota_reserve_tokens": cfg.get("quota_reserve_tokens"), + "input_cost_per_token": cfg.get("input_cost_per_token"), + "output_cost_per_token": cfg.get("output_cost_per_token"), + } + ) + + return safe_configs + except Exception as e: + logger.exception("Failed to fetch global vision LLM configs") + raise HTTPException( + status_code=500, detail=f"Failed to fetch configs: {e!s}" + ) from e + + +# ============================================================================= +# VisionLLMConfig CRUD +# ============================================================================= + + +@router.post("/vision-llm-configs", response_model=VisionLLMConfigRead) +async def create_vision_llm_config( + config_data: VisionLLMConfigCreate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + try: + await check_permission( + session, + user, + config_data.search_space_id, + Permission.VISION_CONFIGS_CREATE.value, + "You don't have permission to create vision LLM configs in this search space", + ) + + db_config = VisionLLMConfig(**config_data.model_dump(), user_id=user.id) + session.add(db_config) + await session.commit() + await session.refresh(db_config) + return db_config + + except HTTPException: + raise + except Exception as e: + await session.rollback() + logger.exception("Failed to create VisionLLMConfig") + raise HTTPException( + status_code=500, detail=f"Failed to create config: {e!s}" + ) from e + + +@router.get("/vision-llm-configs", response_model=list[VisionLLMConfigRead]) +async def list_vision_llm_configs( + search_space_id: int, + skip: int = 0, + limit: int = 100, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + try: + await check_permission( + session, + user, + search_space_id, + Permission.VISION_CONFIGS_READ.value, + "You don't have permission to view vision LLM configs in this search space", + ) + + result = await session.execute( + select(VisionLLMConfig) + .filter(VisionLLMConfig.search_space_id == search_space_id) + .order_by(VisionLLMConfig.created_at.desc()) + .offset(skip) + .limit(limit) + ) + return result.scalars().all() + + except HTTPException: + raise + except Exception as e: + logger.exception("Failed to list VisionLLMConfigs") + raise HTTPException( + status_code=500, detail=f"Failed to fetch configs: {e!s}" + ) from e + + +@router.get("/vision-llm-configs/{config_id}", response_model=VisionLLMConfigRead) +async def get_vision_llm_config( + config_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + try: + result = await session.execute( + select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id) + ) + db_config = result.scalars().first() + if not db_config: + raise HTTPException(status_code=404, detail="Config not found") + + await check_permission( + session, + user, + db_config.search_space_id, + Permission.VISION_CONFIGS_READ.value, + "You don't have permission to view vision LLM configs in this search space", + ) + return db_config + + except HTTPException: + raise + except Exception as e: + logger.exception("Failed to get VisionLLMConfig") + raise HTTPException( + status_code=500, detail=f"Failed to fetch config: {e!s}" + ) from e + + +@router.put("/vision-llm-configs/{config_id}", response_model=VisionLLMConfigRead) +async def update_vision_llm_config( + config_id: int, + update_data: VisionLLMConfigUpdate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + try: + result = await session.execute( + select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id) + ) + db_config = result.scalars().first() + if not db_config: + raise HTTPException(status_code=404, detail="Config not found") + + await check_permission( + session, + user, + db_config.search_space_id, + Permission.VISION_CONFIGS_CREATE.value, + "You don't have permission to update vision LLM configs in this search space", + ) + + for key, value in update_data.model_dump(exclude_unset=True).items(): + setattr(db_config, key, value) + + await session.commit() + await session.refresh(db_config) + return db_config + + except HTTPException: + raise + except Exception as e: + await session.rollback() + logger.exception("Failed to update VisionLLMConfig") + raise HTTPException( + status_code=500, detail=f"Failed to update config: {e!s}" + ) from e + + +@router.delete("/vision-llm-configs/{config_id}", response_model=dict) +async def delete_vision_llm_config( + config_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + try: + result = await session.execute( + select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id) + ) + db_config = result.scalars().first() + if not db_config: + raise HTTPException(status_code=404, detail="Config not found") + + await check_permission( + session, + user, + db_config.search_space_id, + Permission.VISION_CONFIGS_DELETE.value, + "You don't have permission to delete vision LLM configs in this search space", + ) + + await session.delete(db_config) + await session.commit() + return { + "message": "Vision LLM config deleted successfully", + "id": config_id, + } + + except HTTPException: + raise + except Exception as e: + await session.rollback() + logger.exception("Failed to delete VisionLLMConfig") + raise HTTPException( + status_code=500, detail=f"Failed to delete config: {e!s}" + ) from e diff --git a/surfsense_backend/app/schemas/__init__.py b/surfsense_backend/app/schemas/__init__.py index 7b508a132..212a6aa44 100644 --- a/surfsense_backend/app/schemas/__init__.py +++ b/surfsense_backend/app/schemas/__init__.py @@ -34,27 +34,16 @@ from .folders import ( ) from .google_drive import DriveItem, GoogleDriveIndexingOptions, GoogleDriveIndexRequest from .image_generation import ( + GlobalImageGenConfigRead, + ImageGenerationConfigCreate, + ImageGenerationConfigPublic, + ImageGenerationConfigRead, + ImageGenerationConfigUpdate, ImageGenerationCreate, ImageGenerationListRead, ImageGenerationRead, ) from .logs import LogBase, LogCreate, LogFilter, LogRead, LogUpdate -from .model_connections import ( - ConnectionCreate, - ConnectionRead, - ConnectionUpdate, - ModelCreate, - ModelPreviewRead, - ModelProviderRead, - ModelRead, - ModelRolesRead, - ModelRolesUpdate, - ModelsBulkUpdate, - ModelSelection, - ModelTestPreview, - ModelUpdate, - VerifyConnectionResponse, -) from .new_chat import ( ChatMessage, NewChatMessageAppend, @@ -69,6 +58,16 @@ from .new_chat import ( ThreadListItem, ThreadListResponse, ) +from .new_llm_config import ( + DefaultSystemInstructionsResponse, + GlobalNewLLMConfigRead, + LLMPreferencesRead, + LLMPreferencesUpdate, + NewLLMConfigCreate, + NewLLMConfigPublic, + NewLLMConfigRead, + NewLLMConfigUpdate, +) from .rbac_schemas import ( InviteAcceptRequest, InviteAcceptResponse, @@ -127,6 +126,13 @@ from .video_presentations import ( VideoPresentationRead, VideoPresentationUpdate, ) +from .vision_llm import ( + GlobalVisionLLMConfigRead, + VisionLLMConfigCreate, + VisionLLMConfigPublic, + VisionLLMConfigRead, + VisionLLMConfigUpdate, +) __all__ = [ # Folder schemas @@ -138,15 +144,12 @@ __all__ = [ "ChunkCreate", "ChunkRead", "ChunkUpdate", - # Model connection schemas - "ConnectionCreate", - "ConnectionRead", - "ConnectionUpdate", "CreateCreditCheckoutSessionRequest", "CreateCreditCheckoutSessionResponse", "CreditPurchaseHistoryResponse", "CreditPurchaseRead", "CreditStripeStatusResponse", + "DefaultSystemInstructionsResponse", # Document schemas "DocumentBase", "DocumentMove", @@ -169,10 +172,19 @@ __all__ = [ "FolderRead", "FolderReorder", "FolderUpdate", + "GlobalImageGenConfigRead", + "GlobalNewLLMConfigRead", + # Vision LLM Config schemas + "GlobalVisionLLMConfigRead", "GoogleDriveIndexRequest", "GoogleDriveIndexingOptions", # Base schemas "IDModel", + # Image Generation Config schemas + "ImageGenerationConfigCreate", + "ImageGenerationConfigPublic", + "ImageGenerationConfigRead", + "ImageGenerationConfigUpdate", # Image Generation schemas "ImageGenerationCreate", "ImageGenerationListRead", @@ -184,6 +196,9 @@ __all__ = [ "InviteInfoResponse", "InviteRead", "InviteUpdate", + # LLM Preferences schemas + "LLMPreferencesRead", + "LLMPreferencesUpdate", # Log schemas "LogBase", "LogCreate", @@ -202,16 +217,6 @@ __all__ = [ "MembershipRead", "MembershipReadWithUser", "MembershipUpdate", - "ModelCreate", - "ModelPreviewRead", - "ModelProviderRead", - "ModelRead", - "ModelRolesRead", - "ModelRolesUpdate", - "ModelSelection", - "ModelTestPreview", - "ModelUpdate", - "ModelsBulkUpdate", "NewChatMessageAppend", "NewChatMessageCreate", "NewChatMessageRead", @@ -220,6 +225,11 @@ __all__ = [ "NewChatThreadRead", "NewChatThreadUpdate", "NewChatThreadWithMessages", + # NewLLMConfig schemas + "NewLLMConfigCreate", + "NewLLMConfigPublic", + "NewLLMConfigRead", + "NewLLMConfigUpdate", "PagePurchaseHistoryResponse", "PagePurchaseRead", "PaginatedResponse", @@ -257,10 +267,13 @@ __all__ = [ "UserRead", "UserSearchSpaceAccess", "UserUpdate", - "VerifyConnectionResponse", # Video Presentation schemas "VideoPresentationBase", "VideoPresentationCreate", "VideoPresentationRead", "VideoPresentationUpdate", + "VisionLLMConfigCreate", + "VisionLLMConfigPublic", + "VisionLLMConfigRead", + "VisionLLMConfigUpdate", ] diff --git a/surfsense_backend/app/schemas/image_generation.py b/surfsense_backend/app/schemas/image_generation.py index ebd0fa0ac..4262b2b3f 100644 --- a/surfsense_backend/app/schemas/image_generation.py +++ b/surfsense_backend/app/schemas/image_generation.py @@ -1,10 +1,109 @@ -"""Pydantic schemas for image generation requests/results.""" +""" +Pydantic schemas for Image Generation configs and generation requests. +ImageGenerationConfig: CRUD schemas for user-created image gen model configs. +ImageGeneration: Schemas for the actual image generation requests/results. +GlobalImageGenConfigRead: Schema for admin-configured YAML configs. +""" + +import uuid from datetime import datetime from typing import Any from pydantic import BaseModel, ConfigDict, Field +from app.db import ImageGenProvider + +# ============================================================================= +# ImageGenerationConfig CRUD Schemas +# ============================================================================= + + +class ImageGenerationConfigBase(BaseModel): + """Base schema with fields for ImageGenerationConfig.""" + + name: str = Field( + ..., max_length=100, description="User-friendly name for the config" + ) + description: str | None = Field( + None, max_length=500, description="Optional description" + ) + provider: ImageGenProvider = Field( + ..., + description="Image generation provider (OpenAI, Azure, Google AI Studio, Vertex AI, Bedrock, Recraft, OpenRouter, Xinference, Nscale)", + ) + custom_provider: str | None = Field( + None, max_length=100, description="Custom provider name" + ) + model_name: str = Field( + ..., max_length=100, description="Model name (e.g., dall-e-3, gpt-image-1)" + ) + api_key: str = Field(..., description="API key for the provider") + api_base: str | None = Field( + None, max_length=500, description="Optional API base URL" + ) + api_version: str | None = Field( + None, + max_length=50, + description="Azure-specific API version (e.g., '2024-02-15-preview')", + ) + litellm_params: dict[str, Any] | None = Field( + default=None, description="Additional LiteLLM parameters" + ) + + +class ImageGenerationConfigCreate(ImageGenerationConfigBase): + """Schema for creating a new ImageGenerationConfig.""" + + search_space_id: int = Field( + ..., description="Search space ID to associate the config with" + ) + + +class ImageGenerationConfigUpdate(BaseModel): + """Schema for updating an existing ImageGenerationConfig. All fields optional.""" + + name: str | None = Field(None, max_length=100) + description: str | None = Field(None, max_length=500) + provider: ImageGenProvider | None = None + custom_provider: str | None = Field(None, max_length=100) + model_name: str | None = Field(None, max_length=100) + api_key: str | None = None + api_base: str | None = Field(None, max_length=500) + api_version: str | None = Field(None, max_length=50) + litellm_params: dict[str, Any] | None = None + + +class ImageGenerationConfigRead(ImageGenerationConfigBase): + """Schema for reading an ImageGenerationConfig (includes id and timestamps).""" + + id: int + created_at: datetime + search_space_id: int + user_id: uuid.UUID + + model_config = ConfigDict(from_attributes=True) + + +class ImageGenerationConfigPublic(BaseModel): + """Public schema that hides the API key (for list views).""" + + id: int + name: str + description: str | None = None + provider: ImageGenProvider + custom_provider: str | None = None + model_name: str + api_base: str | None = None + api_version: str | None = None + litellm_params: dict[str, Any] | None = None + created_at: datetime + search_space_id: int + user_id: uuid.UUID + + model_config = ConfigDict(from_attributes=True) + + # ============================================================================= # ImageGeneration (request/result) Schemas # ============================================================================= @@ -37,12 +136,12 @@ class ImageGenerationCreate(BaseModel): search_space_id: int = Field( ..., description="Search space ID to associate the generation with" ) - image_gen_model_id: int | None = Field( + image_generation_config_id: int | None = Field( None, description=( - "Image generation model ID. " - "0 = Auto mode, negative = GLOBAL model, positive = BYOK Model row. " - "If not provided, uses the search space's image_gen_model_id preference." + "Image generation config ID. " + "0 = Auto mode (router), negative = global YAML config, positive = DB config. " + "If not provided, uses the search space's image_generation_config_id preference." ), ) @@ -58,7 +157,7 @@ class ImageGenerationRead(BaseModel): size: str | None = None style: str | None = None response_format: str | None = None - image_gen_model_id: int | None = None + image_generation_config_id: int | None = None response_data: dict[str, Any] | None = None error_message: str | None = None search_space_id: int @@ -104,3 +203,58 @@ class ImageGenerationListRead(BaseModel): is_success=obj.response_data is not None, image_count=image_count, ) + + +# ============================================================================= +# Global Image Gen Config (from YAML) +# ============================================================================= + + +class GlobalImageGenConfigRead(BaseModel): + """ + Schema for reading global image generation configs from YAML. + Global configs have negative IDs. API key is hidden. + ID 0 is reserved for Auto mode (LiteLLM Router load balancing). + + The ``billing_tier`` field allows the frontend to show a Premium/Free + badge and (more importantly) tells the backend whether to debit the + user's premium credit pool when this config is used. ``"free"`` is + the default for backward compatibility — admins must explicitly opt + a global config into ``"premium"``. + """ + + id: int = Field( + ..., + description="Config ID: 0 for Auto mode, negative for global configs", + ) + name: str + description: str | None = None + provider: str + custom_provider: str | None = None + model_name: str + api_base: str | None = None + api_version: str | None = None + litellm_params: dict[str, Any] | None = None + is_global: bool = True + is_auto_mode: bool = False + billing_tier: str = Field( + default="free", + description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).", + ) + is_premium: bool = Field( + default=False, + description=( + "Convenience boolean derived server-side from " + "``billing_tier == 'premium'``. The new-chat model selector " + "keys its Free/Premium badge off this field for parity with " + "chat (`GlobalLLMConfigRead.is_premium`)." + ), + ) + quota_reserve_micros: int | None = Field( + default=None, + description=( + "Optional override for the reservation amount (in micro-USD) used when " + "this image generation is premium. Falls back to " + "QUOTA_DEFAULT_IMAGE_RESERVE_MICROS when omitted." + ), + ) diff --git a/surfsense_backend/app/schemas/model_connections.py b/surfsense_backend/app/schemas/model_connections.py deleted file mode 100644 index 0eec666c1..000000000 --- a/surfsense_backend/app/schemas/model_connections.py +++ /dev/null @@ -1,148 +0,0 @@ -import uuid -from datetime import datetime -from typing import Any - -from pydantic import BaseModel, ConfigDict, Field - -from app.db import ConnectionScope, ModelSource - - -class ModelRead(BaseModel): - id: int - connection_id: int - model_id: str - display_name: str | None = None - source: ModelSource | str - supports_chat: bool | None = None - max_input_tokens: int | None = None - supports_image_input: bool | None = None - supports_tools: bool | None = None - supports_image_generation: bool | None = None - capabilities_override: dict[str, Any] = Field(default_factory=dict) - enabled: bool - billing_tier: str | None = None - catalog: dict[str, Any] = Field(default_factory=dict) - created_at: datetime | None = None - - model_config = ConfigDict(from_attributes=True) - - -class ConnectionRead(BaseModel): - id: int - provider: str - base_url: str | None = None - api_key: str | None = None - extra: dict[str, Any] = Field(default_factory=dict) - scope: ConnectionScope | str - search_space_id: int | None = None - user_id: uuid.UUID | None = None - enabled: bool - has_api_key: bool - models: list[ModelRead] = Field(default_factory=list) - created_at: datetime | None = None - - model_config = ConfigDict(from_attributes=True) - - -class ModelSelection(BaseModel): - model_id: str = Field(..., max_length=255) - display_name: str | None = Field(None, max_length=255) - source: ModelSource | str = ModelSource.DISCOVERED - supports_chat: bool | None = None - max_input_tokens: int | None = None - supports_image_input: bool | None = None - supports_tools: bool | None = None - supports_image_generation: bool | None = None - enabled: bool = False - metadata: dict[str, Any] = Field(default_factory=dict) - - -class ModelPreviewRead(BaseModel): - model_id: str - display_name: str | None = None - source: ModelSource | str = ModelSource.DISCOVERED - supports_chat: bool | None = None - max_input_tokens: int | None = None - supports_image_input: bool | None = None - supports_tools: bool | None = None - supports_image_generation: bool | None = None - enabled: bool = False - metadata: dict[str, Any] = Field(default_factory=dict) - - -class ConnectionCreate(BaseModel): - provider: str = Field(..., max_length=100) - base_url: str | None = Field(None, max_length=500) - api_key: str | None = None - extra: dict[str, Any] = Field(default_factory=dict) - scope: ConnectionScope = ConnectionScope.SEARCH_SPACE - search_space_id: int | None = None - enabled: bool = True - models: list[ModelSelection] = Field(default_factory=list) - - -class ModelTestPreview(ConnectionCreate): - model_id: str = Field(..., max_length=255) - - -class ConnectionUpdate(BaseModel): - provider: str | None = Field(None, max_length=100) - base_url: str | None = Field(None, max_length=500) - api_key: str | None = None - extra: dict[str, Any] | None = None - enabled: bool | None = None - - -class ModelCreate(BaseModel): - """Manually register a model id on a connection. - - For providers without a usable ``/models`` endpoint (Perplexity, MiniMax, - Azure deployments, etc.) or to pin a single model from a noisy provider. - """ - - model_id: str = Field(..., max_length=255) - display_name: str | None = Field(None, max_length=255) - - -class ModelUpdate(BaseModel): - display_name: str | None = Field(None, max_length=255) - enabled: bool | None = None - supports_chat: bool | None = None - max_input_tokens: int | None = None - supports_image_input: bool | None = None - supports_tools: bool | None = None - supports_image_generation: bool | None = None - capabilities_override: dict[str, Any] | None = None - - -class ModelsBulkUpdate(BaseModel): - model_ids: list[int] = Field(..., min_length=1, max_length=1000) - enabled: bool - - -class ModelProviderRead(BaseModel): - provider: str - transport: str - discovery: str - default_base_url: str | None = None - base_url_required: bool - auth_style: str - local_only: bool = False - - -class VerifyConnectionResponse(BaseModel): - status: str - ok: bool - message: str = "" - - -class ModelRolesRead(BaseModel): - chat_model_id: int | None = 0 - vision_model_id: int | None = 0 - image_gen_model_id: int | None = 0 - - -class ModelRolesUpdate(BaseModel): - chat_model_id: int | None = None - vision_model_id: int | None = None - image_gen_model_id: int | None = None diff --git a/surfsense_backend/app/schemas/new_llm_config.py b/surfsense_backend/app/schemas/new_llm_config.py new file mode 100644 index 000000000..716aa0457 --- /dev/null +++ b/surfsense_backend/app/schemas/new_llm_config.py @@ -0,0 +1,256 @@ +""" +Pydantic schemas for the NewLLMConfig API. + +NewLLMConfig combines model settings with prompt configuration: +- LLM provider, model, API key, etc. +- Configurable system instructions +- Citation toggle +""" + +import uuid +from datetime import datetime +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + +from app.db import LiteLLMProvider + + +class NewLLMConfigBase(BaseModel): + """Base schema with common fields for NewLLMConfig.""" + + name: str = Field( + ..., max_length=100, description="User-friendly name for the configuration" + ) + description: str | None = Field( + None, max_length=500, description="Optional description" + ) + + # Model Configuration + provider: LiteLLMProvider = Field(..., description="LiteLLM provider type") + custom_provider: str | None = Field( + None, max_length=100, description="Custom provider name when provider is CUSTOM" + ) + model_name: str = Field( + ..., max_length=100, description="Model name without provider prefix" + ) + api_key: str = Field(..., description="API key for the provider") + api_base: str | None = Field( + None, max_length=500, description="Optional API base URL" + ) + litellm_params: dict[str, Any] | None = Field( + default=None, description="Additional LiteLLM parameters" + ) + + # Prompt Configuration + system_instructions: str = Field( + default="", + description="Custom system instructions. Empty string uses default SURFSENSE_SYSTEM_INSTRUCTIONS.", + ) + use_default_system_instructions: bool = Field( + default=True, + description="Whether to use default instructions when system_instructions is empty", + ) + citations_enabled: bool = Field( + default=True, + description="Whether to include citation instructions in the system prompt", + ) + + +class NewLLMConfigCreate(NewLLMConfigBase): + """Schema for creating a new NewLLMConfig.""" + + search_space_id: int = Field( + ..., description="Search space ID to associate the config with" + ) + + +class NewLLMConfigUpdate(BaseModel): + """Schema for updating an existing NewLLMConfig. All fields are optional.""" + + name: str | None = Field(None, max_length=100) + description: str | None = Field(None, max_length=500) + + # Model Configuration + provider: LiteLLMProvider | None = None + custom_provider: str | None = Field(None, max_length=100) + model_name: str | None = Field(None, max_length=100) + api_key: str | None = None + api_base: str | None = Field(None, max_length=500) + litellm_params: dict[str, Any] | None = None + + # Prompt Configuration + system_instructions: str | None = None + use_default_system_instructions: bool | None = None + citations_enabled: bool | None = None + + +class NewLLMConfigRead(NewLLMConfigBase): + """Schema for reading a NewLLMConfig (includes id and timestamps).""" + + id: int + created_at: datetime + search_space_id: int + user_id: uuid.UUID + # Capability flag derived at the API boundary (no DB column). Default + # True matches the conservative-allow stance — a BYOK row that the + # route forgot to augment is not pre-judged. The streaming-task + # safety net is the only place a False actually blocks a request. + supports_image_input: bool = Field( + default=True, + description=( + "Whether the BYOK chat config can accept image inputs. Derived " + "at the route boundary from LiteLLM's authoritative model map " + "(``litellm.supports_vision``) — there is no DB column. " + "Default True is the conservative-allow stance for unknown / " + "unmapped models." + ), + ) + + model_config = ConfigDict(from_attributes=True) + + +class NewLLMConfigPublic(BaseModel): + """ + Public schema for NewLLMConfig that hides the API key. + Used when returning configs in list views or to users who shouldn't see keys. + """ + + id: int + name: str + description: str | None = None + + # Model Configuration (no api_key) + provider: LiteLLMProvider + custom_provider: str | None = None + model_name: str + api_base: str | None = None + litellm_params: dict[str, Any] | None = None + + # Prompt Configuration + system_instructions: str + use_default_system_instructions: bool + citations_enabled: bool + + created_at: datetime + search_space_id: int + user_id: uuid.UUID + # Capability flag derived at the API boundary (see NewLLMConfigRead). + supports_image_input: bool = Field( + default=True, + description=( + "Whether the BYOK chat config can accept image inputs. Derived " + "at the route boundary from LiteLLM's authoritative model map. " + "Default True is the conservative-allow stance." + ), + ) + + model_config = ConfigDict(from_attributes=True) + + +class DefaultSystemInstructionsResponse(BaseModel): + """Response schema for getting default system instructions.""" + + default_system_instructions: str = Field( + ..., description="The default SURFSENSE_SYSTEM_INSTRUCTIONS template" + ) + + +class GlobalNewLLMConfigRead(BaseModel): + """ + Schema for reading global LLM configs from YAML. + Global configs have negative IDs and no search_space_id. + API key is hidden for security. + + ID 0 is reserved for Auto mode which uses LiteLLM Router for load balancing. + """ + + id: int = Field( + ..., + description="Config ID: 0 for Auto mode, negative for global configs", + ) + name: str + description: str | None = None + + # Model Configuration (no api_key) + provider: str # String because YAML doesn't enforce enum, "AUTO" for Auto mode + custom_provider: str | None = None + model_name: str + api_base: str | None = None + litellm_params: dict[str, Any] | None = None + + # Prompt Configuration + system_instructions: str = "" + use_default_system_instructions: bool = True + citations_enabled: bool = True + + is_global: bool = True # Always true for global configs + is_auto_mode: bool = False # True only for Auto mode (ID 0) + + billing_tier: str = "free" + is_premium: bool = False + anonymous_enabled: bool = False + seo_enabled: bool = False + seo_slug: str | None = None + seo_title: str | None = None + seo_description: str | None = None + quota_reserve_tokens: int | None = None + supports_image_input: bool = Field( + default=True, + description=( + "Whether the model accepts image inputs (multimodal vision). " + "Derived server-side: OpenRouter dynamic configs use " + "``architecture.input_modalities``; YAML / BYOK use LiteLLM's " + "authoritative model map (``litellm.supports_vision``). The " + "new-chat selector hints with a 'No image' badge when this is " + "False and there are pending image attachments. The streaming " + "task fails fast only when LiteLLM *explicitly* marks a model " + "as text-only — unknown / unmapped models default-allow." + ), + ) + + +# ============================================================================= +# LLM Preferences Schemas (for role assignments) +# ============================================================================= + + +class LLMPreferencesRead(BaseModel): + """Schema for reading LLM preferences (role assignments) for a search space.""" + + agent_llm_id: int | None = Field( + None, description="ID of the LLM config to use for agent/chat tasks" + ) + image_generation_config_id: int | None = Field( + None, description="ID of the image generation config to use" + ) + vision_llm_config_id: int | None = Field( + None, + description="ID of the vision LLM config to use for vision/screenshot analysis", + ) + agent_llm: dict[str, Any] | None = Field( + None, description="Full config for agent LLM" + ) + image_generation_config: dict[str, Any] | None = Field( + None, description="Full config for image generation" + ) + vision_llm_config: dict[str, Any] | None = Field( + None, description="Full config for vision LLM" + ) + + model_config = ConfigDict(from_attributes=True) + + +class LLMPreferencesUpdate(BaseModel): + """Schema for updating LLM preferences.""" + + agent_llm_id: int | None = Field( + None, description="ID of the LLM config to use for agent/chat tasks" + ) + image_generation_config_id: int | None = Field( + None, description="ID of the image generation config to use" + ) + vision_llm_config_id: int | None = Field( + None, + description="ID of the vision LLM config to use for vision/screenshot analysis", + ) diff --git a/surfsense_backend/app/schemas/vision_llm.py b/surfsense_backend/app/schemas/vision_llm.py new file mode 100644 index 000000000..d0eeaf5c6 --- /dev/null +++ b/surfsense_backend/app/schemas/vision_llm.py @@ -0,0 +1,116 @@ +import uuid +from datetime import datetime +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + +from app.db import VisionProvider + + +class VisionLLMConfigBase(BaseModel): + name: str = Field(..., max_length=100) + description: str | None = Field(None, max_length=500) + provider: VisionProvider = Field(...) + custom_provider: str | None = Field(None, max_length=100) + model_name: str = Field(..., max_length=100) + api_key: str = Field(...) + api_base: str | None = Field(None, max_length=500) + api_version: str | None = Field(None, max_length=50) + litellm_params: dict[str, Any] | None = Field(default=None) + + +class VisionLLMConfigCreate(VisionLLMConfigBase): + search_space_id: int = Field(...) + + +class VisionLLMConfigUpdate(BaseModel): + name: str | None = Field(None, max_length=100) + description: str | None = Field(None, max_length=500) + provider: VisionProvider | None = None + custom_provider: str | None = Field(None, max_length=100) + model_name: str | None = Field(None, max_length=100) + api_key: str | None = None + api_base: str | None = Field(None, max_length=500) + api_version: str | None = Field(None, max_length=50) + litellm_params: dict[str, Any] | None = None + + +class VisionLLMConfigRead(VisionLLMConfigBase): + id: int + created_at: datetime + search_space_id: int + user_id: uuid.UUID + + model_config = ConfigDict(from_attributes=True) + + +class VisionLLMConfigPublic(BaseModel): + id: int + name: str + description: str | None = None + provider: VisionProvider + custom_provider: str | None = None + model_name: str + api_base: str | None = None + api_version: str | None = None + litellm_params: dict[str, Any] | None = None + created_at: datetime + search_space_id: int + user_id: uuid.UUID + + model_config = ConfigDict(from_attributes=True) + + +class GlobalVisionLLMConfigRead(BaseModel): + """Schema for reading global vision LLM configs from YAML. + + The ``billing_tier`` field allows the frontend to show a Premium/Free + badge and (more importantly) tells the backend whether to debit the + user's premium credit pool when this config is used. ``"free"`` is + the default for backward compatibility — admins must explicitly opt + a global config into ``"premium"``. + """ + + id: int = Field(...) + name: str + description: str | None = None + provider: str + custom_provider: str | None = None + model_name: str + api_base: str | None = None + api_version: str | None = None + litellm_params: dict[str, Any] | None = None + is_global: bool = True + is_auto_mode: bool = False + billing_tier: str = Field( + default="free", + description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).", + ) + is_premium: bool = Field( + default=False, + description=( + "Convenience boolean derived server-side from " + "``billing_tier == 'premium'``. The new-chat model selector " + "keys its Free/Premium badge off this field for parity with " + "chat (`GlobalLLMConfigRead.is_premium`)." + ), + ) + quota_reserve_tokens: int | None = Field( + default=None, + description=( + "Optional override for the per-call reservation in *tokens* — " + "converted to micro-USD via the model's input/output prices at " + "reservation time. Falls back to QUOTA_DEFAULT_RESERVE_TOKENS." + ), + ) + input_cost_per_token: float | None = Field( + default=None, + description=( + "Optional input price in USD/token. Used by pricing_registration to " + "register custom Azure / OpenRouter aliases with LiteLLM at startup." + ), + ) + output_cost_per_token: float | None = Field( + default=None, + description="Optional output price in USD/token. Pair with input_cost_per_token.", + ) diff --git a/surfsense_backend/app/services/ai_file_sort_service.py b/surfsense_backend/app/services/ai_file_sort_service.py index 1bf4d325e..2f04131a6 100644 --- a/surfsense_backend/app/services/ai_file_sort_service.py +++ b/surfsense_backend/app/services/ai_file_sort_service.py @@ -156,7 +156,7 @@ async def _resolve_document_text( stmt = ( select(Chunk.content) .where(Chunk.document_id == document.id) - .order_by(Chunk.position, Chunk.id) + .order_by(Chunk.id) .limit(_MAX_CHUNKS_FOR_CONTEXT) ) result = await session.execute(stmt) diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py index f98933a65..c9fd8c315 100644 --- a/surfsense_backend/app/services/auto_model_pin_service.py +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -1,13 +1,13 @@ -"""Resolve and persist Auto model pins per chat thread. +"""Resolve and persist Auto (Fastest) model pins per chat thread. -Auto is represented by ``chat_model_id == 0``. For chat threads we -resolve that virtual mode to one concrete global model exactly once and +Auto (Fastest) is represented by ``agent_llm_id == 0``. For chat threads we +resolve that virtual mode to one concrete global LLM config exactly once and persist the chosen config id on ``new_chat_threads.pinned_llm_config_id`` so subsequent turns are stable. Single-writer invariant: this module is the only writer of ``NewChatThread.pinned_llm_config_id`` (aside from the bulk clear in -``model_connections_routes`` when a search space's ``chat_model_id`` changes). +``search_spaces_routes`` when a search space's ``agent_llm_id`` changes). Therefore a non-NULL value unambiguously means "this thread has an Auto-resolved pin"; no separate source/policy column is needed. """ @@ -21,35 +21,26 @@ import time from dataclasses import dataclass from uuid import UUID -import redis from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import selectinload from app.config import config -from app.db import Connection, Model, NewChatThread -from app.services.model_capabilities import has_capability +from app.db import NewChatThread from app.services.quality_score import _QUALITY_TOP_K from app.services.token_quota_service import TokenQuotaService logger = logging.getLogger(__name__) -AUTO_MODE_ID = 0 -# Stable internal hash namespace for deterministic per-thread selection. -# Do not rename: changing this rebalances Auto's model choice for new pins. -AUTO_PIN_HASH_NAMESPACE = "auto_fastest" +AUTO_FASTEST_ID = 0 +AUTO_FASTEST_MODE = "auto_fastest" _RUNTIME_COOLDOWN_SECONDS = 600 _HEALTHY_TTL_SECONDS = 45 -_RUNTIME_COOLDOWN_REDIS_KEY_PREFIX = "auto:cooldown:llm:" -_REDIS_TIMEOUT_SECONDS = 0.2 # In-memory runtime cooldown map for configs that recently hard-failed at # provider runtime (e.g. OpenRouter 429 on a pinned free model). This keeps # the same unhealthy config from being reselected immediately during repair. _runtime_cooldown_until: dict[int, float] = {} _runtime_cooldown_lock = threading.Lock() -_runtime_cooldown_redis: redis.Redis | None = None -_runtime_cooldown_redis_lock = threading.Lock() # Short-TTL "recently healthy" cache for configs that just passed a runtime # preflight ping. Lets back-to-back turns on the same model skip the probe @@ -70,15 +61,11 @@ def _is_usable_global_config(cfg: dict) -> bool: return bool( cfg.get("id") is not None and cfg.get("model_name") - and (cfg.get("provider") or cfg.get("litellm_provider")) + and cfg.get("provider") and cfg.get("api_key") ) -def _has_capability(model: dict | Model, capability: str) -> bool: - return has_capability(model, capability) - - def _prune_runtime_cooldowns(now_ts: float | None = None) -> None: now = time.time() if now_ts is None else now_ts stale = [cid for cid, until in _runtime_cooldown_until.items() if until <= now] @@ -92,81 +79,6 @@ def _is_runtime_cooled_down(config_id: int) -> bool: return config_id in _runtime_cooldown_until -def _runtime_cooldown_redis_key(config_id: int) -> str: - return f"{_RUNTIME_COOLDOWN_REDIS_KEY_PREFIX}{int(config_id)}" - - -def _get_runtime_cooldown_redis() -> redis.Redis: - global _runtime_cooldown_redis - if _runtime_cooldown_redis is None: - with _runtime_cooldown_redis_lock: - if _runtime_cooldown_redis is None: - _runtime_cooldown_redis = redis.from_url( - config.REDIS_APP_URL, - decode_responses=True, - socket_connect_timeout=_REDIS_TIMEOUT_SECONDS, - socket_timeout=_REDIS_TIMEOUT_SECONDS, - ) - return _runtime_cooldown_redis - - -def _mark_shared_runtime_cooldown( - config_id: int, - *, - reason: str, - cooldown_seconds: int, -) -> None: - try: - _get_runtime_cooldown_redis().set( - _runtime_cooldown_redis_key(config_id), - reason, - ex=int(cooldown_seconds), - ) - except Exception: - logger.warning( - "auto_pin_runtime_cooldown_redis_write_failed config_id=%s", - config_id, - exc_info=True, - ) - - -def _shared_runtime_cooled_down_ids(config_ids: list[int]) -> set[int]: - unique_ids = list(dict.fromkeys(int(cid) for cid in config_ids)) - if not unique_ids: - return set() - try: - values = _get_runtime_cooldown_redis().mget( - [_runtime_cooldown_redis_key(cid) for cid in unique_ids] - ) - except Exception: - logger.warning( - "auto_pin_runtime_cooldown_redis_read_failed count=%s", - len(unique_ids), - exc_info=True, - ) - return set() - return { - cid for cid, value in zip(unique_ids, values, strict=False) if value is not None - } - - -def _clear_shared_runtime_cooldown(config_id: int | None = None) -> None: - try: - client = _get_runtime_cooldown_redis() - if config_id is not None: - client.delete(_runtime_cooldown_redis_key(config_id)) - return - keys = list(client.scan_iter(f"{_RUNTIME_COOLDOWN_REDIS_KEY_PREFIX}*")) - if keys: - client.delete(*keys) - except Exception: - logger.warning( - "auto_pin_runtime_cooldown_redis_clear_failed config_id=%s", - config_id, - exc_info=True, - ) - - def mark_runtime_cooldown( config_id: int, *, @@ -185,11 +97,6 @@ def mark_runtime_cooldown( with _runtime_cooldown_lock: _runtime_cooldown_until[int(config_id)] = until _prune_runtime_cooldowns() - _mark_shared_runtime_cooldown( - int(config_id), - reason=reason, - cooldown_seconds=int(cooldown_seconds), - ) # A cooled cfg can never be "recently healthy"; drop any stale credit so # the next turn that resolves to it (after cooldown) re-runs preflight. clear_healthy(int(config_id)) @@ -206,9 +113,8 @@ def clear_runtime_cooldown(config_id: int | None = None) -> None: with _runtime_cooldown_lock: if config_id is None: _runtime_cooldown_until.clear() - else: - _runtime_cooldown_until.pop(int(config_id), None) - _clear_shared_runtime_cooldown(config_id) + return + _runtime_cooldown_until.pop(int(config_id), None) def _prune_healthy(now_ts: float | None = None) -> None: @@ -280,20 +186,15 @@ def _cfg_supports_image_input(cfg: dict) -> bool: else None ) return derive_supports_image_input( - provider=cfg.get("provider") or cfg.get("litellm_provider"), + provider=cfg.get("provider"), model_name=cfg.get("model_name"), base_model=base_model, custom_provider=cfg.get("custom_provider"), ) -def _global_candidates( - *, - capability: str = "chat", - requires_image_input: bool = False, - shared_cooled_down_ids: set[int] | None = None, -) -> list[dict]: - """Return Auto-eligible global virtual models. +def _global_candidates(*, requires_image_input: bool = False) -> list[dict]: + """Return Auto-eligible global cfgs. Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers @@ -304,167 +205,30 @@ def _global_candidates( filters out configs whose ``supports_image_input`` resolves to False so a text-only deployment can't be pinned for an image request. """ - connection_by_id = { - int(conn.get("id")): conn - for conn in config.GLOBAL_CONNECTIONS - if conn.get("id") is not None - } - config_by_model_name = { - cfg.get("model_name"): cfg + candidates = [ + cfg for cfg in config.GLOBAL_LLM_CONFIGS if _is_usable_global_config(cfg) - } - candidates: list[dict] = [] - shared_cooled_down_ids = shared_cooled_down_ids or set() - for model in config.GLOBAL_MODELS: - model_id = int(model.get("id", 0)) - if ( - model_id >= 0 - or _is_runtime_cooled_down(model_id) - or model_id in shared_cooled_down_ids - ): - continue - if not _has_capability(model, capability): - continue - cfg = config_by_model_name.get(model.get("model_id")) or {} - if cfg.get("health_gated"): - continue - if requires_image_input and not _has_capability(model, "vision"): - continue - if requires_image_input and cfg and not _cfg_supports_image_input(cfg): - continue - connection = connection_by_id.get(int(model.get("connection_id", 0))) - if not connection: - continue - catalog = model.get("catalog") or {} - candidates.append( - { - "id": model_id, - "model_id": model.get("model_id"), - "source": "global", - "connection": connection, - "supports_chat": model.get("supports_chat"), - "supports_image_input": model.get("supports_image_input"), - "supports_tools": model.get("supports_tools"), - "supports_image_generation": model.get("supports_image_generation"), - "capabilities_override": model.get("capabilities_override") or {}, - "billing_tier": model.get("billing_tier", "free"), - "provider": connection.get("provider"), - "model_name": model.get("model_id"), - "auto_pin_tier": catalog.get("auto_pin_tier") - or cfg.get("auto_pin_tier") - or "A", - "quality_score": catalog.get("quality_score") - or cfg.get("quality_score") - or cfg.get("quality_score_static") - or 50, - } - ) - return sorted(candidates, key=lambda c: int(c.get("id", 0))) - - -async def _db_candidates( - session: AsyncSession, - *, - search_space_id: int, - user_id: str | UUID | None, - capability: str, - requires_image_input: bool = False, -) -> list[dict]: - parsed_user_id = _to_uuid(user_id) - stmt = ( - select(Model) - .options(selectinload(Model.connection)) - .join(Connection, Model.connection_id == Connection.id) - .where(Model.enabled.is_(True), Connection.enabled.is_(True)) - ) - result = await session.execute(stmt) - models = result.scalars().all() - shared_cooled_down_ids = _shared_runtime_cooled_down_ids( - [int(model.id) for model in models] - ) - candidates: list[dict] = [] - for model in models: - conn = model.connection - if not conn: - continue - if conn.search_space_id is not None and conn.search_space_id != search_space_id: - continue - if ( - conn.user_id is not None - and parsed_user_id is not None - and conn.user_id != parsed_user_id - ): - continue - if conn.user_id is not None and parsed_user_id is None: - continue - if not _has_capability(model, capability): - continue - if requires_image_input and not _has_capability(model, "vision"): - continue - model_id = int(model.id) - if _is_runtime_cooled_down(model_id) or model_id in shared_cooled_down_ids: - continue - catalog = model.catalog or {} - candidates.append( - { - "id": model_id, - "model_id": model.model_id, - "source": "db", - "connection": conn, - "supports_chat": model.supports_chat, - "supports_image_input": model.supports_image_input, - "supports_tools": model.supports_tools, - "supports_image_generation": model.supports_image_generation, - "capabilities_override": model.capabilities_override or {}, - "billing_tier": "byok", - "provider": conn.provider, - "model_name": model.model_id, - "auto_pin_tier": catalog.get("auto_pin_tier") or "BYOK", - "quality_score": catalog.get("quality_score") or 75, - } - ) - return sorted(candidates, key=lambda c: int(c.get("id", 0))) - - -async def auto_model_candidates( - session: AsyncSession, - *, - search_space_id: int, - user_id: str | UUID | None, - capability: str, - requires_image_input: bool = False, - exclude_model_ids: set[int] | None = None, -) -> list[dict]: - excluded_ids = {int(mid) for mid in (exclude_model_ids or set())} - global_ids = [ - int(model.get("id", 0)) - for model in config.GLOBAL_MODELS - if int(model.get("id", 0)) < 0 + and not cfg.get("health_gated") + and not _is_runtime_cooled_down(int(cfg.get("id", 0))) + and (not requires_image_input or _cfg_supports_image_input(cfg)) ] - shared_global_cooled_down_ids = _shared_runtime_cooled_down_ids(global_ids) - db_candidates = await _db_candidates( - session, - search_space_id=search_space_id, - user_id=user_id, - capability=capability, - requires_image_input=requires_image_input, - ) - candidates = [ - *_global_candidates( - capability=capability, - requires_image_input=requires_image_input, - shared_cooled_down_ids=shared_global_cooled_down_ids, - ), - *db_candidates, - ] - return [c for c in candidates if int(c.get("id", 0)) not in excluded_ids] + return sorted(candidates, key=lambda c: int(c.get("id", 0))) def _tier_of(cfg: dict) -> str: return str(cfg.get("billing_tier", "free")).lower() +def _is_preferred_premium_auto_config(cfg: dict) -> bool: + """Return True for the operator-preferred premium Auto model.""" + return ( + _tier_of(cfg) == "premium" + and str(cfg.get("provider", "")).upper() == "AZURE_OPENAI" + and str(cfg.get("model_name", "")).lower() == "gpt-5.4" + ) + + def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]: """Pick a config with quality-first ranking + deterministic spread. @@ -482,16 +246,11 @@ def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]: pool = tier_a if tier_a else eligible pool = sorted(pool, key=lambda c: -int(c.get("quality_score") or 0)) top_k = pool[:_QUALITY_TOP_K] - digest = hashlib.sha256(f"{AUTO_PIN_HASH_NAMESPACE}:{thread_id}".encode()).digest() + digest = hashlib.sha256(f"{AUTO_FASTEST_MODE}:{thread_id}".encode()).digest() idx = int.from_bytes(digest[:8], "big") % len(top_k) return top_k[idx], len(top_k) -def choose_auto_model_candidate(candidates: list[dict], seed_id: int) -> dict: - selected, _ = _select_pin(candidates, seed_id) - return selected - - def _to_uuid(user_id: str | UUID | None) -> UUID | None: if user_id is None: return None @@ -524,7 +283,7 @@ async def resolve_or_get_pinned_llm_config_id( exclude_config_ids: set[int] | None = None, requires_image_input: bool = False, ) -> AutoPinResolution: - """Resolve Auto to one concrete config id and persist the pin. + """Resolve Auto (Fastest) to one concrete config id and persist the pin. For non-auto selections, this function clears any existing pin and returns the selected id as-is. @@ -556,7 +315,7 @@ async def resolve_or_get_pinned_llm_config_id( ) # Explicit model selected: clear any stale pin. - if selected_llm_config_id != AUTO_MODE_ID: + if selected_llm_config_id != AUTO_FASTEST_ID: if thread.pinned_llm_config_id is not None: thread.pinned_llm_config_id = None await session.commit() @@ -567,21 +326,20 @@ async def resolve_or_get_pinned_llm_config_id( ) excluded_ids = {int(cid) for cid in (exclude_config_ids or set())} - candidates = await auto_model_candidates( - session, - search_space_id=search_space_id, - user_id=user_id, - capability="chat", - requires_image_input=requires_image_input, - exclude_model_ids=excluded_ids, - ) + candidates = [ + c + for c in _global_candidates(requires_image_input=requires_image_input) + if int(c.get("id", 0)) not in excluded_ids + ] if not candidates: if requires_image_input: # Distinguish the "no vision-capable cfg" case from generic # "no usable cfg" so the streaming task can map this to the # MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT SSE error. - raise ValueError("No vision-capable LLM models are available for Auto mode") - raise ValueError("No usable LLM models are available for Auto mode") + raise ValueError( + "No vision-capable global LLM configs are available for Auto mode" + ) + raise ValueError("No usable global LLM configs are available for Auto mode") candidate_by_id = {int(c["id"]): c for c in candidates} # Reuse an existing valid pin without re-checking current quota (no silent @@ -621,13 +379,24 @@ async def resolve_or_get_pinned_llm_config_id( # log that explicitly so operators can correlate the re-pin with # the user's image attachment instead of suspecting a cooldown. if requires_image_input: - logger.info( - "auto_pin_repinned_for_image thread_id=%s search_space_id=%s " - "previous_config_id=%s", - thread_id, - search_space_id, - pinned_id, - ) + try: + pinned_global = next( + c + for c in config.GLOBAL_LLM_CONFIGS + if int(c.get("id", 0)) == int(pinned_id) + ) + except StopIteration: + pinned_global = None + if pinned_global is not None and not _cfg_supports_image_input( + pinned_global + ): + logger.info( + "auto_pin_repinned_for_image thread_id=%s search_space_id=%s " + "previous_config_id=%s", + thread_id, + search_space_id, + pinned_id, + ) logger.info( "auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s", thread_id, @@ -638,10 +407,12 @@ async def resolve_or_get_pinned_llm_config_id( premium_eligible = ( False if force_repin_free else await _is_premium_eligible(session, user_id) ) - byok_candidates = [c for c in candidates if _tier_of(c) == "byok"] if premium_eligible: premium_candidates = [c for c in candidates if _tier_of(c) == "premium"] - eligible = premium_candidates or byok_candidates + preferred_premium = [ + c for c in premium_candidates if _is_preferred_premium_auto_config(c) + ] + eligible = preferred_premium or premium_candidates else: eligible = [c for c in candidates if _tier_of(c) != "premium"] diff --git a/surfsense_backend/app/services/billable_calls.py b/surfsense_backend/app/services/billable_calls.py index 15a3c3e55..919c49a21 100644 --- a/surfsense_backend/app/services/billable_calls.py +++ b/surfsense_backend/app/services/billable_calls.py @@ -445,15 +445,15 @@ async def _resolve_agent_billing_for_search_space( thread_id: int | None = None, ) -> tuple[UUID, str, str]: """Resolve ``(owner_user_id, billing_tier, base_model)`` for the search-space - chat model. + agent LLM. Used by Celery tasks (podcast generation, video presentation) to bill the - search-space owner's premium credit pool when the chat model is premium. + search-space owner's premium credit pool when the agent LLM is premium. - Resolution rules mirror the chat model role resolver: + Resolution rules mirror chat at ``stream_new_chat.py:2294-2351``: - - Search space not found / no ``chat_model_id``: raise ``ValueError``. - - **Auto mode** (``id == AUTO_MODE_ID == 0``): + - Search space not found / no ``agent_llm_id``: raise ``ValueError``. + - **Auto mode** (``id == AUTO_FASTEST_ID == 0``): * ``thread_id`` is set: delegate to ``resolve_or_get_pinned_llm_config_id`` (the same call chat uses) and recurse into the resolved id. Reuses chat's existing pin if present @@ -469,8 +469,9 @@ async def _resolve_agent_billing_for_search_space( (defaults to ``"free"`` via ``app/config/__init__.py:52`` setdefault), ``base_model = litellm_params.get("base_model") or model_name`` — NOT provider-prefixed, matching chat's cost-map lookup convention. - - **Positive id** (user BYOK ``Model``): always free; ``base_model`` from - the model catalog override or the upstream ``model_id``. + - **Positive id** (user BYOK ``NewLLMConfig``): always free (matches + ``AgentConfig.from_new_llm_config`` which hard-codes ``billing_tier="free"``); + ``base_model`` from ``litellm_params`` or ``model_name``. Note on imports: ``llm_service``, ``auto_model_pin_service``, and ``llm_router_service`` are imported lazily inside the function body to @@ -479,9 +480,8 @@ async def _resolve_agent_billing_for_search_space( ``billable_calls.py``'s module load path. """ from sqlalchemy import select - from sqlalchemy.orm import selectinload - from app.db import Model, SearchSpace + from app.db import NewLLMConfig, SearchSpace result = await session.execute( select(SearchSpace).where(SearchSpace.id == search_space_id) @@ -490,20 +490,20 @@ async def _resolve_agent_billing_for_search_space( if search_space is None: raise ValueError(f"Search space {search_space_id} not found") - chat_model_id = search_space.chat_model_id - if chat_model_id is None: + agent_llm_id = search_space.agent_llm_id + if agent_llm_id is None: raise ValueError( - f"Search space {search_space_id} has no chat_model_id configured" + f"Search space {search_space_id} has no agent_llm_id configured" ) owner_user_id: UUID = search_space.user_id from app.services.auto_model_pin_service import ( - AUTO_MODE_ID, + AUTO_FASTEST_ID, resolve_or_get_pinned_llm_config_id, ) - if chat_model_id == AUTO_MODE_ID: + if agent_llm_id == AUTO_FASTEST_ID: if thread_id is None: return owner_user_id, "free", "auto" try: @@ -512,7 +512,7 @@ async def _resolve_agent_billing_for_search_space( thread_id=thread_id, search_space_id=search_space_id, user_id=str(owner_user_id), - selected_llm_config_id=AUTO_MODE_ID, + selected_llm_config_id=AUTO_FASTEST_ID, ) except ValueError: logger.warning( @@ -523,35 +523,28 @@ async def _resolve_agent_billing_for_search_space( exc_info=True, ) return owner_user_id, "free", "auto" - chat_model_id = resolution.resolved_llm_config_id + agent_llm_id = resolution.resolved_llm_config_id - if chat_model_id < 0: + if agent_llm_id < 0: from app.services.llm_service import get_global_llm_config - cfg = get_global_llm_config(chat_model_id) or {} + cfg = get_global_llm_config(agent_llm_id) or {} billing_tier = str(cfg.get("billing_tier", "free")).lower() litellm_params = cfg.get("litellm_params") or {} base_model = litellm_params.get("base_model") or cfg.get("model_name") or "" return owner_user_id, billing_tier, base_model - model_result = await session.execute( - select(Model) - .options(selectinload(Model.connection)) - .where(Model.id == chat_model_id, Model.enabled.is_(True)) - ) - model = model_result.scalars().first() - base_model = "" - if ( - model is not None - and model.connection is not None - and model.connection.enabled - and ( - model.connection.search_space_id in (None, search_space_id) - and model.connection.user_id in (None, owner_user_id) + nlc_result = await session.execute( + select(NewLLMConfig).where( + NewLLMConfig.id == agent_llm_id, + NewLLMConfig.search_space_id == search_space_id, ) - ): - catalog = model.catalog or {} - base_model = catalog.get("base_model") or model.model_id or "" + ) + nlc = nlc_result.scalars().first() + base_model = "" + if nlc is not None: + litellm_params = nlc.litellm_params or {} + base_model = litellm_params.get("base_model") or nlc.model_name or "" return owner_user_id, "free", base_model diff --git a/surfsense_backend/app/services/export_service.py b/surfsense_backend/app/services/export_service.py index 9e6869fe1..97f952223 100644 --- a/surfsense_backend/app/services/export_service.py +++ b/surfsense_backend/app/services/export_service.py @@ -62,7 +62,7 @@ async def _get_document_markdown( chunk_result = await session.execute( select(Chunk.content) .filter(Chunk.document_id == document.id) - .order_by(Chunk.position, Chunk.id) + .order_by(Chunk.id) ) chunks = chunk_result.scalars().all() if chunks: diff --git a/surfsense_backend/app/services/global_model_catalog.py b/surfsense_backend/app/services/global_model_catalog.py deleted file mode 100644 index 1bcc99215..000000000 --- a/surfsense_backend/app/services/global_model_catalog.py +++ /dev/null @@ -1,128 +0,0 @@ -"""Materialize server-owned GLOBAL YAML configs as virtual connections/models.""" - -from __future__ import annotations - -from typing import Any - -from app.services.model_resolver import native_connection_from_config - - -def _base_model(config: dict[str, Any]) -> str | None: - litellm_params = config.get("litellm_params") or {} - if isinstance(litellm_params, dict): - return litellm_params.get("base_model") - return None - - -def _connection_key(conn: dict[str, Any]) -> tuple[Any, ...]: - # Deliberately includes api_key because two operator-owned credentials for - # the same provider/base can have different quota/rate limits upstream. - return ( - conn.get("provider"), - conn.get("base_url"), - conn.get("api_key"), - _freeze(conn.get("extra") or {}), - ) - - -def _freeze(value: Any) -> Any: - if isinstance(value, dict): - return tuple(sorted((key, _freeze(val)) for key, val in value.items())) - if isinstance(value, list): - return tuple(_freeze(item) for item in value) - return value - - -def _catalog_metadata(config: dict[str, Any]) -> dict[str, Any]: - return { - "billing_tier": config.get("billing_tier", "free"), - "quota_reserve_tokens": config.get("quota_reserve_tokens"), - "rpm": config.get("rpm"), - "tpm": config.get("tpm"), - "anonymous_enabled": config.get("anonymous_enabled", False), - "seo_enabled": config.get("seo_enabled", False), - "seo_slug": config.get("seo_slug"), - "input_cost_per_token": (config.get("litellm_params") or {}).get( - "input_cost_per_token" - ) - if isinstance(config.get("litellm_params"), dict) - else None, - "output_cost_per_token": (config.get("litellm_params") or {}).get( - "output_cost_per_token" - ) - if isinstance(config.get("litellm_params"), dict) - else None, - "is_planner": config.get("is_planner", False), - "base_model": _base_model(config), - "router_pool_eligible": config.get("router_pool_eligible", True), - } - - -def materialize_global_model_catalog( - *, - chat_configs: list[dict[str, Any]], - image_configs: list[dict[str, Any]], -) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: - connections: list[dict[str, Any]] = [] - models: list[dict[str, Any]] = [] - connection_id_by_key: dict[tuple[Any, ...], int] = {} - next_connection_id = -1 - - def add_config(config: dict[str, Any], role: str) -> None: - nonlocal next_connection_id - if not config.get("id") or not config.get("model_name"): - return - conn = native_connection_from_config(config) - conn["scope"] = "GLOBAL" - conn["enabled"] = True - key = _connection_key(conn) - connection_id = connection_id_by_key.get(key) - if connection_id is None: - connection_id = next_connection_id - next_connection_id -= 1 - connection_id_by_key[key] = connection_id - connections.append( - { - "id": connection_id, - **conn, - } - ) - - model_id = int(config["id"]) - models.append( - { - "id": model_id, - "connection_id": connection_id, - "model_id": config["model_name"], - "display_name": config.get("name") or config["model_name"], - "source": "MANUAL", - "supports_chat": role == "chat", - "max_input_tokens": config.get("max_input_tokens"), - "supports_image_input": ( - role == "chat" and bool(config.get("supports_image_input")) - ), - "supports_tools": bool(config.get("supports_tools", False)), - "supports_image_generation": role == "image_gen", - "capabilities_override": {}, - "enabled": True, - "billing_tier": config.get("billing_tier", "free"), - "catalog": _catalog_metadata(config), - "role": role, - } - ) - - for cfg in chat_configs: - if cfg.get("is_auto_mode"): - continue - add_config(cfg, "chat") - for cfg in image_configs: - if cfg.get("is_auto_mode"): - continue - add_config(cfg, "image_gen") - - # Each virtual connection is server-only. Callers that serialize these - # must strip api_key before returning data to clients. - return connections, models - - -__all__ = ["materialize_global_model_catalog"] diff --git a/surfsense_backend/app/services/image_gen_router_service.py b/surfsense_backend/app/services/image_gen_router_service.py index 241b3bc53..b4de2a0bf 100644 --- a/surfsense_backend/app/services/image_gen_router_service.py +++ b/surfsense_backend/app/services/image_gen_router_service.py @@ -20,13 +20,28 @@ from typing import Any from litellm import Router from litellm.utils import ImageResponse -from app.services.model_resolver import native_connection_from_config, to_litellm +from app.services.provider_api_base import resolve_api_base logger = logging.getLogger(__name__) # Special ID for Auto mode - uses router for load balancing IMAGE_GEN_AUTO_MODE_ID = 0 +# Provider mapping for LiteLLM model string construction. +# Only includes providers that support image generation. +# See: https://docs.litellm.ai/docs/image_generation#supported-providers +IMAGE_GEN_PROVIDER_MAP = { + "OPENAI": "openai", + "AZURE_OPENAI": "azure", + "GOOGLE": "gemini", # Google AI Studio + "VERTEX_AI": "vertex_ai", + "BEDROCK": "bedrock", # AWS Bedrock + "RECRAFT": "recraft", + "OPENROUTER": "openrouter", + "XINFERENCE": "xinference", + "NSCALE": "nscale", +} + class ImageGenRouterService: """ @@ -138,11 +153,38 @@ class ImageGenRouterService: if not config.get("model_name") or not config.get("api_key"): return None - model_string, resolved_kwargs = to_litellm( - native_connection_from_config(config), - config["model_name"], + # Build model string + provider = config.get("provider", "").upper() + if config.get("custom_provider"): + provider_prefix = config["custom_provider"] + else: + provider_prefix = IMAGE_GEN_PROVIDER_MAP.get(provider, provider.lower()) + model_string = f"{provider_prefix}/{config['model_name']}" + + # Build litellm params + litellm_params: dict[str, Any] = { + "model": model_string, + "api_key": config.get("api_key"), + } + + # Resolve ``api_base`` so deployments don't silently inherit + # ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE`` and 404 against + # the wrong provider (see ``provider_api_base`` docstring). + api_base = resolve_api_base( + provider=provider, + provider_prefix=provider_prefix, + config_api_base=config.get("api_base"), ) - litellm_params: dict[str, Any] = {"model": model_string, **resolved_kwargs} + if api_base: + litellm_params["api_base"] = api_base + + # Add api_version (required for Azure) + if config.get("api_version"): + litellm_params["api_version"] = config["api_version"] + + # Add any additional litellm parameters + if config.get("litellm_params"): + litellm_params.update(config["litellm_params"]) # All configs use same alias "auto" for unified routing deployment: dict[str, Any] = { diff --git a/surfsense_backend/app/services/llm_error_adapter.py b/surfsense_backend/app/services/llm_error_adapter.py deleted file mode 100644 index 8d451ee96..000000000 --- a/surfsense_backend/app/services/llm_error_adapter.py +++ /dev/null @@ -1,257 +0,0 @@ -"""Normalize provider/LLM exceptions into low-cardinality product categories.""" - -from __future__ import annotations - -import json -from dataclasses import dataclass -from enum import StrEnum -from typing import Any - - -class LLMErrorCategory(StrEnum): - RATE_LIMITED = "rate_limited" - TIMEOUT = "timeout" - PROVIDER_UNAVAILABLE = "provider_unavailable" - BAD_GATEWAY = "bad_gateway" - CONNECTION_FAILED = "connection_failed" - AUTH_FAILED = "auth_failed" - PERMISSION_DENIED = "permission_denied" - MODEL_NOT_FOUND = "model_not_found" - BAD_REQUEST = "bad_request" - CONTEXT_LIMIT = "context_limit" - RESPONSE_INVALID = "response_invalid" - SERVER_ERROR = "server_error" - UNKNOWN = "unknown" - - -@dataclass(frozen=True) -class LLMErrorAdaptation: - category: LLMErrorCategory - retryable: bool - user_message: str - provider_status_code: int | None = None - provider_error_type: str | None = None - - -_CATEGORY_MESSAGES: dict[LLMErrorCategory, str] = { - LLMErrorCategory.RATE_LIMITED: "LLM rate limit exceeded. Will retry on next sync.", - LLMErrorCategory.TIMEOUT: "LLM request timed out. Will retry on next sync.", - LLMErrorCategory.PROVIDER_UNAVAILABLE: "LLM service temporarily unavailable. Will retry on next sync.", - LLMErrorCategory.BAD_GATEWAY: "LLM gateway error. Will retry on next sync.", - LLMErrorCategory.CONNECTION_FAILED: "Could not reach the LLM service. Check network connectivity.", - LLMErrorCategory.AUTH_FAILED: "LLM authentication failed. Check your API key.", - LLMErrorCategory.PERMISSION_DENIED: "LLM request denied. Check your account permissions.", - LLMErrorCategory.MODEL_NOT_FOUND: "Model not found. Check your model configuration.", - LLMErrorCategory.BAD_REQUEST: "LLM rejected the request. Document content may be invalid.", - LLMErrorCategory.CONTEXT_LIMIT: "Document exceeds the LLM context window even after optimization.", - LLMErrorCategory.RESPONSE_INVALID: "LLM returned an invalid response.", - LLMErrorCategory.SERVER_ERROR: "LLM internal server error. Will retry on next sync.", - LLMErrorCategory.UNKNOWN: "Something went wrong when calling the LLM.", -} - -_RETRYABLE_CATEGORIES = { - LLMErrorCategory.RATE_LIMITED, - LLMErrorCategory.TIMEOUT, - LLMErrorCategory.PROVIDER_UNAVAILABLE, - LLMErrorCategory.BAD_GATEWAY, - LLMErrorCategory.CONNECTION_FAILED, - LLMErrorCategory.SERVER_ERROR, -} - -_CLASS_NAME_MAP: tuple[tuple[LLMErrorCategory, tuple[str, ...]], ...] = ( - ( - LLMErrorCategory.RATE_LIMITED, - ("RateLimitError", "TooManyRequests", "TooManyRequestsError"), - ), - (LLMErrorCategory.TIMEOUT, ("Timeout", "APITimeoutError", "TimeoutException")), - ( - LLMErrorCategory.PROVIDER_UNAVAILABLE, - ("ServiceUnavailableError", "ServiceUnavailable"), - ), - ( - LLMErrorCategory.BAD_GATEWAY, - ("BadGatewayError", "GatewayTimeoutError"), - ), - ( - LLMErrorCategory.CONNECTION_FAILED, - ("APIConnectionError", "ConnectError", "ConnectTimeout", "ReadTimeout"), - ), - ( - LLMErrorCategory.AUTH_FAILED, - ("AuthenticationError", "InvalidApiKey", "InvalidAPIKey", "InvalidApiKeyError"), - ), - (LLMErrorCategory.PERMISSION_DENIED, ("PermissionDeniedError", "ForbiddenError")), - (LLMErrorCategory.MODEL_NOT_FOUND, ("NotFoundError", "ModelNotFoundError")), - ( - LLMErrorCategory.CONTEXT_LIMIT, - ("ContextWindowExceeded", "ContextOverflow", "ContextLimit"), - ), - ( - LLMErrorCategory.RESPONSE_INVALID, - ("APIResponseValidationError", "ResponseValidationError"), - ), - ( - LLMErrorCategory.BAD_REQUEST, - ("BadRequestError", "InvalidRequestError", "UnprocessableEntityError"), - ), - (LLMErrorCategory.SERVER_ERROR, ("InternalServerError",)), -) - - -def _parse_error_payload(message: str) -> dict[str, Any] | None: - candidates = [message] - first_brace_idx = message.find("{") - if first_brace_idx >= 0: - candidates.append(message[first_brace_idx:]) - - for candidate in candidates: - try: - parsed = json.loads(candidate) - if isinstance(parsed, dict): - return parsed - except Exception: - continue - return None - - -def _class_names(exc: BaseException) -> tuple[str, ...]: - return tuple(cls.__name__ for cls in type(exc).__mro__) - - -def _category_from_class_name(exc: BaseException) -> LLMErrorCategory | None: - names = _class_names(exc) - for category, hints in _CLASS_NAME_MAP: - if any(any(hint in name for hint in hints) for name in names): - return category - return None - - -def _extract_provider_status_code(parsed: dict[str, Any] | None) -> int | None: - if not isinstance(parsed, dict): - return None - candidates: list[Any] = [parsed.get("code"), parsed.get("status")] - nested = parsed.get("error") - if isinstance(nested, dict): - candidates.extend([nested.get("code"), nested.get("status")]) - for value in candidates: - try: - if value is None: - continue - return int(value) - except Exception: - continue - return None - - -def _extract_provider_error_type(parsed: dict[str, Any] | None) -> str | None: - if not isinstance(parsed, dict): - return None - candidates: list[Any] = [parsed.get("type")] - nested = parsed.get("error") - if isinstance(nested, dict): - candidates.append(nested.get("type")) - for value in candidates: - if isinstance(value, str) and value: - return value - return None - - -def _category_from_provider_payload( - status_code: int | None, - provider_error_type: str | None, -) -> LLMErrorCategory | None: - if status_code == 429: - return LLMErrorCategory.RATE_LIMITED - if status_code == 401: - return LLMErrorCategory.AUTH_FAILED - if status_code == 403: - return LLMErrorCategory.PERMISSION_DENIED - if status_code == 404: - return LLMErrorCategory.MODEL_NOT_FOUND - if status_code in (400, 422): - return LLMErrorCategory.BAD_REQUEST - if status_code in (502, 504): - return LLMErrorCategory.BAD_GATEWAY - if status_code == 503: - return LLMErrorCategory.PROVIDER_UNAVAILABLE - if status_code is not None and status_code >= 500: - return LLMErrorCategory.SERVER_ERROR - - normalized_type = (provider_error_type or "").lower() - if normalized_type == "rate_limit_error": - return LLMErrorCategory.RATE_LIMITED - if normalized_type in { - "authentication_error", - "invalid_api_key", - "invalid_api_key_error", - }: - return LLMErrorCategory.AUTH_FAILED - if normalized_type in {"permission_denied", "forbidden"}: - return LLMErrorCategory.PERMISSION_DENIED - if normalized_type in {"not_found_error", "model_not_found"}: - return LLMErrorCategory.MODEL_NOT_FOUND - if normalized_type in {"context_length_exceeded", "context_window_exceeded"}: - return LLMErrorCategory.CONTEXT_LIMIT - return None - - -def _category_from_message(raw: str) -> LLMErrorCategory | None: - lowered = raw.lower() - if any( - hint in lowered - for hint in ("rate limit", "rate-limited", "temporarily rate-limited") - ): - return LLMErrorCategory.RATE_LIMITED - if any( - hint in lowered - for hint in ( - "invalid api key", - "invalid_api_key", - "authentication", - "unauthorized", - "user not found", - "api key is expired", - "expired api key", - ) - ): - return LLMErrorCategory.AUTH_FAILED - if "forbidden" in lowered or "permission denied" in lowered: - return LLMErrorCategory.PERMISSION_DENIED - if "model not found" in lowered: - return LLMErrorCategory.MODEL_NOT_FOUND - if any( - hint in lowered - for hint in ( - "context length", - "context window", - "maximum context", - "too many tokens", - ) - ): - return LLMErrorCategory.CONTEXT_LIMIT - return None - - -def adapt_llm_exception(exc: BaseException) -> LLMErrorAdaptation: - raw = str(exc) - parsed = _parse_error_payload(raw) - status_code = _extract_provider_status_code(parsed) - provider_error_type = _extract_provider_error_type(parsed) - - category = ( - _category_from_provider_payload(status_code, provider_error_type) - or _category_from_message(raw) - or _category_from_class_name(exc) - or LLMErrorCategory.UNKNOWN - ) - return LLMErrorAdaptation( - category=category, - retryable=category in _RETRYABLE_CATEGORIES, - user_message=_CATEGORY_MESSAGES[category], - provider_status_code=status_code, - provider_error_type=provider_error_type, - ) - - -def llm_error_message(exc: BaseException) -> str: - return adapt_llm_exception(exc).user_message diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py index 3affdcce7..d220aa346 100644 --- a/surfsense_backend/app/services/llm_router_service.py +++ b/surfsense_backend/app/services/llm_router_service.py @@ -30,7 +30,6 @@ from litellm.exceptions import ( ) from pydantic import Field -from app.services.model_resolver import native_connection_from_config, to_litellm from app.utils.perf import get_perf_logger litellm.json_logs = False @@ -97,6 +96,53 @@ def _sanitize_content(content: Any) -> Any: # Special ID for Auto mode - uses router for load balancing AUTO_MODE_ID = 0 +# Provider mapping for LiteLLM model string construction +PROVIDER_MAP = { + "OPENAI": "openai", + "ANTHROPIC": "anthropic", + "GROQ": "groq", + "COHERE": "cohere", + "GOOGLE": "gemini", + "OLLAMA": "ollama_chat", + "MISTRAL": "mistral", + "AZURE_OPENAI": "azure", + "OPENROUTER": "openrouter", + "COMETAPI": "cometapi", + "XAI": "xai", + "BEDROCK": "bedrock", + "AWS_BEDROCK": "bedrock", # Legacy support + "VERTEX_AI": "vertex_ai", + "TOGETHER_AI": "together_ai", + "FIREWORKS_AI": "fireworks_ai", + "REPLICATE": "replicate", + "PERPLEXITY": "perplexity", + "ANYSCALE": "anyscale", + "DEEPINFRA": "deepinfra", + "CEREBRAS": "cerebras", + "SAMBANOVA": "sambanova", + "AI21": "ai21", + "CLOUDFLARE": "cloudflare", + "DATABRICKS": "databricks", + "DEEPSEEK": "openai", + "ALIBABA_QWEN": "openai", + "MOONSHOT": "openai", + "ZHIPU": "openai", + "GITHUB_MODELS": "github", + "HUGGINGFACE": "huggingface", + "MINIMAX": "openai", + "CUSTOM": "custom", +} + + +# ``PROVIDER_DEFAULT_API_BASE`` and ``PROVIDER_KEY_DEFAULT_API_BASE`` were +# hoisted to ``app.services.provider_api_base`` so vision and image-gen +# call sites can share the exact same defense (OpenRouter / Groq / etc. +# 404-ing against an inherited Azure endpoint). Re-exported here for +# backward compatibility with any external import. +from app.services.provider_api_base import ( # noqa: E402 + resolve_api_base, +) + class LLMRouterService: """ @@ -374,11 +420,38 @@ class LLMRouterService: if not config.get("model_name") or not config.get("api_key"): return None - model_string, resolved_kwargs = to_litellm( - native_connection_from_config(config), - config["model_name"], + # Build model string + provider = config.get("provider", "").upper() + if config.get("custom_provider"): + provider_prefix = config["custom_provider"] + model_string = f"{provider_prefix}/{config['model_name']}" + else: + provider_prefix = PROVIDER_MAP.get(provider, provider.lower()) + model_string = f"{provider_prefix}/{config['model_name']}" + + # Build litellm params + litellm_params = { + "model": model_string, + "api_key": config.get("api_key"), + } + + # Resolve ``api_base``. Config value wins; otherwise apply a + # provider-aware default so the deployment does not silently + # inherit unrelated env vars (e.g. ``AZURE_API_BASE``) and route + # requests to the wrong endpoint. See ``provider_api_base`` + # docstring for the motivating bug (OpenRouter models 404-ing + # against an Azure endpoint). + api_base = resolve_api_base( + provider=provider, + provider_prefix=provider_prefix, + config_api_base=config.get("api_base"), ) - litellm_params = {"model": model_string, **resolved_kwargs} + if api_base: + litellm_params["api_base"] = api_base + + # Add any additional litellm parameters + if config.get("litellm_params"): + litellm_params.update(config["litellm_params"]) # Extract rate limits if provided deployment = { diff --git a/surfsense_backend/app/services/llm_service.py b/surfsense_backend/app/services/llm_service.py index e535d0150..7061a826f 100644 --- a/surfsense_backend/app/services/llm_service.py +++ b/surfsense_backend/app/services/llm_service.py @@ -6,21 +6,17 @@ from langchain_core.messages import HumanMessage from langchain_litellm import ChatLiteLLM from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from sqlalchemy.orm import selectinload from app.config import config -from app.db import Model, SearchSpace -from app.services.auto_model_pin_service import ( - auto_model_candidates, - choose_auto_model_candidate, -) +from app.db import NewLLMConfig, SearchSpace from app.services.llm_router_service import ( AUTO_MODE_ID, ChatLiteLLMRouter, + LLMRouterService, + get_auto_mode_llm, is_auto_mode, ) -from app.services.model_capabilities import has_capability -from app.services.model_resolver import native_connection_from_config, to_litellm +from app.services.provider_api_base import resolve_api_base from app.services.token_tracking_service import token_tracker # Configure litellm to automatically drop unsupported parameters @@ -70,29 +66,6 @@ def _is_interactive_auth_provider( return False -def _legacy_config_connection( - *, - provider: str, - model_name: str, - api_key: str | None, - api_base: str | None, - custom_provider: str | None = None, - litellm_params: dict | None = None, - api_version: str | None = None, -) -> tuple[str, dict]: - cfg = { - "provider": provider.lower(), - "model_name": model_name, - "api_key": api_key, - "api_base": api_base, - "custom_provider": custom_provider, - "api_version": api_version, - "litellm_params": litellm_params or {}, - } - conn = native_connection_from_config(cfg) - return to_litellm(conn, model_name) - - class LLMRole: AGENT = "agent" # For agent/chat operations @@ -100,16 +73,26 @@ class LLMRole: def get_global_llm_config(llm_config_id: int) -> dict | None: """ Get a global LLM configuration by ID. - Global configs have negative IDs. Auto mode (ID 0) is resolved through the - model-candidate pipeline, not this legacy config lookup. + Global configs have negative IDs. ID 0 is reserved for Auto mode. Args: - llm_config_id: The ID of the global config (must be negative) + llm_config_id: The ID of the global config (should be negative or 0 for Auto) Returns: dict: Global config dictionary or None if not found """ - if llm_config_id >= 0: + # Auto mode (ID 0) is handled separately via the router + if llm_config_id == AUTO_MODE_ID: + return { + "id": AUTO_MODE_ID, + "name": "Auto (Fastest)", + "description": "Automatically routes requests across available LLM providers for optimal performance and rate limit handling", + "provider": "AUTO", + "model_name": "auto", + "is_auto_mode": True, + } + + if llm_config_id > 0: return None for cfg in config.GLOBAL_LLM_CONFIGS: @@ -119,55 +102,6 @@ def get_global_llm_config(llm_config_id: int) -> dict | None: return None -def get_global_model(model_id: int) -> dict | None: - return next((m for m in config.GLOBAL_MODELS if m.get("id") == model_id), None) - - -def get_global_connection(connection_id: int) -> dict | None: - return next( - (c for c in config.GLOBAL_CONNECTIONS if c.get("id") == connection_id), - None, - ) - - -def _has_capability(model: dict | Model, capability: str) -> bool: - return has_capability(model, capability) - - -def _chat_litellm_from_resolved( - *, - conn: dict | object, - model_id: str, - disable_streaming: bool = False, -) -> tuple[str, dict]: - model_string, resolved_kwargs = to_litellm(conn, model_id) - litellm_kwargs = {"model": model_string, **resolved_kwargs} - if disable_streaming: - litellm_kwargs["disable_streaming"] = True - return model_string, litellm_kwargs - - -async def _get_db_model( - session: AsyncSession, - model_id: int, - search_space: SearchSpace, -) -> Model | None: - result = await session.execute( - select(Model) - .options(selectinload(Model.connection)) - .where(Model.id == model_id, Model.enabled.is_(True)) - ) - model = result.scalars().first() - if not model or not model.connection or not model.connection.enabled: - return None - conn = model.connection - if conn.search_space_id and conn.search_space_id != search_space.id: - return None - if conn.user_id and conn.user_id != search_space.user_id: - return None - return model - - async def validate_llm_config( provider: str, model_name: str, @@ -212,15 +146,62 @@ async def validate_llm_config( return False, msg try: - model_string, resolved_kwargs = _legacy_config_connection( - provider=provider, - model_name=model_name, - api_key=api_key, - api_base=api_base, - custom_provider=custom_provider, - litellm_params=litellm_params, - ) - litellm_kwargs = {"model": model_string, **resolved_kwargs, "timeout": 30} + # Build the model string for litellm + if custom_provider: + model_string = f"{custom_provider}/{model_name}" + else: + # Map provider enum to litellm format + provider_map = { + "OPENAI": "openai", + "ANTHROPIC": "anthropic", + "GROQ": "groq", + "COHERE": "cohere", + "GOOGLE": "gemini", + "OLLAMA": "ollama_chat", + "MISTRAL": "mistral", + "AZURE_OPENAI": "azure", + "OPENROUTER": "openrouter", + "COMETAPI": "cometapi", + "XAI": "xai", + "BEDROCK": "bedrock", + "AWS_BEDROCK": "bedrock", # Legacy support (backward compatibility) + "VERTEX_AI": "vertex_ai", + "TOGETHER_AI": "together_ai", + "FIREWORKS_AI": "fireworks_ai", + "REPLICATE": "replicate", + "PERPLEXITY": "perplexity", + "ANYSCALE": "anyscale", + "DEEPINFRA": "deepinfra", + "CEREBRAS": "cerebras", + "SAMBANOVA": "sambanova", + "AI21": "ai21", + "CLOUDFLARE": "cloudflare", + "DATABRICKS": "databricks", + # Chinese LLM providers + "DEEPSEEK": "openai", + "ALIBABA_QWEN": "openai", + "MOONSHOT": "openai", + "ZHIPU": "openai", # GLM needs special handling + "MINIMAX": "openai", + "GITHUB_MODELS": "github", + } + provider_prefix = provider_map.get(provider, provider.lower()) + model_string = f"{provider_prefix}/{model_name}" + + # Create ChatLiteLLM instance + litellm_kwargs = { + "model": model_string, + "api_key": api_key, + "timeout": 30, # Set a timeout for validation + } + + # Add optional parameters + if api_base: + litellm_kwargs["api_base"] = api_base + + # Add any additional litellm parameters + if litellm_params: + litellm_kwargs.update(litellm_params) from app.agents.chat.runtime.llm_config import ( SanitizedChatLiteLLM, @@ -302,9 +283,9 @@ async def get_search_space_llm_instance( logger.error(f"Search space {search_space_id} not found") return None - # Get the appropriate model binding ID based on role + # Get the appropriate LLM config ID based on role if role == LLMRole.AGENT: - llm_config_id = search_space.chat_model_id + llm_config_id = search_space.agent_llm_id else: logger.error(f"Invalid LLM role: {role}") return None @@ -313,42 +294,88 @@ async def get_search_space_llm_instance( logger.error(f"No {role} LLM configured for search space {search_space_id}") return None - # Auto mode resolves to one concrete global or BYOK model from the - # unified model-connections catalog. + # Check for Auto mode (ID 0) - use router for load balancing if is_auto_mode(llm_config_id): - candidates = await auto_model_candidates( - session, - search_space_id=search_space_id, - user_id=search_space.user_id, - capability="chat", - ) - if not candidates: - logger.error("No chat-capable models available for Auto mode") - return None - llm_config_id = int( - choose_auto_model_candidate(candidates, search_space_id)["id"] - ) - - # Check if this is a global virtual model (negative ID) - if llm_config_id < 0: - global_model = get_global_model(llm_config_id) - if not global_model or not _has_capability(global_model, "chat"): - logger.error(f"Global chat model {llm_config_id} not found") - return None - global_connection = get_global_connection(global_model["connection_id"]) - if not global_connection: + if not LLMRouterService.is_initialized(): logger.error( - "Global connection %s not found for model %s", - global_model["connection_id"], - llm_config_id, + "Auto mode requested but LLM Router not initialized. " + "Ensure global_llm_config.yaml exists with valid configs." ) return None - _, litellm_kwargs = _chat_litellm_from_resolved( - conn=global_connection, - model_id=global_model["model_id"], - disable_streaming=disable_streaming, - ) + try: + logger.debug( + f"Using Auto mode (LLM Router) for search space {search_space_id}, role {role}" + ) + return get_auto_mode_llm(streaming=not disable_streaming) + except Exception as e: + logger.error(f"Failed to create ChatLiteLLMRouter: {e}") + return None + + # Check if this is a global config (negative ID) + if llm_config_id < 0: + global_config = get_global_llm_config(llm_config_id) + if not global_config: + logger.error(f"Global LLM config {llm_config_id} not found") + return None + + # Build model string for global config + if global_config.get("custom_provider"): + model_string = ( + f"{global_config['custom_provider']}/{global_config['model_name']}" + ) + else: + provider_map = { + "OPENAI": "openai", + "ANTHROPIC": "anthropic", + "GROQ": "groq", + "COHERE": "cohere", + "GOOGLE": "gemini", + "OLLAMA": "ollama_chat", + "MISTRAL": "mistral", + "AZURE_OPENAI": "azure", + "OPENROUTER": "openrouter", + "COMETAPI": "cometapi", + "XAI": "xai", + "BEDROCK": "bedrock", + "AWS_BEDROCK": "bedrock", + "VERTEX_AI": "vertex_ai", + "TOGETHER_AI": "together_ai", + "FIREWORKS_AI": "fireworks_ai", + "REPLICATE": "replicate", + "PERPLEXITY": "perplexity", + "ANYSCALE": "anyscale", + "DEEPINFRA": "deepinfra", + "CEREBRAS": "cerebras", + "SAMBANOVA": "sambanova", + "AI21": "ai21", + "CLOUDFLARE": "cloudflare", + "DATABRICKS": "databricks", + "DEEPSEEK": "openai", + "ALIBABA_QWEN": "openai", + "MOONSHOT": "openai", + "ZHIPU": "openai", + "MINIMAX": "openai", + } + provider_prefix = provider_map.get( + global_config["provider"], global_config["provider"].lower() + ) + model_string = f"{provider_prefix}/{global_config['model_name']}" + + # Create ChatLiteLLM instance from global config + litellm_kwargs = { + "model": model_string, + "api_key": global_config["api_key"], + } + + if global_config.get("api_base"): + litellm_kwargs["api_base"] = global_config["api_base"] + + if global_config.get("litellm_params"): + litellm_kwargs.update(global_config["litellm_params"]) + + if disable_streaming: + litellm_kwargs["disable_streaming"] = True from app.agents.chat.runtime.llm_config import ( SanitizedChatLiteLLM, @@ -356,18 +383,80 @@ async def get_search_space_llm_instance( return SanitizedChatLiteLLM(**litellm_kwargs) - model = await _get_db_model(session, llm_config_id, search_space) - if not model or not _has_capability(model, "chat"): + # Get the LLM configuration from database (NewLLMConfig) + result = await session.execute( + select(NewLLMConfig).where( + NewLLMConfig.id == llm_config_id, + NewLLMConfig.search_space_id == search_space_id, + ) + ) + llm_config = result.scalars().first() + + if not llm_config: logger.error( - f"Chat model {llm_config_id} not found in search space {search_space_id}" + f"LLM config {llm_config_id} not found in search space {search_space_id}" ) return None - _, litellm_kwargs = _chat_litellm_from_resolved( - conn=model.connection, - model_id=model.model_id, - disable_streaming=disable_streaming, - ) + # Build the model string for litellm + if llm_config.custom_provider: + model_string = f"{llm_config.custom_provider}/{llm_config.model_name}" + else: + # Map provider enum to litellm format + provider_map = { + "OPENAI": "openai", + "ANTHROPIC": "anthropic", + "GROQ": "groq", + "COHERE": "cohere", + "GOOGLE": "gemini", + "OLLAMA": "ollama_chat", + "MISTRAL": "mistral", + "AZURE_OPENAI": "azure", + "OPENROUTER": "openrouter", + "COMETAPI": "cometapi", + "XAI": "xai", + "BEDROCK": "bedrock", + "AWS_BEDROCK": "bedrock", + "VERTEX_AI": "vertex_ai", + "TOGETHER_AI": "together_ai", + "FIREWORKS_AI": "fireworks_ai", + "REPLICATE": "replicate", + "PERPLEXITY": "perplexity", + "ANYSCALE": "anyscale", + "DEEPINFRA": "deepinfra", + "CEREBRAS": "cerebras", + "SAMBANOVA": "sambanova", + "AI21": "ai21", + "CLOUDFLARE": "cloudflare", + "DATABRICKS": "databricks", + "DEEPSEEK": "openai", + "ALIBABA_QWEN": "openai", + "MOONSHOT": "openai", + "ZHIPU": "openai", + "MINIMAX": "openai", + "GITHUB_MODELS": "github", + } + provider_prefix = provider_map.get( + llm_config.provider.value, llm_config.provider.value.lower() + ) + model_string = f"{provider_prefix}/{llm_config.model_name}" + + # Create ChatLiteLLM instance + litellm_kwargs = { + "model": model_string, + "api_key": llm_config.api_key, + } + + # Add optional parameters + if llm_config.api_base: + litellm_kwargs["api_base"] = llm_config.api_base + + # Add any additional litellm parameters + if llm_config.litellm_params: + litellm_kwargs.update(llm_config.litellm_params) + + if disable_streaming: + litellm_kwargs["disable_streaming"] = True from app.agents.chat.runtime.llm_config import ( SanitizedChatLiteLLM, @@ -385,7 +474,7 @@ async def get_search_space_llm_instance( async def get_agent_llm( session: AsyncSession, search_space_id: int, disable_streaming: bool = False ) -> ChatLiteLLM | ChatLiteLLMRouter | None: - """Get the search space's chat model instance.""" + """Get the search space's agent LLM instance for chat operations.""" return await get_search_space_llm_instance( session, search_space_id, @@ -399,17 +488,24 @@ async def get_vision_llm( ) -> ChatLiteLLM | ChatLiteLLMRouter | None: """Get the search space's vision LLM instance for screenshot analysis. - Resolves from the new connection/model role bindings: - - Auto mode (ID 0): unified global/BYOK model candidate selection - - Global (negative ID): virtual GLOBAL models from YAML - - DB (positive ID): Model + Connection tables + Resolves from the dedicated VisionLLMConfig system: + - Auto mode (ID 0): VisionLLMRouterService + - Global (negative ID): YAML configs + - DB (positive ID): VisionLLMConfig table Premium global configs are wrapped in :class:`QuotaCheckedVisionLLM` so each ``ainvoke`` debits the search-space owner's premium credit pool. User-owned BYOK configs and free global configs are returned unwrapped — they don't consume premium credit (issue M). """ + from app.db import VisionLLMConfig from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM + from app.services.vision_llm_router_service import ( + VISION_PROVIDER_MAP, + VisionLLMRouterService, + get_global_vision_llm_config, + is_vision_auto_mode, + ) try: result = await session.execute( @@ -420,78 +516,64 @@ async def get_vision_llm( logger.error(f"Search space {search_space_id} not found") return None - owner_user_id = search_space.user_id - - # Prefer the selected chat model when it is vision-capable. - chat_model_id = search_space.chat_model_id - if chat_model_id and chat_model_id != AUTO_MODE_ID: - if chat_model_id < 0: - chat_model = get_global_model(chat_model_id) - if chat_model and _has_capability(chat_model, "vision"): - global_connection = get_global_connection( - chat_model["connection_id"] - ) - if global_connection: - model_string, litellm_kwargs = _chat_litellm_from_resolved( - conn=global_connection, - model_id=chat_model["model_id"], - ) - from app.agents.chat.runtime.llm_config import ( - SanitizedChatLiteLLM, - ) - - return SanitizedChatLiteLLM(**litellm_kwargs) - else: - chat_model = await _get_db_model(session, chat_model_id, search_space) - if chat_model and _has_capability(chat_model, "vision"): - _, litellm_kwargs = _chat_litellm_from_resolved( - conn=chat_model.connection, - model_id=chat_model.model_id, - ) - from app.agents.chat.runtime.llm_config import ( - SanitizedChatLiteLLM, - ) - - return SanitizedChatLiteLLM(**litellm_kwargs) - - config_id = search_space.vision_model_id + config_id = search_space.vision_llm_config_id if config_id is None: logger.error(f"No vision LLM configured for search space {search_space_id}") return None - if config_id == AUTO_MODE_ID: - candidates = await auto_model_candidates( - session, - search_space_id=search_space_id, - user_id=owner_user_id, - capability="vision", - ) - if not candidates: - logger.error("No vision-capable models available for Auto mode") - return None - config_id = int( - choose_auto_model_candidate(candidates, search_space_id)["id"] - ) + owner_user_id = search_space.user_id - if config_id < 0: - global_model = get_global_model(config_id) - if not global_model or not _has_capability(global_model, "vision"): - logger.error(f"Global vision model {config_id} not found") - return None - - global_connection = get_global_connection(global_model["connection_id"]) - if not global_connection: + if is_vision_auto_mode(config_id): + if not VisionLLMRouterService.is_initialized(): logger.error( - "Global connection %s not found for model %s", - global_model["connection_id"], - config_id, + "Vision Auto mode requested but Vision LLM Router not initialized" ) return None + try: + # Auto mode is currently treated as free at the wrapper + # level — the underlying router can dispatch to either + # premium or free YAML configs but routing decisions are + # opaque. If/when we want to bill Auto-routed vision + # calls we'd need to thread the resolved deployment's + # billing_tier back from the router. For now we keep + # parity with chat Auto, which also doesn't pre-classify. + return ChatLiteLLMRouter( + router=VisionLLMRouterService.get_router(), + streaming=True, + ) + except Exception as e: + logger.error(f"Failed to create vision ChatLiteLLMRouter: {e}") + return None - model_string, litellm_kwargs = _chat_litellm_from_resolved( - conn=global_connection, - model_id=global_model["model_id"], + if config_id < 0: + global_cfg = get_global_vision_llm_config(config_id) + if not global_cfg: + logger.error(f"Global vision LLM config {config_id} not found") + return None + + if global_cfg.get("custom_provider"): + provider_prefix = global_cfg["custom_provider"] + model_string = f"{provider_prefix}/{global_cfg['model_name']}" + else: + provider_prefix = VISION_PROVIDER_MAP.get( + global_cfg["provider"].upper(), + global_cfg["provider"].lower(), + ) + model_string = f"{provider_prefix}/{global_cfg['model_name']}" + + litellm_kwargs = { + "model": model_string, + "api_key": global_cfg["api_key"], + } + api_base = resolve_api_base( + provider=global_cfg.get("provider"), + provider_prefix=provider_prefix, + config_api_base=global_cfg.get("api_base"), ) + if api_base: + litellm_kwargs["api_base"] = api_base + if global_cfg.get("litellm_params"): + litellm_kwargs.update(global_cfg["litellm_params"]) from app.agents.chat.runtime.llm_config import ( SanitizedChatLiteLLM, @@ -499,7 +581,7 @@ async def get_vision_llm( inner_llm = SanitizedChatLiteLLM(**litellm_kwargs) - billing_tier = str(global_model.get("billing_tier", "free")).lower() + billing_tier = str(global_cfg.get("billing_tier", "free")).lower() if billing_tier == "premium": return QuotaCheckedVisionLLM( inner_llm, @@ -507,23 +589,47 @@ async def get_vision_llm( search_space_id=search_space_id, billing_tier=billing_tier, base_model=model_string, - quota_reserve_tokens=global_model.get("catalog", {}).get( - "quota_reserve_tokens" - ), + quota_reserve_tokens=global_cfg.get("quota_reserve_tokens"), ) return inner_llm - model = await _get_db_model(session, config_id, search_space) - if not model or not _has_capability(model, "vision"): + # User-owned (positive ID) BYOK configs — always free. + result = await session.execute( + select(VisionLLMConfig).where( + VisionLLMConfig.id == config_id, + VisionLLMConfig.search_space_id == search_space_id, + ) + ) + vision_cfg = result.scalars().first() + if not vision_cfg: logger.error( - f"Vision model {config_id} not found in search space {search_space_id}" + f"Vision LLM config {config_id} not found in search space {search_space_id}" ) return None - _, litellm_kwargs = _chat_litellm_from_resolved( - conn=model.connection, - model_id=model.model_id, + if vision_cfg.custom_provider: + provider_prefix = vision_cfg.custom_provider + model_string = f"{provider_prefix}/{vision_cfg.model_name}" + else: + provider_prefix = VISION_PROVIDER_MAP.get( + vision_cfg.provider.value.upper(), + vision_cfg.provider.value.lower(), + ) + model_string = f"{provider_prefix}/{vision_cfg.model_name}" + + litellm_kwargs = { + "model": model_string, + "api_key": vision_cfg.api_key, + } + api_base = resolve_api_base( + provider=vision_cfg.provider.value, + provider_prefix=provider_prefix, + config_api_base=vision_cfg.api_base, ) + if api_base: + litellm_kwargs["api_base"] = api_base + if vision_cfg.litellm_params: + litellm_kwargs.update(vision_cfg.litellm_params) from app.agents.chat.runtime.llm_config import ( SanitizedChatLiteLLM, diff --git a/surfsense_backend/app/services/model_capabilities.py b/surfsense_backend/app/services/model_capabilities.py deleted file mode 100644 index fb7681f35..000000000 --- a/surfsense_backend/app/services/model_capabilities.py +++ /dev/null @@ -1,36 +0,0 @@ -"""Override-aware model capability lookup.""" - -from __future__ import annotations - -from collections.abc import Mapping -from typing import Any - -CAPABILITY_FIELDS = { - "chat": "supports_chat", - "vision": "supports_image_input", - "image_gen": "supports_image_generation", - "tools": "supports_tools", -} - - -def _get_value(model: Any, key: str) -> Any: - if isinstance(model, Mapping): - return model.get(key) - return getattr(model, key, None) - - -def has_capability(model: Any, capability: str) -> bool: - field = CAPABILITY_FIELDS.get(capability) - if field is None: - return False - - override = _get_value(model, "capabilities_override") or {} - if isinstance(override, Mapping) and field in override: - return bool(override[field]) - if isinstance(override, Mapping) and capability in override: - return bool(override[capability]) - - return bool(_get_value(model, field)) - - -__all__ = ["CAPABILITY_FIELDS", "has_capability"] diff --git a/surfsense_backend/app/services/model_connection_service.py b/surfsense_backend/app/services/model_connection_service.py deleted file mode 100644 index cdfd1d725..000000000 --- a/surfsense_backend/app/services/model_connection_service.py +++ /dev/null @@ -1,490 +0,0 @@ -"""Connection verification, model discovery, and capability probing.""" - -from __future__ import annotations - -import contextlib -import logging -from dataclasses import dataclass -from typing import Any - -import anyio -import httpx -import litellm - -from app.db import Connection, Model, ModelSource -from app.services.model_resolver import ensure_v1, to_litellm -from app.services.openrouter_model_normalizer import normalize_openrouter_models -from app.services.provider_registry import Transport, provider_label, spec_for - -logger = logging.getLogger(__name__) - -VERIFY_TIMEOUT_SECONDS = 8.0 -DISCOVERY_TIMEOUT_SECONDS = 15.0 -TEST_TIMEOUT_SECONDS = 30.0 - - -@dataclass(frozen=True) -class VerifyResult: - status: str - ok: bool - message: str = "" - - -class ModelDiscoveryError(Exception): - """User-correctable discovery failure for provider configuration issues.""" - - -def _auth_headers(conn: Connection) -> dict[str, str]: - if not conn.api_key: - return {} - return {"Authorization": f"Bearer {conn.api_key}"} - - -def _anthropic_headers(conn: Connection) -> dict[str, str]: - headers = {"anthropic-version": "2023-06-01"} - if conn.api_key: - headers["x-api-key"] = conn.api_key - return headers - - -def _base_url_or_default(conn: Connection) -> str | None: - if conn.base_url: - return conn.base_url.rstrip("/") - if conn.provider == "openai": - return "https://api.openai.com/v1" - if conn.provider == "anthropic": - return "https://api.anthropic.com/v1" - return spec_for(conn.provider).default_base_url - - -def _docker_hint(url: str | None, exc_or_status: Any) -> str: - raw = str(exc_or_status) - if not url: - return raw - if "localhost" in url or "127.0.0.1" in url: - return ( - f"{raw}. The backend is running inside Docker; localhost means the " - "backend container. Use host.docker.internal and make sure the model " - "server listens on 0.0.0.0." - ) - if "host.docker.internal" in url and ( - "refused" in raw.lower() or "connect" in raw.lower() - ): - return ( - f"{raw}. The host is reachable only if your local model server is " - "listening on 0.0.0.0. On Linux Docker, add " - "`host.docker.internal:host-gateway` to extra_hosts." - ) - return raw - - -def _model_test_error(conn: Connection, model_id: str, exc: Exception) -> VerifyResult: - provider_name = provider_label(conn.provider) - raw = str(exc) - normalized = raw.lower() - exc_name = exc.__class__.__name__.lower() - status_code = getattr(exc, "status_code", None) - - logger.info( - "Model test failed for provider=%s model=%s: %s", - conn.provider, - model_id, - raw, - ) - - if status_code in (401, 403) or "authentication" in exc_name or "401" in normalized: - return VerifyResult( - "AUTH_FAILED", - False, - f"Authentication failed. Check your {provider_name} credentials and try again.", - ) - - if status_code == 404 or "notfound" in exc_name or "not found" in normalized: - if conn.provider == "azure": - message = ( - "Azure OpenAI deployment was not found. Check the deployment name, " - "API version, and endpoint." - ) - else: - message = f"Model '{model_id}' was not found on {provider_name}." - return VerifyResult("NOT_FOUND", False, message) - - if status_code == 429 or "ratelimit" in exc_name or "rate limit" in normalized: - return VerifyResult( - "RATE_LIMITED", - False, - f"{provider_name} rate limited the model test. Try again later.", - ) - - if "timeout" in exc_name or "timed out" in normalized: - return VerifyResult( - "TIMEOUT", - False, - f"{provider_name} did not respond in time. Check the endpoint and try again.", - ) - - if "connection" in exc_name or "connect" in normalized: - return VerifyResult( - "UNREACHABLE", - False, - _docker_hint( - _base_url_or_default(conn), - f"Could not reach {provider_name}. Check the endpoint and try again.", - ), - ) - - return VerifyResult( - "UNREACHABLE", - False, - f"Could not test model '{model_id}' on {provider_name}. Check the credentials, endpoint, and model name.", - ) - - -async def verify_connection(conn: Connection) -> VerifyResult: - spec = spec_for(conn.provider) - base_url = _base_url_or_default(conn) - if spec.base_url_required and not base_url: - return VerifyResult("UNREACHABLE", False, "Base URL is required.") - - if spec.transport == Transport.OLLAMA and base_url: - url = f"{base_url.rstrip('/')}/api/version" - elif spec.discovery in {"openai_models", "openrouter"} and base_url: - url = f"{ensure_v1(base_url)}/models" - elif spec.discovery == "anthropic_models" and base_url: - url = f"{base_url.rstrip('/')}/models" - else: - return VerifyResult( - "OK", True, "Connection uses provider-native authentication." - ) - - try: - async with httpx.AsyncClient(timeout=VERIFY_TIMEOUT_SECONDS) as client: - headers = ( - _anthropic_headers(conn) - if spec.auth_style == "x-api-key" - else _auth_headers(conn) - ) - response = await client.get(url, headers=headers) - if response.status_code in (401, 403): - return VerifyResult("AUTH_FAILED", False, "Authentication failed.") - if response.status_code == 404: - if spec.transport == Transport.OLLAMA and url.endswith("/v1/models"): - message = "Ollama native API should not use /v1." - elif spec.transport == Transport.OPENAI_COMPATIBLE: - message = "OpenAI-compatible servers should expose /v1/models." - else: - message = "Endpoint returned 404." - return VerifyResult("NOT_FOUND", False, message) - response.raise_for_status() - return VerifyResult("OK", True, "Connection verified.") - except httpx.ConnectError as exc: - return VerifyResult("UNREACHABLE", False, _docker_hint(base_url, exc)) - except httpx.TimeoutException as exc: - return VerifyResult("UNREACHABLE", False, f"Connection timed out: {exc}") - except httpx.HTTPError as exc: - return VerifyResult("UNREACHABLE", False, _docker_hint(base_url, exc)) - - -def _discovery_error_message(conn: Connection, exc: httpx.HTTPError) -> str: - base_url = _base_url_or_default(conn) - if isinstance(exc, httpx.HTTPStatusError): - status_code = exc.response.status_code - if status_code in (401, 403): - return "Authentication failed while discovering models." - if status_code == 404: - spec = spec_for(conn.provider) - if spec.transport == Transport.OPENAI_COMPATIBLE: - return "OpenAI-compatible servers should expose /v1/models." - return "Model discovery endpoint returned 404." - return f"Model discovery failed with HTTP {status_code}." - if isinstance(exc, httpx.TimeoutException): - return f"Model discovery timed out: {exc}" - return _docker_hint(base_url, exc) - - -def _allowlist(conn: Connection) -> set[str]: - raw = (conn.extra or {}).get("model_ids") or [] - return {str(item).strip() for item in raw if str(item).strip()} - - -def _litellm_info(model_string: str, model_id: str) -> dict[str, Any]: - with contextlib.suppress(Exception): - info = litellm.get_model_info(model=model_string) - if isinstance(info, dict): - return info - return ( - litellm.model_cost.get(model_string) or litellm.model_cost.get(model_id) or {} - ) - - -def _classify_from_litellm(model_string: str, model_id: str) -> dict[str, Any]: - info = _litellm_info(model_string, model_id) - mode = info.get("mode") - supports_image_input = False - supports_tools = False - with contextlib.suppress(Exception): - supports_image_input = bool(litellm.supports_vision(model=model_string)) - with contextlib.suppress(Exception): - supports_tools = bool(litellm.supports_function_calling(model=model_string)) - return { - "supports_chat": mode in (None, "chat", "completion", "responses"), - "max_input_tokens": info.get("max_input_tokens") or info.get("max_tokens"), - "supports_image_input": supports_image_input, - "supports_tools": supports_tools, - "supports_image_generation": mode - in {"image_generation", "image_generation_model"}, - } - - -def derive_capabilities( - conn: Connection, model_id: str, metadata: dict | None = None -) -> dict[str, Any]: - metadata = metadata or {} - spec = spec_for(conn.provider) - model_string, _ = to_litellm(conn, model_id) - facts = _classify_from_litellm(model_string, model_id) - if spec.transport == Transport.OLLAMA: - caps = set(metadata.get("capabilities") or []) - details = metadata.get("details") or {} - facts.update( - { - "supports_chat": "embedding" not in caps, - "supports_image_input": "vision" in caps - or facts["supports_image_input"], - "supports_tools": "tools" in caps or facts["supports_tools"], - "supports_image_generation": False, - "max_input_tokens": metadata.get("context_length") - or metadata.get("num_ctx") - or details.get("context_length") - or facts["max_input_tokens"], - } - ) - return facts - - -async def _discover_openai_shaped_models( - conn: Connection, base_url: str | None -) -> list[dict[str, Any]]: - resolved_base_url = base_url or _base_url_or_default(conn) - if not resolved_base_url: - return [] - - url = f"{ensure_v1(resolved_base_url)}/models" - async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: - response = await client.get(url, headers=_auth_headers(conn)) - response.raise_for_status() - - results: list[dict[str, Any]] = [] - for item in response.json().get("data", []): - model_id = item.get("id") - if not model_id: - continue - results.append( - { - "model_id": model_id, - "display_name": item.get("name") or model_id, - "source": ModelSource.DISCOVERED, - **derive_capabilities(conn, model_id, item), - "metadata": item, - } - ) - return results - - -async def _discover_anthropic_models(conn: Connection) -> list[dict[str, Any]]: - base_url = _base_url_or_default(conn) - if not base_url: - return [] - - url = f"{base_url.rstrip('/')}/models" - async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: - response = await client.get(url, headers=_anthropic_headers(conn)) - response.raise_for_status() - - results: list[dict[str, Any]] = [] - for item in response.json().get("data", []): - model_id = item.get("id") - if not model_id: - continue - results.append( - { - "model_id": model_id, - "display_name": item.get("display_name") or model_id, - "source": ModelSource.DISCOVERED, - **derive_capabilities(conn, model_id, item), - "metadata": item, - } - ) - return results - - -async def _ollama_tags_then_show(conn: Connection) -> list[dict[str, Any]]: - if not conn.base_url: - return [] - - base_url = conn.base_url.rstrip("/") - async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: - response = await client.get(f"{base_url}/api/tags", headers=_auth_headers(conn)) - response.raise_for_status() - models = response.json().get("models", []) - results: list[dict[str, Any]] = [] - for item in models: - model_id = item.get("model") or item.get("name") - if not model_id: - continue - metadata = dict(item) - with contextlib.suppress(Exception): - show_response = await client.post( - f"{base_url}/api/show", - json={"model": model_id}, - headers=_auth_headers(conn), - ) - show_response.raise_for_status() - metadata.update(show_response.json()) - results.append( - { - "model_id": model_id, - "display_name": item.get("name") or model_id, - "source": ModelSource.DISCOVERED, - **derive_capabilities(conn, model_id, metadata), - "metadata": metadata, - } - ) - return results - - -async def _openrouter_models(conn: Connection) -> list[dict[str, Any]]: - base_url = _base_url_or_default(conn) or "https://openrouter.ai/api/v1" - async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: - response = await client.get( - f"{ensure_v1(base_url)}/models", headers=_auth_headers(conn) - ) - response.raise_for_status() - return normalize_openrouter_models(response.json().get("data", [])) - - -def _litellm_static_models(conn: Connection) -> list[dict[str, Any]]: - provider = conn.provider - prefix = spec_for(provider).litellm_prefix or provider - results: list[dict[str, Any]] = [] - for model_string, metadata in litellm.model_cost.items(): - if not isinstance(model_string, str) or not model_string.startswith( - f"{prefix}/" - ): - continue - model_id = model_string.split("/", 1)[1] - results.append( - { - "model_id": model_id, - "display_name": metadata.get("display_name") or model_id, - "source": ModelSource.DISCOVERED, - **_classify_from_litellm(model_string, model_id), - "metadata": metadata, - } - ) - return results - - -async def _discover_bedrock_models(conn: Connection) -> list[dict[str, Any]]: - params = (conn.extra or {}).get("litellm_params", {}) - region_name = params.get("aws_region_name") - if not region_name: - return [] - - def list_models() -> list[dict[str, Any]]: - import os - - import boto3 - - if bearer_token := params.get("aws_bearer_token_bedrock"): - try: - os.environ["AWS_BEARER_TOKEN_BEDROCK"] = bearer_token - client = boto3.client("bedrock", region_name=region_name) - finally: - os.environ.pop("AWS_BEARER_TOKEN_BEDROCK", None) - else: - client_kwargs: dict[str, str] = {"region_name": region_name} - if params.get("aws_access_key_id"): - client_kwargs["aws_access_key_id"] = params["aws_access_key_id"] - if params.get("aws_secret_access_key"): - client_kwargs["aws_secret_access_key"] = params["aws_secret_access_key"] - client = boto3.client("bedrock", **client_kwargs) - - response = client.list_foundation_models() - results: list[dict[str, Any]] = [] - for item in response.get("modelSummaries", []): - model_id = item.get("modelId") - if not model_id: - continue - input_modalities = set(item.get("inputModalities") or []) - output_modalities = set(item.get("outputModalities") or []) - results.append( - { - "model_id": model_id, - "display_name": item.get("modelName") or model_id, - "source": ModelSource.DISCOVERED, - "supports_chat": "TEXT" in input_modalities - and "TEXT" in output_modalities, - "supports_image_input": "IMAGE" in input_modalities, - "supports_tools": None, - "supports_image_generation": "IMAGE" in output_modalities, - "max_input_tokens": None, - "metadata": item, - } - ) - return results - - return await anyio.to_thread.run_sync(list_models) - - -async def discover_models(conn: Connection) -> list[dict[str, Any]]: - allowlist = _allowlist(conn) - spec = spec_for(conn.provider) - - try: - if spec.discovery == "ollama": - results = await _ollama_tags_then_show(conn) - elif spec.discovery == "openrouter": - results = await _openrouter_models(conn) - elif spec.discovery == "anthropic_models": - results = await _discover_anthropic_models(conn) - elif spec.discovery == "openai_models": - results = await _discover_openai_shaped_models(conn, conn.base_url) - elif spec.discovery == "bedrock_models": - results = await _discover_bedrock_models(conn) - elif spec.discovery == "static": - results = _litellm_static_models(conn) - else: - results = [] - except httpx.HTTPError as exc: - raise ModelDiscoveryError(_discovery_error_message(conn, exc)) from exc - - if allowlist: - results = [item for item in results if item["model_id"] in allowlist] - return results - - -async def test_model(conn: Connection, model: Model) -> VerifyResult: - model_string, kwargs = to_litellm(conn, model.model_id) - try: - await litellm.acompletion( - model=model_string, - messages=[{"role": "user", "content": "Hello"}], - timeout=TEST_TIMEOUT_SECONDS, - **kwargs, - ) - except Exception as exc: - return _model_test_error(conn, model.model_id, exc) - - model.supports_chat = True - return VerifyResult("OK", True, "Model test succeeded.") - - -__all__ = [ - "ModelDiscoveryError", - "VerifyResult", - "derive_capabilities", - "discover_models", - "test_model", - "verify_connection", -] diff --git a/surfsense_backend/app/services/model_list_service.py b/surfsense_backend/app/services/model_list_service.py index ffb430756..33837a8a0 100644 --- a/surfsense_backend/app/services/model_list_service.py +++ b/surfsense_backend/app/services/model_list_service.py @@ -12,8 +12,6 @@ from pathlib import Path import httpx -from app.services.openrouter_model_normalizer import normalize_openrouter_models - logger = logging.getLogger(__name__) OPENROUTER_API_URL = "https://openrouter.ai/api/v1/models" @@ -24,7 +22,7 @@ CACHE_TTL_SECONDS = 86400 # 24 hours _cache: list[dict] | None = None _cache_timestamp: float = 0 -# Maps OpenRouter provider slug to native LiteLLM provider prefixes. +# Maps OpenRouter provider slug → our LiteLLMProvider enum value. # Only providers where the model-name part (after the slash) can be # used directly with the native provider's litellm prefix are listed. # @@ -123,13 +121,26 @@ def _process_models(raw_models: list[dict]) -> list[dict]: """ processed: list[dict] = [] - for normalized in normalize_openrouter_models(raw_models): - model_id: str = normalized["model_id"] - name: str = normalized.get("display_name") or model_id - context_length = normalized.get("max_input_tokens") + for model in raw_models: + model_id: str = model.get("id", "") + name: str = model.get("name", "") + context_length = model.get("context_length") + if "/" not in model_id: continue + if not _is_text_output_model(model): + continue + + if not _supports_tool_calling(model): + continue + + if not _has_sufficient_context(model): + continue + + if not _is_allowed_model(model): + continue + provider_slug, model_name = model_id.split("/", 1) context_window = _format_context_length(context_length) @@ -143,19 +154,19 @@ def _process_models(raw_models: list[dict]) -> list[dict]: } ) - # 2) Emit for the direct provider when we have a mapping - direct_provider = OPENROUTER_SLUG_TO_PROVIDER.get(provider_slug) - if direct_provider: + # 2) Emit for the native provider when we have a mapping + native_provider = OPENROUTER_SLUG_TO_PROVIDER.get(provider_slug) + if native_provider: # Google's Gemini API only serves gemini-* models. # Open-source models like gemma-* are NOT available through it. - if direct_provider == "GOOGLE" and not model_name.startswith("gemini-"): + if native_provider == "GOOGLE" and not model_name.startswith("gemini-"): continue processed.append( { "value": model_name, "label": name, - "provider": direct_provider, + "provider": native_provider, "context_window": context_window, } ) diff --git a/surfsense_backend/app/services/model_resolver.py b/surfsense_backend/app/services/model_resolver.py deleted file mode 100644 index 628c9f473..000000000 --- a/surfsense_backend/app/services/model_resolver.py +++ /dev/null @@ -1,94 +0,0 @@ -"""Single model-to-LiteLLM resolver. - -All chat, vision, image-generation, validation, and Auto routing paths should -turn a Connection + Model into LiteLLM input through this module. -""" - -from __future__ import annotations - -from collections.abc import Mapping -from typing import TYPE_CHECKING, Any - -if TYPE_CHECKING: - from app.db import Connection - -from app.services.provider_registry import Transport, spec_for - - -def ensure_v1(base_url: str | None) -> str | None: - if not base_url: - return None - stripped = base_url.rstrip("/") - if stripped.endswith("/v1"): - return stripped - return f"{stripped}/v1" - - -def _conn_value(conn: Connection | Mapping[str, Any], key: str) -> Any: - if isinstance(conn, Mapping): - return conn.get(key) - return getattr(conn, key) - - -def to_litellm( - conn: Connection | Mapping[str, Any], - model_id: str, -) -> tuple[str, dict[str, Any]]: - """Return ``(model_string, litellm_kwargs)`` for any model role.""" - provider = _conn_value(conn, "provider") - base_url = _conn_value(conn, "base_url") - api_key = _conn_value(conn, "api_key") - extra = _conn_value(conn, "extra") or {} - spec = spec_for(provider) - - kwargs: dict[str, Any] = {} - if api_key: - kwargs["api_key"] = api_key - - prefix = spec.litellm_prefix or str(provider) - model_string = f"{prefix}/{model_id}" if prefix else model_id - if base_url: - api_base = ( - ensure_v1(base_url) - if spec.transport == Transport.OPENAI_COMPATIBLE - else base_url.rstrip("/") - ) - kwargs["api_base"] = api_base - - if api_version := extra.get("api_version"): - kwargs["api_version"] = api_version - kwargs.update(extra.get("litellm_params", {})) - kwargs.update(extra.get("kwargs", {})) - if provider == "bedrock" and ( - bearer_token := kwargs.pop("aws_bearer_token_bedrock", None) - ): - kwargs["api_key"] = bearer_token - return model_string, kwargs - - -def native_connection_from_config(config: Mapping[str, Any]) -> dict[str, Any]: - """Build an in-memory connection mapping from a global config.""" - provider = str( - config.get("provider") - or config.get("litellm_provider") - or config.get("custom_provider") - or "openai" - ) - extra: dict[str, Any] = { - "litellm_params": config.get("litellm_params") or {}, - } - if config.get("api_version"): - extra["api_version"] = config.get("api_version") - return { - "provider": provider, - "base_url": config.get("api_base") or None, - "api_key": config.get("api_key") or None, - "extra": extra, - } - - -__all__ = [ - "ensure_v1", - "native_connection_from_config", - "to_litellm", -] diff --git a/surfsense_backend/app/services/obsidian_plugin_indexer.py b/surfsense_backend/app/services/obsidian_plugin_indexer.py index cd05d7935..13f43d1ee 100644 --- a/surfsense_backend/app/services/obsidian_plugin_indexer.py +++ b/surfsense_backend/app/services/obsidian_plugin_indexer.py @@ -199,12 +199,11 @@ async def _extract_binary_attachment_markdown( async def _run_etl_extract(*, file_path: str, filename: str, vision_llm): """Lazy-load ETL dependencies to avoid module-import cycles.""" - from app.etl_pipeline.cache import extract_with_cache from app.etl_pipeline.etl_document import EtlRequest + from app.etl_pipeline.etl_pipeline_service import EtlPipelineService - return await extract_with_cache( - EtlRequest(file_path=file_path, filename=filename), - vision_llm=vision_llm, + return await EtlPipelineService(vision_llm=vision_llm).extract( + EtlRequest(file_path=file_path, filename=filename) ) diff --git a/surfsense_backend/app/services/openrouter_integration_service.py b/surfsense_backend/app/services/openrouter_integration_service.py index 17b8c10eb..6454e2d58 100644 --- a/surfsense_backend/app/services/openrouter_integration_service.py +++ b/surfsense_backend/app/services/openrouter_integration_service.py @@ -19,10 +19,6 @@ from typing import Any import httpx -from app.services.openrouter_model_normalizer import ( - is_openrouter_image_model, - normalize_openrouter_models, -) from app.services.quality_score import ( _HEALTH_BLEND_WEIGHT, _HEALTH_ENRICH_CONCURRENCY, @@ -278,7 +274,7 @@ def _generate_configs( OpenRouter's own ``openrouter/free`` meta-router is filtered out upstream via ``_EXCLUDED_MODEL_IDS``; we don't expose a redundant auto-select layer - because our own Auto pin + 24 h refresh + repair logic already + because our own Auto (Fastest) pin + 24 h refresh + repair logic already cover the catalogue-churn case. """ id_offset: int = settings.get("id_offset", -10000) @@ -296,16 +292,24 @@ def _generate_configs( use_default: bool = settings.get("use_default_system_instructions", True) citations_enabled: bool = settings.get("citations_enabled", True) - text_models = normalize_openrouter_models(raw_models) + text_models = [ + m + for m in raw_models + if _is_text_output_model(m) + and _supports_tool_calling(m) + and _has_sufficient_context(m) + and _is_compatible_provider(m) + and _is_allowed_model(m) + and "/" in m.get("id", "") + ] configs: list[dict] = [] taken: set[int] = set() now_ts = int(time.time()) - for normalized in text_models: - model = normalized.get("metadata") or {} - model_id: str = normalized["model_id"] - name: str = normalized.get("display_name") or model_id + for model in text_models: + model_id: str = model["id"] + name: str = model.get("name", model_id) tier = _openrouter_tier(model) static_q = static_score_or(model, now_ts=now_ts) @@ -319,10 +323,10 @@ def _generate_configs( "seo_enabled": seo_enabled, "seo_slug": None, "quota_reserve_tokens": quota_reserve_tokens, - "provider": "openrouter", + "provider": "OPENROUTER", "model_name": model_id, "api_key": api_key, - "api_base": "https://openrouter.ai/api/v1", + "api_base": "", "rpm": free_rpm if tier == "free" else rpm, "tpm": free_tpm if tier == "free" else tpm, "litellm_params": dict(litellm_params), @@ -341,9 +345,9 @@ def _generate_configs( # ``stream_new_chat`` as a fail-fast safety net before the # OpenRouter request would otherwise 404 with # ``"No endpoints found that support image input"``. - "supports_image_input": bool(normalized.get("supports_image_input")), + "supports_image_input": _supports_image_input(model), _OPENROUTER_DYNAMIC_MARKER: True, - # Auto ranking metadata. ``quality_score`` is initialised + # Auto (Fastest) ranking metadata. ``quality_score`` is initialised # to the static score and gets re-blended with health on the next # ``_enrich_health`` pass (synchronous on refresh, deferred on cold # start so startup latency is unchanged). @@ -358,7 +362,11 @@ def _generate_configs( return configs +# ID-offset bands used to keep dynamic OpenRouter configs in their own +# namespace per surface. Image / vision get separate bands so a single +# Postgres-INTEGER cfg ID is unambiguous about which selector it belongs to. _OPENROUTER_IMAGE_ID_OFFSET_DEFAULT = -20000 +_OPENROUTER_VISION_ID_OFFSET_DEFAULT = -30000 def _generate_image_gen_configs( @@ -392,7 +400,14 @@ def _generate_image_gen_configs( free_rpm: int = settings.get("free_rpm", 20) litellm_params: dict = settings.get("litellm_params") or {} - image_models = [m for m in raw_models if is_openrouter_image_model(m)] + image_models = [ + m + for m in raw_models + if _is_image_output_model(m) + and _is_compatible_provider(m) + and _is_allowed_model(m) + and "/" in m.get("id", "") + ] configs: list[dict] = [] taken: set[int] = set() @@ -405,9 +420,14 @@ def _generate_image_gen_configs( "id": _stable_config_id(model_id, id_offset, taken), "name": name, "description": f"{name} via OpenRouter (image generation)", - "provider": "openrouter", + "provider": "OPENROUTER", "model_name": model_id, "api_key": api_key, + # Pin to OpenRouter's public base URL so a downstream call site + # that forgets ``resolve_api_base`` still doesn't inherit + # ``AZURE_OPENAI_ENDPOINT`` and 404 on + # ``image_generation/transformation`` (defense-in-depth, see + # ``provider_api_base`` docstring). "api_base": "https://openrouter.ai/api/v1", "api_version": None, "rpm": free_rpm if tier == "free" else rpm, @@ -420,6 +440,93 @@ def _generate_image_gen_configs( return configs +def _generate_vision_llm_configs( + raw_models: list[dict], settings: dict[str, Any] +) -> list[dict]: + """Convert OpenRouter vision-capable LLMs into global vision-LLM config + dicts (matches the YAML shape consumed by ``vision_llm_routes``). + + Filter: + - architecture.input_modalities contains "image" + - architecture.output_modalities contains "text" + - compatible provider (excluded slugs blocked) + - allowed model id (excluded list blocked) + + Vision-LLM is invoked from the indexer (image extraction during + document upload) via ``langchain_litellm.ChatLiteLLM.ainvoke``, so + the chat-only ``_supports_tool_calling`` and ``_has_sufficient_context`` + filters do not apply: a small-context vision model that doesn't + advertise tool-calling is still perfectly viable for "describe this + image" prompts. + """ + id_offset: int = int( + settings.get("vision_id_offset") or _OPENROUTER_VISION_ID_OFFSET_DEFAULT + ) + api_key: str = settings.get("api_key", "") + rpm: int = settings.get("rpm", 200) + tpm: int = settings.get("tpm", 1_000_000) + free_rpm: int = settings.get("free_rpm", 20) + free_tpm: int = settings.get("free_tpm", 100_000) + quota_reserve_tokens: int = settings.get("quota_reserve_tokens", 4000) + litellm_params: dict = settings.get("litellm_params") or {} + + vision_models = [ + m + for m in raw_models + if _is_vision_input_model(m) + and _is_compatible_provider(m) + and _is_allowed_model(m) + and "/" in m.get("id", "") + ] + + configs: list[dict] = [] + taken: set[int] = set() + for model in vision_models: + model_id: str = model["id"] + name: str = model.get("name", model_id) + tier = _openrouter_tier(model) + pricing = model.get("pricing") or {} + + # Capture per-token prices so ``pricing_registration`` can + # register them with LiteLLM at startup (and so the cost + # estimator in ``estimate_call_reserve_micros`` can resolve + # them at reserve time). + try: + input_cost = float(pricing.get("prompt", 0) or 0) + except (TypeError, ValueError): + input_cost = 0.0 + try: + output_cost = float(pricing.get("completion", 0) or 0) + except (TypeError, ValueError): + output_cost = 0.0 + + cfg: dict[str, Any] = { + "id": _stable_config_id(model_id, id_offset, taken), + "name": name, + "description": f"{name} via OpenRouter (vision)", + "provider": "OPENROUTER", + "model_name": model_id, + "api_key": api_key, + # Pin to OpenRouter's public base URL so a downstream call site + # that forgets ``resolve_api_base`` still doesn't inherit + # ``AZURE_OPENAI_ENDPOINT`` (defense-in-depth, see + # ``provider_api_base`` docstring). + "api_base": "https://openrouter.ai/api/v1", + "api_version": None, + "rpm": free_rpm if tier == "free" else rpm, + "tpm": free_tpm if tier == "free" else tpm, + "litellm_params": dict(litellm_params), + "billing_tier": tier, + "quota_reserve_tokens": quota_reserve_tokens, + "input_cost_per_token": input_cost or None, + "output_cost_per_token": output_cost or None, + _OPENROUTER_DYNAMIC_MARKER: True, + } + configs.append(cfg) + + return configs + + class OpenRouterIntegrationService: """Singleton that manages the dynamic OpenRouter model catalogue.""" @@ -446,9 +553,11 @@ class OpenRouterIntegrationService: # Cached raw catalogue from the most recent fetch. Image / vision # emitters reuse this to avoid a second network call per surface. self._raw_models: list[dict] = [] - # Image config cache (only populated when the matching opt-in flag is - # true on initialize). Refreshed in lockstep with the chat catalogue. + # Image / vision config caches (only populated when the matching + # opt-in flag is true on initialize). Refreshed in lockstep with + # the chat catalogue. self._image_configs: list[dict] = [] + self._vision_configs: list[dict] = [] @classmethod def get_instance(cls) -> "OpenRouterIntegrationService": @@ -483,7 +592,7 @@ class OpenRouterIntegrationService: self._configs_by_id = {c["id"]: c for c in self._configs} self._raw_pricing = _extract_raw_pricing(raw_models) - # Populate image cache when its opt-in flag is set. + # Populate image / vision caches when their opt-in flag is set. # Empty otherwise so the accessors return [] without re-running # filters every refresh. if settings.get("image_generation_enabled"): @@ -495,6 +604,15 @@ class OpenRouterIntegrationService: else: self._image_configs = [] + if settings.get("vision_enabled"): + self._vision_configs = _generate_vision_llm_configs(raw_models, settings) + logger.info( + "OpenRouter integration: vision LLM emission ON (%d models)", + len(self._vision_configs), + ) + else: + self._vision_configs = [] + self._initialized = True tier_counts = self._tier_counts(self._configs) @@ -548,9 +666,9 @@ class OpenRouterIntegrationService: self._configs = new_configs self._configs_by_id = new_by_id - # Image list is atomic-swapped the same way: filter out + # Image / vision lists are atomic-swapped the same way: filter out # the previous dynamic entries from the live config list and append - # the freshly generated ones. No-op when the opt-in flag is off. + # the freshly generated ones. No-ops when the opt-in flag is off. if self._settings.get("image_generation_enabled"): new_image = _generate_image_gen_configs(raw_models, self._settings) static_image = [ @@ -561,6 +679,16 @@ class OpenRouterIntegrationService: app_config.GLOBAL_IMAGE_GEN_CONFIGS = static_image + new_image self._image_configs = new_image + if self._settings.get("vision_enabled"): + new_vision = _generate_vision_llm_configs(raw_models, self._settings) + static_vision = [ + c + for c in app_config.GLOBAL_VISION_LLM_CONFIGS + if not c.get(_OPENROUTER_DYNAMIC_MARKER) + ] + app_config.GLOBAL_VISION_LLM_CONFIGS = static_vision + new_vision + self._vision_configs = new_vision + # Catalogue churn invalidates per-config "recently healthy" credit # earned by the previous turn's preflight. Drop the whole table so # the next turn re-probes against the freshly loaded configs. @@ -582,7 +710,7 @@ class OpenRouterIntegrationService: ) # Re-blend health scores against the freshly fetched catalogue. Also - # re-stamps health for any YAML-curated cfg with provider=openrouter + # re-stamps health for any YAML-curated cfg with provider==OPENROUTER # so a hand-picked dead OR model is gated like a dynamic one. await self._enrich_health_safely(static_configs + new_configs, log_summary=True) @@ -630,7 +758,7 @@ class OpenRouterIntegrationService: return counts # ------------------------------------------------------------------ - # Auto health enrichment + # Auto (Fastest) health enrichment # ------------------------------------------------------------------ async def _enrich_health_safely( @@ -659,7 +787,7 @@ class OpenRouterIntegrationService: the entire previous cycle's cache for this run. """ or_cfgs = [ - c for c in configs if str(c.get("provider", "")).lower() == "openrouter" + c for c in configs if str(c.get("provider", "")).upper() == "OPENROUTER" ] if not or_cfgs: return @@ -840,6 +968,17 @@ class OpenRouterIntegrationService: """ return list(self._image_configs) + def get_vision_llm_configs(self) -> list[dict]: + """Return the dynamic OpenRouter vision-LLM configs (empty list + when the ``vision_enabled`` flag is off). + + Each entry exposes ``input_cost_per_token`` / ``output_cost_per_token`` + so ``pricing_registration`` can teach LiteLLM the cost of these + models the same way it does for chat — which keeps the billable + wrapper able to debit accurate micro-USD on a vision call. + """ + return list(self._vision_configs) + def get_raw_pricing(self) -> dict[str, dict[str, str]]: """Return the cached raw OpenRouter pricing map. diff --git a/surfsense_backend/app/services/openrouter_model_normalizer.py b/surfsense_backend/app/services/openrouter_model_normalizer.py deleted file mode 100644 index 5998a2f1f..000000000 --- a/surfsense_backend/app/services/openrouter_model_normalizer.py +++ /dev/null @@ -1,123 +0,0 @@ -"""Shared OpenRouter model normalization. - -OpenRouter metadata is richer than generic OpenAI-compatible ``/models`` -responses. Keep all OpenRouter filtering and capability extraction here so -GLOBAL catalogue generation and BYOK discovery agree. -""" - -from __future__ import annotations - -from typing import Any - -from app.db import ModelSource - -MIN_CONTEXT_LENGTH = 100_000 - -EXCLUDED_PROVIDER_SLUGS = {"amazon"} -EXCLUDED_MODEL_IDS: set[str] = { - "openai/gpt-4-1106-preview", - "openai/gpt-4-turbo-preview", - "openai/gpt-4o:extended", - "arcee-ai/virtuoso-large", - "openai/o3-deep-research", - "openai/o4-mini-deep-research", - "openrouter/free", -} -EXCLUDED_MODEL_SUFFIXES: tuple[str, ...] = ("-deep-research",) - - -def is_text_output_model(model: dict[str, Any]) -> bool: - output_mods = model.get("architecture", {}).get("output_modalities", []) - return output_mods == ["text"] - - -def is_image_output_model(model: dict[str, Any]) -> bool: - output_mods = model.get("architecture", {}).get("output_modalities", []) or [] - return "image" in output_mods - - -def supports_image_input(model: dict[str, Any]) -> bool: - input_mods = model.get("architecture", {}).get("input_modalities", []) or [] - return "image" in input_mods - - -def supports_tool_calling(model: dict[str, Any]) -> bool: - supported = model.get("supported_parameters") or [] - return "tools" in supported - - -def has_sufficient_context(model: dict[str, Any]) -> bool: - return int(model.get("context_length") or 0) >= MIN_CONTEXT_LENGTH - - -def is_compatible_provider(model: dict[str, Any]) -> bool: - model_id = str(model.get("id") or "") - slug = model_id.split("/", 1)[0] if "/" in model_id else "" - return slug not in EXCLUDED_PROVIDER_SLUGS - - -def is_allowed_model(model: dict[str, Any]) -> bool: - model_id = str(model.get("id") or "") - if model_id in EXCLUDED_MODEL_IDS: - return False - base_id = model_id.split(":")[0] - return not base_id.endswith(EXCLUDED_MODEL_SUFFIXES) - - -def is_openrouter_chat_model(model: dict[str, Any]) -> bool: - return ( - "/" in str(model.get("id") or "") - and is_text_output_model(model) - and supports_tool_calling(model) - and has_sufficient_context(model) - and is_compatible_provider(model) - and is_allowed_model(model) - ) - - -def is_openrouter_image_model(model: dict[str, Any]) -> bool: - return ( - "/" in str(model.get("id") or "") - and is_image_output_model(model) - and is_compatible_provider(model) - and is_allowed_model(model) - ) - - -def normalize_openrouter_models( - raw_models: list[dict[str, Any]], -) -> list[dict[str, Any]]: - normalized: list[dict[str, Any]] = [] - for model in raw_models: - if not is_openrouter_chat_model(model): - continue - model_id = str(model.get("id") or "") - normalized.append( - { - "model_id": model_id, - "display_name": model.get("name") or model_id, - "source": ModelSource.DISCOVERED, - "supports_chat": True, - "max_input_tokens": model.get("context_length"), - "supports_image_input": supports_image_input(model), - "supports_tools": supports_tool_calling(model), - "supports_image_generation": False, - "metadata": model, - } - ) - return normalized - - -__all__ = [ - "MIN_CONTEXT_LENGTH", - "has_sufficient_context", - "is_allowed_model", - "is_compatible_provider", - "is_image_output_model", - "is_openrouter_chat_model", - "is_openrouter_image_model", - "is_text_output_model", - "normalize_openrouter_models", - "supports_image_input", - "supports_tool_calling", -] diff --git a/surfsense_backend/app/services/pricing_registration.py b/surfsense_backend/app/services/pricing_registration.py index 7343df737..de98e50c2 100644 --- a/surfsense_backend/app/services/pricing_registration.py +++ b/surfsense_backend/app/services/pricing_registration.py @@ -143,19 +143,21 @@ def _register_chat_shape_configs( sample_keys: list[str] = [] for cfg in configs: - provider = str(cfg.get("provider") or cfg.get("litellm_provider") or "").lower() + provider = str(cfg.get("provider") or "").upper() model_name = str(cfg.get("model_name") or "").strip() litellm_params = cfg.get("litellm_params") or {} base_model = str(litellm_params.get("base_model") or model_name).strip() - if provider == "openrouter": + if provider == "OPENROUTER": entry = or_pricing.get(model_name) if entry: input_cost = _safe_float(entry.get("prompt")) output_cost = _safe_float(entry.get("completion")) else: - # Some dynamically materialized configs can carry pricing - # inline when the raw OpenRouter cache has no matching entry. + # Vision configs from ``_generate_vision_llm_configs`` + # carry their pricing inline because the OpenRouter + # raw-pricing cache is keyed by chat-catalogue model_id; + # vision flows pick up the inline values here. input_cost = _safe_float(cfg.get("input_cost_per_token")) output_cost = _safe_float(cfg.get("output_cost_per_token")) if input_cost == 0.0 and output_cost == 0.0: @@ -187,11 +189,12 @@ def _register_chat_shape_configs( skipped_no_pricing += 1 continue aliases = _alias_set_for_yaml(provider, model_name, base_model) + provider_slug = "azure" if provider == "AZURE_OPENAI" else provider.lower() count = _register( aliases, input_cost=input_cost, output_cost=output_cost, - provider=provider, + provider=provider_slug, ) if count > 0: registered_models += 1 @@ -214,8 +217,9 @@ def _register_chat_shape_configs( def register_pricing_from_global_configs() -> None: """Register pricing for every known LLM deployment with LiteLLM. - Walks ``config.GLOBAL_LLM_CONFIGS`` so chat and vision calls can resolve - cost from the same chat-shaped deployment configs: + Walks ``config.GLOBAL_LLM_CONFIGS`` *and* ``config.GLOBAL_VISION_LLM_CONFIGS`` + so vision calls (during indexing) can resolve cost the same way chat + calls do — namely: 1. ``OPENROUTER``: pulls the cached raw pricing from ``OpenRouterIntegrationService`` (populated during its own @@ -242,7 +246,10 @@ def register_pricing_from_global_configs() -> None: from app.config import config as app_config chat_configs: list[dict] = list(getattr(app_config, "GLOBAL_LLM_CONFIGS", []) or []) - if not chat_configs: + vision_configs: list[dict] = list( + getattr(app_config, "GLOBAL_VISION_LLM_CONFIGS", []) or [] + ) + if not chat_configs and not vision_configs: logger.info("[PricingRegistration] no global configs to register") return @@ -261,3 +268,7 @@ def register_pricing_from_global_configs() -> None: if chat_configs: _register_chat_shape_configs(chat_configs, or_pricing=or_pricing, label="chat") + if vision_configs: + _register_chat_shape_configs( + vision_configs, or_pricing=or_pricing, label="vision" + ) diff --git a/surfsense_backend/app/services/provider_api_base.py b/surfsense_backend/app/services/provider_api_base.py new file mode 100644 index 000000000..dca1f9462 --- /dev/null +++ b/surfsense_backend/app/services/provider_api_base.py @@ -0,0 +1,106 @@ +"""Provider-aware ``api_base`` resolution shared by chat / image-gen / vision. + +LiteLLM falls back to the module-global ``litellm.api_base`` when an +individual call doesn't pass one, which silently inherits provider-agnostic +env vars like ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE``. Without an +explicit ``api_base``, an ``openrouter/`` request can end up at an +Azure endpoint and 404 with ``Resource not found`` (real reproducer: +[litellm/llms/openrouter/image_generation/transformation.py:242-263] appends +``/chat/completions`` to whatever inherited base it gets, regardless of +provider). + +The chat router has had this defense for a while +(``llm_router_service.py:466-478``). This module hoists the maps + cascade +into a tiny standalone helper so vision and image-gen can share the same +source of truth without an inter-service circular import. +""" + +from __future__ import annotations + +PROVIDER_DEFAULT_API_BASE: dict[str, str] = { + "openrouter": "https://openrouter.ai/api/v1", + "groq": "https://api.groq.com/openai/v1", + "mistral": "https://api.mistral.ai/v1", + "perplexity": "https://api.perplexity.ai", + "xai": "https://api.x.ai/v1", + "cerebras": "https://api.cerebras.ai/v1", + "deepinfra": "https://api.deepinfra.com/v1/openai", + "fireworks_ai": "https://api.fireworks.ai/inference/v1", + "together_ai": "https://api.together.xyz/v1", + "anyscale": "https://api.endpoints.anyscale.com/v1", + "cometapi": "https://api.cometapi.com/v1", + "sambanova": "https://api.sambanova.ai/v1", +} +"""Default ``api_base`` per LiteLLM provider prefix (lowercase). + +Only providers with a well-known, stable public base URL are listed — +self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai, +huggingface, databricks, cloudflare, replicate) are intentionally omitted +so their existing config-driven behaviour is preserved.""" + + +PROVIDER_KEY_DEFAULT_API_BASE: dict[str, str] = { + "DEEPSEEK": "https://api.deepseek.com/v1", + "ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1", + "MOONSHOT": "https://api.moonshot.ai/v1", + "ZHIPU": "https://open.bigmodel.cn/api/paas/v4", + "MINIMAX": "https://api.minimax.io/v1", +} +"""Canonical provider key (uppercase) → base URL. + +Used when the LiteLLM provider prefix is the generic ``openai`` shim but the +config's ``provider`` field tells us which API it actually is (DeepSeek, +Alibaba, Moonshot, Zhipu, MiniMax all use the ``openai`` prefix but each +has its own base URL).""" + + +def resolve_api_base( + *, + provider: str | None, + provider_prefix: str | None, + config_api_base: str | None, +) -> str | None: + """Resolve a non-Azure-leaking ``api_base`` for a deployment. + + Cascade (first non-empty wins): + 1. The config's own ``api_base`` (whitespace-only treated as missing). + 2. ``PROVIDER_KEY_DEFAULT_API_BASE[provider.upper()]``. + 3. ``PROVIDER_DEFAULT_API_BASE[provider_prefix.lower()]``. + 4. ``None`` — caller should NOT set ``api_base`` and let the LiteLLM + provider integration apply its own default (e.g. AzureOpenAI's + deployment-derived URL, custom provider's per-deployment URL). + + Args: + provider: The config's ``provider`` field (e.g. ``"OPENROUTER"``, + ``"DEEPSEEK"``). Case-insensitive. + provider_prefix: The LiteLLM model-string prefix the same call + site builds for the model id (e.g. ``"openrouter"``, + ``"groq"``). Case-insensitive. + config_api_base: ``api_base`` from the global YAML / DB row / + OpenRouter dynamic config. Empty / whitespace-only means + "missing" — the resolver still applies the cascade. + + Returns: + A URL string, or ``None`` if no default applies for this provider. + """ + if config_api_base and config_api_base.strip(): + return config_api_base + + if provider: + key_default = PROVIDER_KEY_DEFAULT_API_BASE.get(provider.upper()) + if key_default: + return key_default + + if provider_prefix: + prefix_default = PROVIDER_DEFAULT_API_BASE.get(provider_prefix.lower()) + if prefix_default: + return prefix_default + + return None + + +__all__ = [ + "PROVIDER_DEFAULT_API_BASE", + "PROVIDER_KEY_DEFAULT_API_BASE", + "resolve_api_base", +] diff --git a/surfsense_backend/app/services/provider_capabilities.py b/surfsense_backend/app/services/provider_capabilities.py index fae283ab6..f094c9954 100644 --- a/surfsense_backend/app/services/provider_capabilities.py +++ b/surfsense_backend/app/services/provider_capabilities.py @@ -49,6 +49,51 @@ import litellm logger = logging.getLogger(__name__) +# Provider-name → LiteLLM model-prefix map. +# +# Owned here because ``app.services.provider_capabilities`` is the +# only edge that's safe to call from ``app.config``'s YAML loader at +# class-body init time. ``app.agents.chat.runtime.llm_config`` re-exports +# this constant under the historical ``PROVIDER_MAP`` name; placing the +# map there directly would re-introduce the +# ``app.config -> ... -> deliverables/tools/generate_image -> +# app.config`` cycle that prompted the move. +_PROVIDER_PREFIX_MAP: dict[str, str] = { + "OPENAI": "openai", + "ANTHROPIC": "anthropic", + "GROQ": "groq", + "COHERE": "cohere", + "GOOGLE": "gemini", + "OLLAMA": "ollama_chat", + "MISTRAL": "mistral", + "AZURE_OPENAI": "azure", + "OPENROUTER": "openrouter", + "XAI": "xai", + "BEDROCK": "bedrock", + "VERTEX_AI": "vertex_ai", + "TOGETHER_AI": "together_ai", + "FIREWORKS_AI": "fireworks_ai", + "DEEPSEEK": "openai", + "ALIBABA_QWEN": "openai", + "MOONSHOT": "openai", + "ZHIPU": "openai", + "GITHUB_MODELS": "github", + "REPLICATE": "replicate", + "PERPLEXITY": "perplexity", + "ANYSCALE": "anyscale", + "DEEPINFRA": "deepinfra", + "CEREBRAS": "cerebras", + "SAMBANOVA": "sambanova", + "AI21": "ai21", + "CLOUDFLARE": "cloudflare", + "DATABRICKS": "databricks", + "COMETAPI": "cometapi", + "HUGGINGFACE": "huggingface", + "MINIMAX": "openai", + "CUSTOM": "custom", +} + + def _candidate_model_strings( *, provider: str | None, @@ -78,7 +123,12 @@ def _candidate_model_strings( seen.add(key) candidates.append(key) - provider_prefix = custom_provider or provider + provider_prefix: str | None = None + if provider: + provider_prefix = _PROVIDER_PREFIX_MAP.get(provider.upper(), provider.lower()) + if custom_provider: + # ``custom_provider`` overrides everything for CUSTOM/proxy setups. + provider_prefix = custom_provider primary_model = base_model or model_name bare_model = model_name diff --git a/surfsense_backend/app/services/provider_registry.py b/surfsense_backend/app/services/provider_registry.py deleted file mode 100644 index 67d1c4db4..000000000 --- a/surfsense_backend/app/services/provider_registry.py +++ /dev/null @@ -1,126 +0,0 @@ -"""Provider registry for model connections. - -The provider string is the single public identity axis. This registry only -describes providers whose behavior differs from LiteLLM's native default. -""" - -from __future__ import annotations - -from dataclasses import dataclass -from enum import StrEnum -from typing import Literal - - -class Transport(StrEnum): - NATIVE = "NATIVE" - OPENAI_COMPATIBLE = "OPENAI_COMPATIBLE" - OLLAMA = "OLLAMA" - - -DiscoveryKind = Literal[ - "ollama", - "openai_models", - "anthropic_models", - "bedrock_models", - "openrouter", - "static", - "none", -] - -AuthStyle = Literal["bearer", "x-api-key", "none", "native"] - - -@dataclass(frozen=True) -class ProviderSpec: - transport: Transport - litellm_prefix: str | None - discovery: DiscoveryKind - default_base_url: str | None - base_url_required: bool - auth_style: AuthStyle - display_name: str | None = None - - -REGISTRY: dict[str, ProviderSpec] = { - "openai": ProviderSpec( - Transport.NATIVE, "openai", "openai_models", None, False, "bearer", "OpenAI" - ), - "anthropic": ProviderSpec( - Transport.NATIVE, - "anthropic", - "anthropic_models", - None, - False, - "x-api-key", - "Anthropic", - ), - "azure": ProviderSpec( - Transport.NATIVE, "azure", "static", None, True, "native", "Azure OpenAI" - ), - "vertex_ai": ProviderSpec( - Transport.NATIVE, "vertex_ai", "static", None, False, "native", "Vertex AI" - ), - "bedrock": ProviderSpec( - Transport.NATIVE, - "bedrock", - "bedrock_models", - None, - False, - "native", - "Amazon Bedrock", - ), - "openrouter": ProviderSpec( - Transport.OPENAI_COMPATIBLE, - "openrouter", - "openrouter", - "https://openrouter.ai/api/v1", - False, - "bearer", - "OpenRouter", - ), - "openai_compatible": ProviderSpec( - Transport.OPENAI_COMPATIBLE, - "openai", - "openai_models", - None, - True, - "bearer", - "OpenAI-compatible provider", - ), - "lm_studio": ProviderSpec( - Transport.OPENAI_COMPATIBLE, - "openai", - "openai_models", - "http://host.docker.internal:1234/v1", - True, - "bearer", - "LM Studio", - ), - "ollama_chat": ProviderSpec( - Transport.OLLAMA, - "ollama_chat", - "ollama", - "http://host.docker.internal:11434", - True, - "none", - "Ollama", - ), -} - - -def spec_for(provider: str | None) -> ProviderSpec: - provider_key = (provider or "").strip() - return REGISTRY.get(provider_key) or ProviderSpec( - Transport.NATIVE, provider_key or "openai", "static", None, False, "native" - ) - - -def provider_label(provider: str | None) -> str: - provider_key = (provider or "").strip() - spec = spec_for(provider_key) - if spec.display_name: - return spec.display_name - return provider_key.replace("_", " ").title() if provider_key else "Provider" - - -__all__ = ["REGISTRY", "ProviderSpec", "Transport", "provider_label", "spec_for"] diff --git a/surfsense_backend/app/services/quality_score.py b/surfsense_backend/app/services/quality_score.py index 737dd7c2f..2fb37de21 100644 --- a/surfsense_backend/app/services/quality_score.py +++ b/surfsense_backend/app/services/quality_score.py @@ -1,4 +1,4 @@ -"""Pure-function quality scoring for Auto model selection. +"""Pure-function quality scoring for Auto (Fastest) model selection. This module is import-free of any service / request-path dependencies. All numbers are computed once during the OpenRouter refresh tick (or YAML load) @@ -108,23 +108,25 @@ PROVIDER_PRESTIGE_OR: dict[str, int] = { # YAML provider field (the upstream API shape the operator selected). PROVIDER_PRESTIGE_YAML: dict[str, int] = { - "azure": 50, - "openai": 50, - "anthropic": 50, - "gemini": 50, - "vertex_ai": 50, - "xai": 50, - "mistral": 38, - "deepseek": 38, - "cohere": 38, - "groq": 30, - "together_ai": 28, - "fireworks_ai": 28, - "perplexity": 28, - "bedrock": 28, - "openrouter": 25, - "ollama_chat": 12, - "custom": 12, + "AZURE_OPENAI": 50, + "OPENAI": 50, + "ANTHROPIC": 50, + "GOOGLE": 50, + "VERTEX_AI": 50, + "GEMINI": 50, + "XAI": 50, + "MISTRAL": 38, + "DEEPSEEK": 38, + "COHERE": 38, + "GROQ": 30, + "TOGETHER_AI": 28, + "FIREWORKS_AI": 28, + "PERPLEXITY": 28, + "MINIMAX": 28, + "BEDROCK": 28, + "OPENROUTER": 25, + "OLLAMA": 12, + "CUSTOM": 12, } @@ -273,7 +275,7 @@ def static_score_yaml(cfg: dict) -> int: listed this model. Pricing / context fall through to lazy ``litellm`` lookups; failures are silent (we just lose those sub-points). """ - provider = str(cfg.get("provider") or cfg.get("litellm_provider") or "").lower() + provider = str(cfg.get("provider", "")).upper() base = PROVIDER_PRESTIGE_YAML.get(provider, 15) model_name = cfg.get("model_name") or "" diff --git a/surfsense_backend/app/services/revert_service.py b/surfsense_backend/app/services/revert_service.py index 0cb6cd092..6db5e2604 100644 --- a/surfsense_backend/app/services/revert_service.py +++ b/surfsense_backend/app/services/revert_service.py @@ -238,14 +238,9 @@ async def _restore_in_place_document( chunk_embeddings = await asyncio.to_thread(embed_texts, chunk_texts) session.add_all( [ - Chunk( - document_id=doc.id, - content=text, - embedding=embedding, - position=i, - ) - for i, (text, embedding) in enumerate( - zip(chunk_texts, chunk_embeddings, strict=True) + Chunk(document_id=doc.id, content=text, embedding=embedding) + for text, embedding in zip( + chunk_texts, chunk_embeddings, strict=True ) ] ) @@ -341,15 +336,8 @@ async def _reinsert_document_from_revision( chunk_embeddings = await asyncio.to_thread(embed_texts, chunk_texts) session.add_all( [ - Chunk( - document_id=new_doc.id, - content=text, - embedding=embedding, - position=i, - ) - for i, (text, embedding) in enumerate( - zip(chunk_texts, chunk_embeddings, strict=True) - ) + Chunk(document_id=new_doc.id, content=text, embedding=embedding) + for text, embedding in zip(chunk_texts, chunk_embeddings, strict=True) ] ) diff --git a/surfsense_backend/app/services/token_tracking_service.py b/surfsense_backend/app/services/token_tracking_service.py index d1a29b82a..3f07e6f9e 100644 --- a/surfsense_backend/app/services/token_tracking_service.py +++ b/surfsense_backend/app/services/token_tracking_service.py @@ -32,23 +32,6 @@ from app.db import TokenUsage logger = logging.getLogger(__name__) -def _bare_model_name(model: str) -> str: - """Return a model identifier with any provider routing prefix stripped. - - LiteLLM's ``get_llm_provider`` consumes the provider prefix we add in - ``to_litellm`` (e.g. ``azure/gpt-5.2-chat`` → ``gpt-5.2-chat`` because - ``azure`` is in ``litellm.provider_list``). The token-tracking success - callback therefore reports ``kwargs["model"]`` *without* that prefix, - while model metadata is registered under the *prefixed* string. Normalising - both sides to the last path segment lets the two reconcile so the per-model - breakdown carries provider/display_name and the UI attributes the turn to - the correct connection instead of falling back to a bare-name collision. - """ - if not model: - return model - return model.split("/")[-1] - - @dataclass class TokenCallRecord: model: str @@ -57,10 +40,6 @@ class TokenCallRecord: total_tokens: int cost_micros: int = 0 call_kind: str = "chat" - model_ref: str | None = None - model_id: str | None = None - display_name: str | None = None - provider: str | None = None @dataclass @@ -68,46 +47,6 @@ class TurnTokenAccumulator: """Accumulates token usage across all LLM calls within a single user turn.""" calls: list[TokenCallRecord] = field(default_factory=list) - model_metadata: dict[str, dict[str, str | None]] = field(default_factory=dict) - # Secondary index keyed by the bare model name (provider prefix stripped) so - # the LiteLLM callback — which never sees our routing prefix — can still - # reconcile its ``kwargs["model"]`` back to the registered metadata. - model_metadata_by_bare: dict[str, dict[str, str | None]] = field( - default_factory=dict - ) - - def register_model_metadata( - self, - *, - model: str, - model_ref: str | None, - model_id: str | None, - display_name: str | None, - provider: str | None, - ) -> None: - """Attach resolved model metadata for later LiteLLM callback attribution.""" - metadata = { - "model_ref": model_ref, - "model_id": model_id, - "display_name": display_name, - "provider": provider, - } - self.model_metadata[model] = metadata - # Index every reconcilable alias: the prefixed string's bare form and - # the resolved ``model_id`` (which for some providers is itself the bare - # deployment LiteLLM reports). Exact lookups always take precedence. - self.model_metadata_by_bare[_bare_model_name(model)] = metadata - if model_id: - self.model_metadata_by_bare.setdefault(_bare_model_name(model_id), metadata) - - def _lookup_metadata(self, model: str) -> dict[str, str | None]: - """Resolve registered metadata for a callback model, tolerating the - provider-prefix stripping LiteLLM applies before the success callback - fires (see :func:`_bare_model_name`).""" - exact = self.model_metadata.get(model) - if exact is not None: - return exact - return self.model_metadata_by_bare.get(_bare_model_name(model), {}) def add( self, @@ -118,14 +57,9 @@ class TurnTokenAccumulator: cost_micros: int = 0, call_kind: str = "chat", ) -> None: - metadata = self._lookup_metadata(model) self.calls.append( TokenCallRecord( model=model, - model_ref=metadata.get("model_ref"), - model_id=metadata.get("model_id"), - display_name=metadata.get("display_name"), - provider=metadata.get("provider"), prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, @@ -134,18 +68,13 @@ class TurnTokenAccumulator: ) ) - def per_message_summary(self) -> dict[str, dict[str, Any]]: + def per_message_summary(self) -> dict[str, dict[str, int]]: """Return token counts (and cost) grouped by model name.""" - by_model: dict[str, dict[str, Any]] = {} + by_model: dict[str, dict[str, int]] = {} for c in self.calls: entry = by_model.setdefault( c.model, { - "model": c.model, - "model_ref": c.model_ref, - "model_id": c.model_id, - "display_name": c.display_name, - "provider": c.provider, "prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0, @@ -213,27 +142,6 @@ def get_current_accumulator() -> TurnTokenAccumulator | None: return _turn_accumulator.get() -def register_model_usage_metadata( - *, - model: str, - model_ref: str | None, - model_id: str | None, - display_name: str | None, - provider: str | None, -) -> None: - """Register resolved model metadata with the current turn, if one exists.""" - acc = _turn_accumulator.get() - if acc is None: - return - acc.register_model_metadata( - model=model, - model_ref=model_ref, - model_id=model_id, - display_name=display_name, - provider=provider, - ) - - @asynccontextmanager async def scoped_turn() -> AsyncIterator[TurnTokenAccumulator]: """Async context manager that scopes a fresh ``TurnTokenAccumulator`` diff --git a/surfsense_backend/app/services/vision_llm_router_service.py b/surfsense_backend/app/services/vision_llm_router_service.py new file mode 100644 index 000000000..ed5de921c --- /dev/null +++ b/surfsense_backend/app/services/vision_llm_router_service.py @@ -0,0 +1,201 @@ +import logging +from typing import Any + +from litellm import Router + +from app.services.provider_api_base import resolve_api_base + +logger = logging.getLogger(__name__) + +VISION_AUTO_MODE_ID = 0 + +VISION_PROVIDER_MAP = { + "OPENAI": "openai", + "ANTHROPIC": "anthropic", + "GOOGLE": "gemini", + "AZURE_OPENAI": "azure", + "VERTEX_AI": "vertex_ai", + "BEDROCK": "bedrock", + "XAI": "xai", + "OPENROUTER": "openrouter", + "OLLAMA": "ollama_chat", + "GROQ": "groq", + "TOGETHER_AI": "together_ai", + "FIREWORKS_AI": "fireworks_ai", + "DEEPSEEK": "openai", + "MISTRAL": "mistral", + "CUSTOM": "custom", +} + + +class VisionLLMRouterService: + _instance = None + _router: Router | None = None + _model_list: list[dict] = [] + _router_settings: dict = {} + _initialized: bool = False + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + @classmethod + def get_instance(cls) -> "VisionLLMRouterService": + if cls._instance is None: + cls._instance = cls() + return cls._instance + + @classmethod + def initialize( + cls, + global_configs: list[dict], + router_settings: dict | None = None, + ) -> None: + instance = cls.get_instance() + + if instance._initialized: + logger.debug("Vision LLM Router already initialized, skipping") + return + + model_list = [] + for config in global_configs: + deployment = cls._config_to_deployment(config) + if deployment: + model_list.append(deployment) + + if not model_list: + logger.warning( + "No valid vision LLM configs found for router initialization" + ) + return + + instance._model_list = model_list + instance._router_settings = router_settings or {} + + default_settings = { + "routing_strategy": "usage-based-routing", + "num_retries": 3, + "allowed_fails": 3, + "cooldown_time": 60, + "retry_after": 5, + } + + final_settings = {**default_settings, **instance._router_settings} + + try: + instance._router = Router( + model_list=model_list, + routing_strategy=final_settings.get( + "routing_strategy", "usage-based-routing" + ), + num_retries=final_settings.get("num_retries", 3), + allowed_fails=final_settings.get("allowed_fails", 3), + cooldown_time=final_settings.get("cooldown_time", 60), + set_verbose=False, + ) + instance._initialized = True + logger.info( + "Vision LLM Router initialized with %d deployments, strategy: %s", + len(model_list), + final_settings.get("routing_strategy"), + ) + except Exception as e: + logger.error(f"Failed to initialize Vision LLM Router: {e}") + instance._router = None + + @classmethod + def _config_to_deployment(cls, config: dict) -> dict | None: + try: + if not config.get("model_name") or not config.get("api_key"): + return None + + provider = config.get("provider", "").upper() + if config.get("custom_provider"): + provider_prefix = config["custom_provider"] + model_string = f"{provider_prefix}/{config['model_name']}" + else: + provider_prefix = VISION_PROVIDER_MAP.get(provider, provider.lower()) + model_string = f"{provider_prefix}/{config['model_name']}" + + litellm_params: dict[str, Any] = { + "model": model_string, + "api_key": config.get("api_key"), + } + + api_base = resolve_api_base( + provider=provider, + provider_prefix=provider_prefix, + config_api_base=config.get("api_base"), + ) + if api_base: + litellm_params["api_base"] = api_base + + if config.get("api_version"): + litellm_params["api_version"] = config["api_version"] + + if config.get("litellm_params"): + litellm_params.update(config["litellm_params"]) + + deployment: dict[str, Any] = { + "model_name": "auto", + "litellm_params": litellm_params, + } + + if config.get("rpm"): + deployment["rpm"] = config["rpm"] + if config.get("tpm"): + deployment["tpm"] = config["tpm"] + + return deployment + + except Exception as e: + logger.warning(f"Failed to convert vision config to deployment: {e}") + return None + + @classmethod + def get_router(cls) -> Router | None: + instance = cls.get_instance() + return instance._router + + @classmethod + def is_initialized(cls) -> bool: + instance = cls.get_instance() + return instance._initialized and instance._router is not None + + @classmethod + def get_model_count(cls) -> int: + instance = cls.get_instance() + return len(instance._model_list) + + +def is_vision_auto_mode(config_id: int | None) -> bool: + return config_id == VISION_AUTO_MODE_ID + + +def build_vision_model_string( + provider: str, model_name: str, custom_provider: str | None +) -> str: + if custom_provider: + return f"{custom_provider}/{model_name}" + prefix = VISION_PROVIDER_MAP.get(provider.upper(), provider.lower()) + return f"{prefix}/{model_name}" + + +def get_global_vision_llm_config(config_id: int) -> dict | None: + from app.config import config + + if config_id == VISION_AUTO_MODE_ID: + return { + "id": VISION_AUTO_MODE_ID, + "name": "Auto (Fastest)", + "provider": "AUTO", + "model_name": "auto", + "is_auto_mode": True, + } + if config_id > 0: + return None + for cfg in config.GLOBAL_VISION_LLM_CONFIGS: + if cfg.get("id") == config_id: + return cfg + return None diff --git a/surfsense_backend/app/services/vision_model_list_service.py b/surfsense_backend/app/services/vision_model_list_service.py new file mode 100644 index 000000000..fc459910b --- /dev/null +++ b/surfsense_backend/app/services/vision_model_list_service.py @@ -0,0 +1,134 @@ +""" +Service for fetching and caching the vision-capable model list. + +Reuses the same OpenRouter public API and local fallback as the LLM model +list service, but filters for models that accept image input. +""" + +import json +import logging +import time +from pathlib import Path + +import httpx + +logger = logging.getLogger(__name__) + +OPENROUTER_API_URL = "https://openrouter.ai/api/v1/models" +FALLBACK_FILE = ( + Path(__file__).parent.parent / "config" / "vision_model_list_fallback.json" +) +CACHE_TTL_SECONDS = 86400 # 24 hours + +_cache: list[dict] | None = None +_cache_timestamp: float = 0 + +OPENROUTER_SLUG_TO_VISION_PROVIDER: dict[str, str] = { + "openai": "OPENAI", + "anthropic": "ANTHROPIC", + "google": "GOOGLE", + "mistralai": "MISTRAL", + "x-ai": "XAI", +} + + +def _format_context_length(length: int | None) -> str | None: + if not length: + return None + if length >= 1_000_000: + return f"{length / 1_000_000:g}M" + if length >= 1_000: + return f"{length / 1_000:g}K" + return str(length) + + +async def _fetch_from_openrouter() -> list[dict] | None: + try: + async with httpx.AsyncClient(timeout=15) as client: + response = await client.get(OPENROUTER_API_URL) + response.raise_for_status() + data = response.json() + return data.get("data", []) + except Exception as e: + logger.warning("Failed to fetch from OpenRouter API for vision models: %s", e) + return None + + +def _load_fallback() -> list[dict]: + try: + with open(FALLBACK_FILE, encoding="utf-8") as f: + return json.load(f) + except Exception as e: + logger.error("Failed to load vision model fallback list: %s", e) + return [] + + +def _is_vision_model(model: dict) -> bool: + """Return True if the model accepts image input and outputs text.""" + arch = model.get("architecture", {}) + input_mods = arch.get("input_modalities", []) + output_mods = arch.get("output_modalities", []) + return "image" in input_mods and "text" in output_mods + + +def _process_vision_models(raw_models: list[dict]) -> list[dict]: + processed: list[dict] = [] + + for model in raw_models: + model_id: str = model.get("id", "") + name: str = model.get("name", "") + context_length = model.get("context_length") + + if "/" not in model_id: + continue + + if not _is_vision_model(model): + continue + + provider_slug, model_name = model_id.split("/", 1) + context_window = _format_context_length(context_length) + + processed.append( + { + "value": model_id, + "label": name, + "provider": "OPENROUTER", + "context_window": context_window, + } + ) + + native_provider = OPENROUTER_SLUG_TO_VISION_PROVIDER.get(provider_slug) + if native_provider: + if native_provider == "GOOGLE" and not model_name.startswith("gemini-"): + continue + + processed.append( + { + "value": model_name, + "label": name, + "provider": native_provider, + "context_window": context_window, + } + ) + + return processed + + +async def get_vision_model_list() -> list[dict]: + global _cache, _cache_timestamp + + if _cache is not None and (time.time() - _cache_timestamp) < CACHE_TTL_SECONDS: + return _cache + + raw_models = await _fetch_from_openrouter() + + if raw_models is None: + logger.info("Using fallback vision model list") + return _load_fallback() + + processed = _process_vision_models(raw_models) + + _cache = processed + _cache_timestamp = time.time() + + return processed diff --git a/surfsense_backend/app/tasks/celery_tasks/__init__.py b/surfsense_backend/app/tasks/celery_tasks/__init__.py index a1113884f..6ea7a2e68 100644 --- a/surfsense_backend/app/tasks/celery_tasks/__init__.py +++ b/surfsense_backend/app/tasks/celery_tasks/__init__.py @@ -32,27 +32,10 @@ def get_celery_session_maker() -> async_sessionmaker: """ global _celery_engine, _celery_session_maker if _celery_session_maker is None: - # Reap connections orphaned mid-transaction (e.g. a worker that hung or - # crashed mid-index) so they can't hold locks on documents/chunks and - # wedge writes — the failure mode that previously left an "idle in - # transaction" session holding locks for 11+ hours. Kept generous so a - # legitimate long per-document embed window is never killed. - connect_args: dict = {} - idle_ms = config.DB_CELERY_IDLE_IN_TX_TIMEOUT_MS - if ( - idle_ms - and idle_ms > 0 - and config.DATABASE_URL - and "asyncpg" in config.DATABASE_URL - ): - connect_args["server_settings"] = { - "idle_in_transaction_session_timeout": str(idle_ms) - } _celery_engine = create_async_engine( config.DATABASE_URL, poolclass=NullPool, echo=False, - connect_args=connect_args, ) with contextlib.suppress(Exception): from app.observability.bootstrap import instrument_sqlalchemy_engine diff --git a/surfsense_backend/app/tasks/celery_tasks/document_tasks.py b/surfsense_backend/app/tasks/celery_tasks/document_tasks.py index 4d71d6c9a..41e029a60 100644 --- a/surfsense_backend/app/tasks/celery_tasks/document_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/document_tasks.py @@ -602,29 +602,23 @@ async def _process_file_upload( # Create notification for document processing logger.info(f"[_process_file_upload] Creating notification for: {filename}") - notification = None - heartbeat_task = None - try: - notification = ( - await NotificationService.document_processing.notify_processing_started( - session=session, - user_id=UUID(user_id), - document_type="FILE", - document_name=filename, - search_space_id=search_space_id, - file_size=file_size, - ) - ) - logger.info( - f"[_process_file_upload] Notification created with ID: {notification.id}" - ) - _start_heartbeat(notification.id) - heartbeat_task = asyncio.create_task(_run_heartbeat_loop(notification.id)) - except Exception: - logger.warning( - f"[_process_file_upload] Failed to create notification for: {filename}", - exc_info=True, + notification = ( + await NotificationService.document_processing.notify_processing_started( + session=session, + user_id=UUID(user_id), + document_type="FILE", + document_name=filename, + search_space_id=search_space_id, + file_size=file_size, ) + ) + logger.info( + f"[_process_file_upload] Notification created with ID: {notification.id if notification else 'None'}" + ) + + # Start Redis heartbeat for stale task detection + _start_heartbeat(notification.id) + heartbeat_task = asyncio.create_task(_run_heartbeat_loop(notification.id)) log_entry = await task_logger.log_task_start( task_name="process_file_upload", @@ -652,21 +646,23 @@ async def _process_file_upload( # Update notification on success if result: - if notification: - await NotificationService.document_processing.notify_processing_completed( + await ( + NotificationService.document_processing.notify_processing_completed( session=session, notification=notification, document_id=result.id, chunks_count=None, ) + ) else: # Duplicate detected - if notification: - await NotificationService.document_processing.notify_processing_completed( + await ( + NotificationService.document_processing.notify_processing_completed( session=session, notification=notification, error_message="Document already exists (duplicate)", ) + ) except Exception as e: # Import here to avoid circular dependencies @@ -695,13 +691,13 @@ async def _process_file_upload( error_message = str(credit_error) # Create a dedicated insufficient credits notification try: - if notification: - await session.refresh(notification) - await NotificationService.document_processing.notify_processing_completed( - session=session, - notification=notification, - error_message="Insufficient credits", - ) + # First, mark the processing notification as failed + await session.refresh(notification) + await NotificationService.document_processing.notify_processing_completed( + session=session, + notification=notification, + error_message="Insufficient credits", + ) # Then create a separate insufficient_credits notification for better UX await NotificationService.insufficient_credits.notify_insufficient_credits( @@ -721,13 +717,12 @@ async def _process_file_upload( # HTTPException with page limit message but no detailed cause error_message = str(e.detail) try: - if notification: - await session.refresh(notification) - await NotificationService.document_processing.notify_processing_completed( - session=session, - notification=notification, - error_message=error_message, - ) + await session.refresh(notification) + await NotificationService.document_processing.notify_processing_completed( + session=session, + notification=notification, + error_message=error_message, + ) except Exception as notif_error: logger.error( f"Failed to update notification on failure: {notif_error!s}" @@ -736,13 +731,13 @@ async def _process_file_upload( error_message = str(e)[:100] # Update notification on failure - wrapped in try-except to ensure it doesn't fail silently try: - if notification: - await session.refresh(notification) - await NotificationService.document_processing.notify_processing_completed( - session=session, - notification=notification, - error_message=error_message, - ) + # Refresh notification to ensure it's not stale after any rollback + await session.refresh(notification) + await NotificationService.document_processing.notify_processing_completed( + session=session, + notification=notification, + error_message=error_message, + ) except Exception as notif_error: logger.error( f"Failed to update notification on failure: {notif_error!s}" @@ -758,10 +753,8 @@ async def _process_file_upload( raise finally: # Stop heartbeat — key deleted on success, expires on crash - if heartbeat_task: - heartbeat_task.cancel() - if notification: - _stop_heartbeat(notification.id) + heartbeat_task.cancel() + _stop_heartbeat(notification.id) @celery_app.task(name="process_file_upload_with_document", bind=True) @@ -901,36 +894,29 @@ async def _process_file_with_document( logger.info( f"[_process_file_with_document] Creating notification for: {filename}" ) - notification = None - heartbeat_task = None - try: - notification = ( - await NotificationService.document_processing.notify_processing_started( - session=session, - user_id=UUID(user_id), - document_type="FILE", - document_name=filename, - search_space_id=search_space_id, - file_size=file_size, - ) + notification = ( + await NotificationService.document_processing.notify_processing_started( + session=session, + user_id=UUID(user_id), + document_type="FILE", + document_name=filename, + search_space_id=search_space_id, + file_size=file_size, ) + ) - # Store document_id in notification metadata so cleanup task can find the document - if notification.notification_metadata is not None: - notification.notification_metadata["document_id"] = document_id - from sqlalchemy.orm.attributes import flag_modified + # Store document_id in notification metadata so cleanup task can find the document + if notification and notification.notification_metadata is not None: + notification.notification_metadata["document_id"] = document_id + from sqlalchemy.orm.attributes import flag_modified - flag_modified(notification, "notification_metadata") - await session.commit() - await session.refresh(notification) + flag_modified(notification, "notification_metadata") + await session.commit() + await session.refresh(notification) - _start_heartbeat(notification.id) - heartbeat_task = asyncio.create_task(_run_heartbeat_loop(notification.id)) - except Exception: - logger.warning( - f"[_process_file_with_document] Failed to create notification for: {filename}", - exc_info=True, - ) + # Start Redis heartbeat for stale task detection + _start_heartbeat(notification.id) + heartbeat_task = asyncio.create_task(_run_heartbeat_loop(notification.id)) log_entry = await task_logger.log_task_start( task_name="process_file_upload_with_document", @@ -970,13 +956,14 @@ async def _process_file_with_document( # Update notification on success if result: - if notification: - await NotificationService.document_processing.notify_processing_completed( + await ( + NotificationService.document_processing.notify_processing_completed( session=session, notification=notification, document_id=result.id, chunks_count=None, ) + ) logger.info( f"[_process_file_with_document] Successfully processed document {document_id}" ) @@ -985,12 +972,13 @@ async def _process_file_with_document( document.status = DocumentStatus.failed("Duplicate content detected") document.updated_at = get_current_timestamp() await session.commit() - if notification: - await NotificationService.document_processing.notify_processing_completed( + await ( + NotificationService.document_processing.notify_processing_completed( session=session, notification=notification, error_message="Document already exists (duplicate)", ) + ) except Exception as e: # Import here to avoid circular dependencies @@ -1021,13 +1009,12 @@ async def _process_file_with_document( # Handle insufficient-credit errors with dedicated notification if credit_error is not None: try: - if notification: - await session.refresh(notification) - await NotificationService.document_processing.notify_processing_completed( - session=session, - notification=notification, - error_message="Insufficient credits", - ) + await session.refresh(notification) + await NotificationService.document_processing.notify_processing_completed( + session=session, + notification=notification, + error_message="Insufficient credits", + ) await NotificationService.insufficient_credits.notify_insufficient_credits( session=session, user_id=UUID(user_id), @@ -1044,13 +1031,12 @@ async def _process_file_with_document( else: # Update notification on failure try: - if notification: - await session.refresh(notification) - await NotificationService.document_processing.notify_processing_completed( - session=session, - notification=notification, - error_message=str(e)[:100], - ) + await session.refresh(notification) + await NotificationService.document_processing.notify_processing_completed( + session=session, + notification=notification, + error_message=str(e)[:100], + ) except Exception as notif_error: logger.error( f"Failed to update notification on failure: {notif_error!s}" @@ -1067,10 +1053,8 @@ async def _process_file_with_document( finally: # Stop heartbeat — key deleted on success, expires on crash - if heartbeat_task: - heartbeat_task.cancel() - if notification: - _stop_heartbeat(notification.id) + heartbeat_task.cancel() + _stop_heartbeat(notification.id) # Clean up temp file if os.path.exists(temp_path): diff --git a/surfsense_backend/app/tasks/chat/llm_history_normalizer.py b/surfsense_backend/app/tasks/chat/llm_history_normalizer.py deleted file mode 100644 index 3394913c3..000000000 --- a/surfsense_backend/app/tasks/chat/llm_history_normalizer.py +++ /dev/null @@ -1,88 +0,0 @@ -"""Convert persisted chat content into provider-safe LangChain history. - -Assistant UI parts are a UI/storage shape, not an LLM prompt shape. This module -extracts only model-safe content before prior turns are replayed to a provider. -""" - -from __future__ import annotations - -from typing import Any - -_USER_CONTENT_TYPES = {"text", "image", "image_url"} - - -def _text_from_block(block: dict[str, Any]) -> str: - value = block.get("text") or block.get("content") or "" - return value if isinstance(value, str) else "" - - -def assistant_content_to_llm_text(content: Any) -> str: - """Return visible assistant text, dropping reasoning/UI/provider blocks.""" - if isinstance(content, str): - return content - if isinstance(content, dict): - return _text_from_block(content) - if not isinstance(content, list): - return "" - - text_chunks: list[str] = [] - for block in content: - if isinstance(block, str): - if block: - text_chunks.append(block) - continue - if not isinstance(block, dict): - continue - if block.get("type") == "text": - text = _text_from_block(block) - if text: - text_chunks.append(text) - return "\n".join(text_chunks) - - -def user_content_to_llm_content( - content: Any, - *, - allow_images: bool = True, -) -> str | list[dict[str, Any]]: - """Return provider-safe user text/image content for LangChain.""" - if isinstance(content, str): - return content - if isinstance(content, dict): - return _text_from_block(content) - if not isinstance(content, list): - return "" - - parts: list[dict[str, Any]] = [] - text_chunks: list[str] = [] - for block in content: - if isinstance(block, str): - if block: - text_chunks.append(block) - continue - if not isinstance(block, dict): - continue - block_type = block.get("type") - if block_type not in _USER_CONTENT_TYPES: - continue - if block_type == "text": - text = _text_from_block(block) - if text: - parts.append({"type": "text", "text": text}) - text_chunks.append(text) - elif allow_images and block_type == "image": - image = block.get("image") - if isinstance(image, str) and image.startswith("data:"): - parts.append({"type": "image_url", "image_url": {"url": image}}) - elif allow_images and block_type == "image_url": - image_url = block.get("image_url") - if isinstance(image_url, dict): - url = image_url.get("url") - if isinstance(url, str) and url.startswith("data:"): - parts.append({"type": "image_url", "image_url": {"url": url}}) - elif isinstance(image_url, str) and image_url.startswith("data:"): - parts.append({"type": "image_url", "image_url": {"url": image_url}}) - - if allow_images and any(part.get("type") == "image_url" for part in parts): - return parts - return "\n".join(text_chunks) diff --git a/surfsense_backend/app/tasks/chat/message_parts_normalizer.py b/surfsense_backend/app/tasks/chat/message_parts_normalizer.py deleted file mode 100644 index a4b636538..000000000 --- a/surfsense_backend/app/tasks/chat/message_parts_normalizer.py +++ /dev/null @@ -1,89 +0,0 @@ -"""Normalize final LangChain assistant messages into assistant-ui parts. - -Live streaming remains the primary source for rich, incremental UI state. -This module is only used after the graph has finished so refresh persistence -does not depend on provider-specific streaming chunk shapes. -""" - -from __future__ import annotations - -from collections.abc import Iterable -from typing import Any - -from langchain_core.messages import AIMessage - - -def _text_from_content(content: Any) -> str: - if isinstance(content, str): - return content - if not isinstance(content, list): - return "" - - text_parts: list[str] = [] - for block in content: - if not isinstance(block, dict): - continue - if block.get("type") != "text": - continue - value = block.get("text") or block.get("content") or "" - if isinstance(value, str) and value: - text_parts.append(value) - return "".join(text_parts) - - -def normalize_ai_message_to_parts( - message: AIMessage | Any | None, -) -> list[dict[str, Any]]: - """Return user-visible assistant-ui parts for a final AI message. - - We intentionally do not backfill provider ``thinking`` / - ``reasoning_content`` blocks here. If reasoning streamed live, the - ``AssistantContentBuilder`` already captured it. If it only exists in the - final model payload, persisting it retroactively could expose content the - UI never showed during the turn. - """ - if message is None: - return [] - - text = _text_from_content(getattr(message, "content", None)).strip() - if not text: - return [] - return [{"type": "text", "text": text}] - - -def last_ai_message(messages: Iterable[Any] | None) -> AIMessage | Any | None: - if messages is None: - return None - for message in reversed(list(messages)): - if isinstance(message, AIMessage): - return message - if getattr(message, "type", None) == "ai": - return message - return None - - -def final_assistant_parts_from_messages( - messages: Iterable[Any] | None, -) -> list[dict[str, Any]]: - return normalize_ai_message_to_parts(last_ai_message(messages)) - - -def has_non_empty_text_part(parts: Iterable[dict[str, Any]]) -> bool: - return any( - part.get("type") == "text" - and isinstance(part.get("text"), str) - and bool(part.get("text", "").strip()) - for part in parts - ) - - -def merge_streamed_and_final_parts( - streamed_parts: list[dict[str, Any]], - final_parts: list[dict[str, Any]], -) -> list[dict[str, Any]]: - """Use final-state text only when streaming captured no answer text.""" - if has_non_empty_text_part(streamed_parts): - return streamed_parts - if not has_non_empty_text_part(final_parts): - return streamed_parts - return [*streamed_parts, *final_parts] diff --git a/surfsense_backend/app/tasks/chat/streaming/agent/event_loop.py b/surfsense_backend/app/tasks/chat/streaming/agent/event_loop.py index 939cd9b17..d96144bcd 100644 --- a/surfsense_backend/app/tasks/chat/streaming/agent/event_loop.py +++ b/surfsense_backend/app/tasks/chat/streaming/agent/event_loop.py @@ -16,9 +16,6 @@ from app.agents.chat.multi_agent_chat.main_agent.middleware.kb_persistence impor ) from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode from app.services.new_streaming_service import VercelStreamingService -from app.tasks.chat.message_parts_normalizer import ( - final_assistant_parts_from_messages, -) from app.tasks.chat.streaming.contract.file_contract import ( contract_enforcement_active, evaluate_file_contract_outcome, @@ -78,9 +75,6 @@ async def stream_agent_events( state = await agent.aget_state(config) state_values = getattr(state, "values", {}) or {} - result.final_message_parts = final_assistant_parts_from_messages( - state_values.get("messages") - ) # Safety net: if astream_events was cancelled before # KnowledgeBasePersistenceMiddleware.aafter_agent ran, any staged work diff --git a/surfsense_backend/app/tasks/chat/streaming/errors/classifier.py b/surfsense_backend/app/tasks/chat/streaming/errors/classifier.py index 3fc5918ee..6b37df343 100644 --- a/surfsense_backend/app/tasks/chat/streaming/errors/classifier.py +++ b/surfsense_backend/app/tasks/chat/streaming/errors/classifier.py @@ -12,7 +12,6 @@ from app.agents.chat.multi_agent_chat.main_agent.middleware.busy_mutex import ( is_cancel_requested, ) from app.agents.chat.runtime.errors import BusyError -from app.services.llm_error_adapter import LLMErrorCategory, adapt_llm_exception TURN_CANCELLING_INITIAL_DELAY_MS = 200 TURN_CANCELLING_BACKOFF_FACTOR = 2 @@ -103,9 +102,6 @@ def _extract_provider_error_code(parsed: dict[str, Any] | None) -> int | None: def is_provider_rate_limited(exc: BaseException) -> bool: """Return True if the exception looks like an upstream HTTP 429 / rate limit.""" - if adapt_llm_exception(exc).category is LLMErrorCategory.RATE_LIMITED: - return True - raw = str(exc) lowered = raw.lower() if "ratelimit" in type(exc).__name__.lower(): @@ -135,85 +131,6 @@ def is_provider_rate_limited(exc: BaseException) -> bool: ) -def _provider_error_extra(adapted: Any) -> dict[str, Any] | None: - extra: dict[str, Any] = {"provider_error_category": adapted.category.value} - if adapted.provider_status_code is not None: - extra["provider_status_code"] = adapted.provider_status_code - if adapted.provider_error_type: - extra["provider_error_type"] = adapted.provider_error_type - return extra - - -def _classify_provider_exception( - exc: Exception, -) -> ( - tuple[str, str, Literal["info", "warn", "error"], bool, str, dict[str, Any] | None] - | None -): - adapted = adapt_llm_exception(exc) - - if adapted.category is LLMErrorCategory.RATE_LIMITED: - return ( - "rate_limited", - "RATE_LIMITED", - "warn", - True, - "This model is temporarily rate-limited. Please try again in a few seconds or switch models.", - _provider_error_extra(adapted), - ) - - if adapted.category in { - LLMErrorCategory.AUTH_FAILED, - LLMErrorCategory.PERMISSION_DENIED, - }: - return ( - "model_auth_failed", - "MODEL_AUTH_FAILED", - "warn", - True, - "This model's API key is invalid or expired. Switch models, or update the API key.", - _provider_error_extra(adapted), - ) - - if adapted.category is LLMErrorCategory.MODEL_NOT_FOUND: - return ( - "model_not_found", - "MODEL_NOT_FOUND", - "warn", - True, - "The selected model is unavailable or no longer exists. Switch to another model and try again.", - _provider_error_extra(adapted), - ) - - if adapted.category is LLMErrorCategory.CONTEXT_LIMIT: - return ( - "model_context_limit", - "MODEL_CONTEXT_LIMIT", - "warn", - True, - "This request is too large for the selected model. Try a model with a larger context window or reduce the input.", - _provider_error_extra(adapted), - ) - - if adapted.category in { - LLMErrorCategory.TIMEOUT, - LLMErrorCategory.PROVIDER_UNAVAILABLE, - LLMErrorCategory.BAD_GATEWAY, - LLMErrorCategory.CONNECTION_FAILED, - LLMErrorCategory.SERVER_ERROR, - }: - return ( - "model_provider_unavailable", - "MODEL_PROVIDER_UNAVAILABLE", - "warn", - True, - "The selected model provider is temporarily unavailable. Please try again or switch models.", - _provider_error_extra(adapted), - ) - - return None - - def classify_stream_exception( exc: Exception, *, @@ -250,9 +167,15 @@ def classify_stream_exception( None, ) - provider_classification = _classify_provider_exception(exc) - if provider_classification is not None: - return provider_classification + if is_provider_rate_limited(exc): + return ( + "rate_limited", + "RATE_LIMITED", + "warn", + True, + "This model is temporarily rate-limited. Please try again in a few seconds or switch models.", + None, + ) return ( "server_error", diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/title_gen.py b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/title_gen.py index d5e8c3729..fe3d210bb 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/title_gen.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/title_gen.py @@ -80,6 +80,7 @@ async def _generate_title( from litellm import acompletion from app.services.llm_router_service import LLMRouterService + from app.services.provider_api_base import resolve_api_base from app.services.token_tracking_service import _turn_accumulator # Excludes this turn's own assistant row (pre-written by @@ -124,12 +125,26 @@ async def _generate_title( router = LLMRouterService.get_router() response = await router.acompletion(model="auto", messages=messages) else: + # Apply the same ``api_base`` cascade chat / vision / image-gen + # call sites use so we never inherit ``litellm.api_base`` + # (commonly set by ``AZURE_OPENAI_ENDPOINT``) when the chat + # config itself ships an empty ``api_base``. Without this the + # title-gen on an OpenRouter chat config would 404 against the + # inherited Azure endpoint — see ``provider_api_base`` for the + # same bug repro on the image-gen / vision paths. raw_model = getattr(llm, "model", "") or "" + provider_prefix = raw_model.split("/", 1)[0] if "/" in raw_model else None + provider_value = agent_config.provider if agent_config is not None else None + title_api_base = resolve_api_base( + provider=provider_value, + provider_prefix=provider_prefix, + config_api_base=getattr(llm, "api_base", None), + ) response = await acompletion( model=raw_model, messages=messages, api_key=getattr(llm, "api_key", None), - api_base=getattr(llm, "api_base", None), + api_base=title_api_base, ) usage_info = None diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/shared/assistant_finalize.py b/surfsense_backend/app/tasks/chat/streaming/flows/shared/assistant_finalize.py index 3f767c60b..be1f102f3 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/shared/assistant_finalize.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/shared/assistant_finalize.py @@ -53,7 +53,6 @@ async def finalize_assistant_message( ): return - from app.tasks.chat.message_parts_normalizer import merge_streamed_and_final_parts from app.tasks.chat.persistence import finalize_assistant_turn builder_stats: dict[str, int] | None = None @@ -75,10 +74,6 @@ async def finalize_assistant_message( "text": stream_result.accumulated_text or "", } ] - content_payload = merge_streamed_and_final_parts( - content_payload, - stream_result.final_message_parts, - ) if builder_stats is not None: _perf_log.info( diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py b/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py index 6f905e8f4..7e2bc950b 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py @@ -1,8 +1,8 @@ """Load an LLM + AgentConfig bundle for a given config id. Handles both code paths uniformly: -- ``config_id > 0`` → database-backed model-connection ``Model`` row. -- ``config_id < 0`` → virtual global model materialized from YAML/OpenRouter. +- ``config_id >= 0`` → database-backed ``NewLLMConfig`` row (per-user/per-space). +- ``config_id < 0`` → YAML-defined global LLM config (built-in defaults). Returns ``(llm, agent_config, error_message)``; on success ``error_message`` is ``None``. The caller emits the friendly SSE error frame. @@ -12,78 +12,15 @@ from __future__ import annotations from typing import Any -from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import selectinload from app.agents.chat.runtime.llm_config import ( AgentConfig, - SanitizedChatLiteLLM, + create_chat_litellm_from_agent_config, + create_chat_litellm_from_config, + load_agent_config, + load_global_llm_config_by_id, ) -from app.config import config -from app.db import Model, SearchSpace -from app.services.model_capabilities import has_capability -from app.services.model_resolver import to_litellm -from app.services.token_tracking_service import register_model_usage_metadata - - -def _agent_config_from_resolved( - *, - config_id: int, - config_name: str | None, - provider: str, - model_name: str, - api_key: str | None, - api_base: str | None, - litellm_params: dict | None, - supports_image_input: bool, - billing_tier: str = "free", -) -> AgentConfig: - return AgentConfig( - provider=provider, - model_name=model_name, - api_key=api_key or "", - api_base=api_base, - custom_provider=None, - litellm_params=litellm_params, - config_id=config_id, - config_name=config_name, - is_auto_mode=False, - billing_tier=billing_tier, - is_premium=billing_tier == "premium", - supports_image_input=supports_image_input, - ) - - -async def _load_search_space( - session: AsyncSession, search_space_id: int -) -> SearchSpace | None: - result = await session.execute( - select(SearchSpace).where(SearchSpace.id == search_space_id) - ) - return result.scalars().first() - - -async def _load_db_model( - session: AsyncSession, - *, - model_id: int, - search_space: SearchSpace, -) -> Model | None: - result = await session.execute( - select(Model) - .options(selectinload(Model.connection)) - .where(Model.id == model_id, Model.enabled.is_(True)) - ) - model = result.scalars().first() - if not model or not model.connection or not model.connection.enabled: - return None - conn = model.connection - if conn.search_space_id is not None and conn.search_space_id != search_space.id: - return None - if conn.user_id is not None and conn.user_id != search_space.user_id: - return None - return model async def load_llm_bundle( @@ -92,93 +29,29 @@ async def load_llm_bundle( config_id: int, search_space_id: int, ) -> tuple[Any, AgentConfig | None, str | None]: - search_space = await _load_search_space(session, search_space_id) - if not search_space: - return None, None, f"Search space {search_space_id} not found" - - if config_id > 0: - model = await _load_db_model( - session, - model_id=config_id, - search_space=search_space, + if config_id >= 0: + loaded_agent_config = await load_agent_config( + session=session, + config_id=config_id, + search_space_id=search_space_id, ) - if not model or not has_capability(model, "chat"): + if not loaded_agent_config: return ( None, None, - f"Failed to load chat model with id {config_id}", + f"Failed to load NewLLMConfig with id {config_id}", ) - model_string, litellm_kwargs = to_litellm(model.connection, model.model_id) - display_name = model.display_name or model.model_id - provider = model.connection.provider or "" - register_model_usage_metadata( - model=model_string, - model_ref=f"db:{model.id}", - model_id=model.model_id, - display_name=display_name, - provider=provider, - ) - agent_config = _agent_config_from_resolved( - config_id=config_id, - config_name=display_name, - provider=provider, - model_name=model.model_id, - api_key=model.connection.api_key, - api_base=model.connection.base_url, - litellm_params=(model.connection.extra or {}).get("litellm_params"), - supports_image_input=has_capability(model, "vision"), - billing_tier="free", - ) return ( - SanitizedChatLiteLLM( - model=model_string, **{**litellm_kwargs, "streaming": True} - ), - agent_config, + create_chat_litellm_from_agent_config(loaded_agent_config), + loaded_agent_config, None, ) - global_model = next( - (m for m in config.GLOBAL_MODELS if m.get("id") == config_id), None - ) - if not global_model or not has_capability(global_model, "chat"): - return None, None, f"Failed to load global chat model with id {config_id}" - global_connection = next( - ( - c - for c in config.GLOBAL_CONNECTIONS - if c.get("id") == global_model.get("connection_id") - ), - None, - ) - if not global_connection: - return None, None, f"Failed to load global connection for model {config_id}" - model_string, litellm_kwargs = to_litellm( - global_connection, global_model["model_id"] - ) - display_name = global_model.get("display_name") or global_model.get("model_id") - provider = global_connection.get("provider") or "" - register_model_usage_metadata( - model=model_string, - model_ref=f"global:{config_id}", - model_id=global_model["model_id"], - display_name=display_name, - provider=provider, - ) - agent_config = _agent_config_from_resolved( - config_id=config_id, - config_name=display_name, - provider=provider, - model_name=global_model["model_id"], - api_key=global_connection.get("api_key"), - api_base=global_connection.get("base_url"), - litellm_params=(global_connection.get("extra") or {}).get("litellm_params"), - supports_image_input=has_capability(global_model, "vision"), - billing_tier=str(global_model.get("billing_tier", "free")).lower(), - ) + loaded_llm_config = load_global_llm_config_by_id(config_id) + if not loaded_llm_config: + return None, None, f"Failed to load LLM config with id {config_id}" return ( - SanitizedChatLiteLLM( - model=model_string, **{**litellm_kwargs, "streaming": True} - ), - agent_config, + create_chat_litellm_from_config(loaded_llm_config), + AgentConfig.from_yaml_config(loaded_llm_config), None, ) diff --git a/surfsense_backend/app/tasks/chat/streaming/shared/stream_result.py b/surfsense_backend/app/tasks/chat/streaming/shared/stream_result.py index 5e164070a..a940e8a9f 100644 --- a/surfsense_backend/app/tasks/chat/streaming/shared/stream_result.py +++ b/surfsense_backend/app/tasks/chat/streaming/shared/stream_result.py @@ -35,7 +35,3 @@ class StreamResult: # (``StreamResult`` is logged in some error branches) from dumping a # potentially-large parts list. content_builder: Any | None = field(default=None, repr=False) - # User-visible assistant message parts derived from the final LangGraph - # state. Used after streaming completes as a provider-agnostic persistence - # backfill when no text chunks reached the live stream. - final_message_parts: list[dict[str, Any]] = field(default_factory=list) diff --git a/surfsense_backend/app/tasks/connector_indexers/github_indexer.py b/surfsense_backend/app/tasks/connector_indexers/github_indexer.py index 557c2ce71..ce9b80e5e 100644 --- a/surfsense_backend/app/tasks/connector_indexers/github_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/github_indexer.py @@ -525,7 +525,6 @@ async def _simple_chunk_content(content: str, chunk_size: int = 4000) -> list: Chunk( content=chunk_text, embedding=embed_text(chunk_text), - position=len(chunks), ) ) diff --git a/surfsense_backend/app/tasks/connector_indexers/local_folder_indexer.py b/surfsense_backend/app/tasks/connector_indexers/local_folder_indexer.py index 2505fa7c4..1a2d4b967 100644 --- a/surfsense_backend/app/tasks/connector_indexers/local_folder_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/local_folder_indexer.py @@ -162,13 +162,12 @@ async def _read_file_content( All file types (plaintext, audio, direct-convert, document, image) are handled by ``EtlPipelineService``. """ - from app.etl_pipeline.cache import extract_with_cache from app.etl_pipeline.etl_document import EtlRequest, ProcessingMode + from app.etl_pipeline.etl_pipeline_service import EtlPipelineService mode = ProcessingMode.coerce(processing_mode) - result = await extract_with_cache( - EtlRequest(file_path=file_path, filename=filename, processing_mode=mode), - vision_llm=vision_llm, + result = await EtlPipelineService(vision_llm=vision_llm).extract( + EtlRequest(file_path=file_path, filename=filename, processing_mode=mode) ) return result.markdown_content diff --git a/surfsense_backend/app/tasks/document_processors/file_processors.py b/surfsense_backend/app/tasks/document_processors/file_processors.py index 174ac966d..a646b7aa6 100644 --- a/surfsense_backend/app/tasks/document_processors/file_processors.py +++ b/surfsense_backend/app/tasks/document_processors/file_processors.py @@ -1,9 +1,8 @@ """ File document processors orchestrating content extraction and indexing. -Delegates content extraction to the cache-aware ``extract_with_cache`` facade -(over ``EtlPipelineService``) and keeps only orchestration concerns -(notifications, logging, page limits, saving). +Delegates content extraction to ``app.etl_pipeline.EtlPipelineService`` and +keeps only orchestration concerns (notifications, logging, page limits, saving). """ from __future__ import annotations @@ -117,8 +116,8 @@ async def _log_page_divergence( async def _process_non_document_upload(ctx: _ProcessingContext) -> Document | None: """Extract content from a non-document file (plaintext/direct_convert/audio/image) via the unified ETL pipeline.""" - from app.etl_pipeline.cache import extract_with_cache from app.etl_pipeline.etl_document import EtlRequest + from app.etl_pipeline.etl_pipeline_service import EtlPipelineService await _notify(ctx, "parsing", "Processing file") await ctx.task_logger.log_task_progress( @@ -137,9 +136,8 @@ async def _process_non_document_upload(ctx: _ProcessingContext) -> Document | No vision_llm = await get_vision_llm(ctx.session, ctx.search_space_id) - etl_result = await extract_with_cache( - EtlRequest(file_path=ctx.file_path, filename=ctx.filename), - vision_llm=vision_llm, + etl_result = await EtlPipelineService(vision_llm=vision_llm).extract( + EtlRequest(file_path=ctx.file_path, filename=ctx.filename) ) with contextlib.suppress(Exception): @@ -185,8 +183,8 @@ async def _process_non_document_upload(ctx: _ProcessingContext) -> Document | No async def _process_document_upload(ctx: _ProcessingContext) -> Document | None: """Route a document file to the configured ETL service via the unified pipeline.""" - from app.etl_pipeline.cache import extract_with_cache from app.etl_pipeline.etl_document import EtlRequest, ProcessingMode + from app.etl_pipeline.etl_pipeline_service import EtlPipelineService from app.services.etl_credit_service import ( EtlCreditService, InsufficientCreditsError, @@ -239,14 +237,13 @@ async def _process_document_upload(ctx: _ProcessingContext) -> Document | None: vision_llm = await get_vision_llm(ctx.session, ctx.search_space_id) - etl_result = await extract_with_cache( + etl_result = await EtlPipelineService(vision_llm=vision_llm).extract( EtlRequest( file_path=ctx.file_path, filename=ctx.filename, estimated_pages=estimated_pages, processing_mode=mode, - ), - vision_llm=vision_llm, + ) ) with contextlib.suppress(Exception): @@ -384,6 +381,7 @@ async def _extract_file_content( Tuple of (markdown_content, etl_service_name, billable_pages). """ from app.etl_pipeline.etl_document import EtlRequest, ProcessingMode + from app.etl_pipeline.etl_pipeline_service import EtlPipelineService from app.etl_pipeline.file_classifier import ( FileCategory, classify_file as etl_classify, @@ -434,16 +432,13 @@ async def _extract_file_content( vision_llm = await get_vision_llm(session, search_space_id) - from app.etl_pipeline.cache import extract_with_cache - - result = await extract_with_cache( + result = await EtlPipelineService(vision_llm=vision_llm).extract( EtlRequest( file_path=file_path, filename=filename, estimated_pages=estimated_pages, processing_mode=mode, - ), - vision_llm=vision_llm, + ) ) with contextlib.suppress(Exception): diff --git a/surfsense_backend/app/utils/content_utils.py b/surfsense_backend/app/utils/content_utils.py index aae936888..05a4610c7 100644 --- a/surfsense_backend/app/utils/content_utils.py +++ b/surfsense_backend/app/utils/content_utils.py @@ -18,11 +18,6 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload -from app.tasks.chat.llm_history_normalizer import ( - assistant_content_to_llm_text, - user_content_to_llm_content, -) - if TYPE_CHECKING: from app.db import ChatVisibility @@ -100,28 +95,17 @@ async def bootstrap_history_from_db( langchain_messages: list[HumanMessage | AIMessage] = [] for msg in db_messages: + text_content = extract_text_content(msg.content) + if not text_content: + continue if msg.role == "user": - user_content = user_content_to_llm_content( - msg.content, - allow_images=False, - ) - if not user_content: - continue if is_shared: author_name = ( msg.author.display_name if msg.author else None ) or "A team member" - if isinstance(user_content, str): - user_content = f"**[{author_name}]:** {user_content}" - elif user_content and user_content[0].get("type") == "text": - user_content[0] = { - **user_content[0], - "text": f"**[{author_name}]:** {user_content[0].get('text', '')}", - } - langchain_messages.append(HumanMessage(content=user_content)) + text_content = f"**[{author_name}]:** {text_content}" + langchain_messages.append(HumanMessage(content=text_content)) elif msg.role == "assistant": - assistant_text = assistant_content_to_llm_text(msg.content) - if assistant_text: - langchain_messages.append(AIMessage(content=assistant_text)) + langchain_messages.append(AIMessage(content=text_content)) return langchain_messages diff --git a/surfsense_backend/app/utils/document_converters.py b/surfsense_backend/app/utils/document_converters.py index fef51d692..694ae22ac 100644 --- a/surfsense_backend/app/utils/document_converters.py +++ b/surfsense_backend/app/utils/document_converters.py @@ -188,10 +188,8 @@ async def create_document_chunks(content: str) -> list[Chunk]: chunk_texts = [c.text for c in config.chunker_instance.chunk(content)] chunk_embeddings = await asyncio.to_thread(embed_texts, chunk_texts) return [ - Chunk(content=text, embedding=emb, position=i) - for i, (text, emb) in enumerate( - zip(chunk_texts, chunk_embeddings, strict=False) - ) + Chunk(content=text, embedding=emb) + for text, emb in zip(chunk_texts, chunk_embeddings, strict=False) ] diff --git a/surfsense_backend/pyproject.toml b/surfsense_backend/pyproject.toml index 6afc7fd15..ff43f6a97 100644 --- a/surfsense_backend/pyproject.toml +++ b/surfsense_backend/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "surf-new-backend" -version = "0.0.29" +version = "0.0.28" description = "SurfSense Backend" requires-python = ">=3.12" dependencies = [ @@ -41,6 +41,7 @@ dependencies = [ "elasticsearch>=9.1.1", "faster-whisper>=1.1.0", "celery[redis]>=5.5.3", + "flower>=2.0.1", "redis>=5.2.1", "firecrawl-py>=4.9.0", "boto3>=1.35.0", diff --git a/surfsense_backend/scripts/verify_chat_image_capability.py b/surfsense_backend/scripts/verify_chat_image_capability.py index a6be063eb..a49d4eab2 100644 --- a/surfsense_backend/scripts/verify_chat_image_capability.py +++ b/surfsense_backend/scripts/verify_chat_image_capability.py @@ -55,6 +55,7 @@ from app.services.openrouter_integration_service import ( # noqa: E402 _OPENROUTER_DYNAMIC_MARKER, OpenRouterIntegrationService, ) +from app.services.provider_api_base import resolve_api_base # noqa: E402 from app.services.provider_capabilities import ( # noqa: E402 derive_supports_image_input, is_known_text_only_chat_model, @@ -153,13 +154,13 @@ def _probe_chat_capability(cfg: dict) -> tuple[bool, str]: litellm_params.get("base_model") if isinstance(litellm_params, dict) else None ) cap = derive_supports_image_input( - provider=cfg.get("litellm_provider"), + provider=cfg.get("provider"), model_name=cfg.get("model_name"), base_model=base_model, custom_provider=cfg.get("custom_provider"), ) block = is_known_text_only_chat_model( - provider=cfg.get("litellm_provider"), + provider=cfg.get("provider"), model_name=cfg.get("model_name"), base_model=base_model, custom_provider=cfg.get("custom_provider"), @@ -178,7 +179,11 @@ def _probe_chat_capability(cfg: dict) -> tuple[bool, str]: def _build_chat_model_string(cfg: dict) -> str: if cfg.get("custom_provider"): return f"{cfg['custom_provider']}/{cfg['model_name']}" - prefix = cfg.get("litellm_provider") or "openai" + from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP + + prefix = _PROVIDER_PREFIX_MAP.get( + (cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower() + ) return f"{prefix}/{cfg['model_name']}" @@ -190,6 +195,11 @@ def _build_chat_model_string(cfg: dict) -> str: async def _live_chat_image_call(cfg: dict) -> tuple[bool, str]: """Send a 1x1 PNG + `reply with one word: ok` to the chat config.""" model_string = _build_chat_model_string(cfg) + api_base = resolve_api_base( + provider=cfg.get("provider"), + provider_prefix=model_string.split("/", 1)[0], + config_api_base=cfg.get("api_base") or None, + ) kwargs: dict[str, Any] = { "model": model_string, "api_key": cfg.get("api_key"), @@ -208,8 +218,8 @@ async def _live_chat_image_call(cfg: dict) -> tuple[bool, str]: "max_tokens": 16, "timeout": 60, } - if cfg.get("api_base"): - kwargs["api_base"] = cfg["api_base"] + if api_base: + kwargs["api_base"] = api_base if cfg.get("litellm_params"): # Strip pricing keys — they're tracking-only and confuse some # provider validators (e.g. azure/openai reject unknown kwargs @@ -247,11 +257,20 @@ _IMAGE_GEN_PROMPTS: tuple[str, ...] = ( async def _live_image_gen_call(cfg: dict) -> tuple[bool, str]: """Generate one tiny image to verify the deployment is reachable.""" + from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP + if cfg.get("custom_provider"): prefix = cfg["custom_provider"] else: - prefix = cfg.get("litellm_provider") or "openai" + prefix = _PROVIDER_PREFIX_MAP.get( + (cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower() + ) model_string = f"{prefix}/{cfg['model_name']}" + api_base = resolve_api_base( + provider=cfg.get("provider"), + provider_prefix=prefix, + config_api_base=cfg.get("api_base") or None, + ) base_kwargs: dict[str, Any] = { "model": model_string, "api_key": cfg.get("api_key"), @@ -259,8 +278,8 @@ async def _live_image_gen_call(cfg: dict) -> tuple[bool, str]: "size": "1024x1024", "timeout": 120, } - if cfg.get("api_base"): - base_kwargs["api_base"] = cfg["api_base"] + if api_base: + base_kwargs["api_base"] = api_base if cfg.get("api_version"): base_kwargs["api_version"] = cfg["api_version"] if cfg.get("litellm_params"): @@ -330,6 +349,31 @@ async def probe_chat_configs(report: Report, *, live: bool) -> None: report.add(result) +async def probe_vision_configs(report: Report, *, live: bool) -> None: + print("\n[vision configs from global_vision_llm_configs (YAML-static)]") + for cfg in config.GLOBAL_VISION_LLM_CONFIGS: + if _is_or_dynamic(cfg): + continue + result = ProbeResult( + label=str(cfg.get("name") or cfg.get("model_name")), + surface="vision", + config_id=cfg.get("id"), + ) + # For vision configs, capability is implied — they're in the + # dedicated vision pool. Run the same resolver to flag any + # surprise disagreement. + cap_ok, cap_note = _probe_chat_capability(cfg) + result.capability_ok = cap_ok + result.capability_note = cap_note + if live: + t0 = time.perf_counter() + ok, note = await _live_chat_image_call(cfg) + result.live_ok = ok + result.live_note = note + result.duration_s = time.perf_counter() - t0 + report.add(result) + + async def probe_image_gen_configs(report: Report, *, live: bool) -> None: print( "\n[image generation configs from global_image_generation_configs (YAML-static)]" @@ -355,7 +399,7 @@ async def probe_image_gen_configs(report: Report, *, live: bool) -> None: async def probe_openrouter_catalog(report: Report, *, live: bool) -> None: - """Sample chat/vision-capable and image-gen models + """Sample one chat (vision-capable), one vision, one image-gen model from the live OpenRouter catalogue. Doesn't iterate the full pool (would be hundreds of probes); just validates the integration end- to-end on a representative model from each surface.""" @@ -380,6 +424,9 @@ async def probe_openrouter_catalog(report: Report, *, live: bool) -> None: for c in config.GLOBAL_LLM_CONFIGS if c.get("provider") == "OPENROUTER" and c.get("supports_image_input") ] + or_vision = [ + c for c in config.GLOBAL_VISION_LLM_CONFIGS if c.get("provider") == "OPENROUTER" + ] or_image_gen = [ c for c in config.GLOBAL_IMAGE_GEN_CONFIGS if c.get("provider") == "OPENROUTER" ] @@ -399,6 +446,11 @@ async def probe_openrouter_catalog(report: Report, *, live: bool) -> None: ("or-chat", _pick_first(or_chat, "anthropic/claude")), ("or-chat", _pick_first(or_chat, "google/gemini-2.5-flash")), ] + vision_picks = [ + ("or-vision", _pick_first(or_vision, "openai/gpt-4o")), + ("or-vision", _pick_first(or_vision, "anthropic/claude")), + ("or-vision", _pick_first(or_vision, "google/gemini-2.5-flash")), + ] image_picks = [ ("or-image", _pick_first(or_image_gen, "google/gemini-2.5-flash-image")), # OpenRouter publishes OpenAI image models as ``openai/gpt-5-image*`` @@ -408,11 +460,11 @@ async def probe_openrouter_catalog(report: Report, *, live: bool) -> None: ] print( - f" catalog: chat_vision={len(or_chat)} image_gen={len(or_image_gen)} " + f" catalog: chat={len(or_chat)} vision={len(or_vision)} image_gen={len(or_image_gen)} " f"(service initialized={service.is_initialized() if hasattr(service, 'is_initialized') else 'n/a'})" ) - for surface, picked in chat_picks + image_picks: + for surface, picked in chat_picks + vision_picks + image_picks: if not picked: report.add( ProbeResult( @@ -453,6 +505,7 @@ async def probe_openrouter_catalog(report: Report, *, live: bool) -> None: async def main(args: argparse.Namespace) -> int: print("Loaded global configs:") print(f" chat: {len(config.GLOBAL_LLM_CONFIGS)} entries") + print(f" vision: {len(config.GLOBAL_VISION_LLM_CONFIGS)} entries") print(f" image-gen: {len(config.GLOBAL_IMAGE_GEN_CONFIGS)} entries") print(f" OR settings present: {bool(config.OPENROUTER_INTEGRATION_SETTINGS)}") @@ -473,6 +526,8 @@ async def main(args: argparse.Namespace) -> int: report = Report() if not args.skip_chat: await probe_chat_configs(report, live=args.live) + if not args.skip_vision: + await probe_vision_configs(report, live=args.live) if not args.skip_image_gen: await probe_image_gen_configs(report, live=args.live) if not args.skip_openrouter: @@ -492,6 +547,7 @@ def _parse_args() -> argparse.Namespace: ) parser.set_defaults(live=True) parser.add_argument("--skip-chat", action="store_true") + parser.add_argument("--skip-vision", action="store_true") parser.add_argument("--skip-image-gen", action="store_true") parser.add_argument("--skip-openrouter", action="store_true") return parser.parse_args() diff --git a/surfsense_backend/tests/conftest.py b/surfsense_backend/tests/conftest.py index e227ed287..e2b586aa2 100644 --- a/surfsense_backend/tests/conftest.py +++ b/surfsense_backend/tests/conftest.py @@ -13,14 +13,6 @@ TEST_DATABASE_URL = os.environ.get("TEST_DATABASE_URL", _DEFAULT_TEST_DB) # DATABASE_URL in the environment (e.g. from .env or shell profile). os.environ["DATABASE_URL"] = TEST_DATABASE_URL -# Integration tests authenticate over HTTP via email/password, so the -# password-auth routers must be mounted (they are skipped under AUTH_TYPE=GOOGLE). -# setdefault (not load_dotenv, which runs later with override=False) lets a -# developer's .env=GOOGLE be overridden here while still honouring an explicitly -# exported shell AUTH_TYPE. -os.environ.setdefault("AUTH_TYPE", "LOCAL") -os.environ.setdefault("REGISTRATION_ENABLED", "TRUE") - import pytest # noqa: E402 from app.db import DocumentType # noqa: E402 diff --git a/surfsense_backend/tests/e2e/fakes/embeddings.py b/surfsense_backend/tests/e2e/fakes/embeddings.py index 9a01fb84b..ab9e24df9 100644 --- a/surfsense_backend/tests/e2e/fakes/embeddings.py +++ b/surfsense_backend/tests/e2e/fakes/embeddings.py @@ -57,9 +57,9 @@ def install(patches: list[Any]) -> None: # Consumers that did `from app.utils.document_converters import embed_text/texts` ("app.indexing_pipeline.document_embedder.embed_text", fake_embed_text), ("app.indexing_pipeline.document_embedder.embed_texts", fake_embed_texts), - # Index-cache facade binding (the actual call site for indexing.index) + # Pipeline service binding (the actual call site for indexing.index) ( - "app.indexing_pipeline.cache.cached_indexing.embed_texts", + "app.indexing_pipeline.indexing_pipeline_service.embed_texts", fake_embed_texts, ), ] diff --git a/surfsense_backend/tests/e2e/fixtures/global_llm_config.yaml b/surfsense_backend/tests/e2e/fixtures/global_llm_config.yaml index 9ea5e1a29..017fa1eb3 100644 --- a/surfsense_backend/tests/e2e/fixtures/global_llm_config.yaml +++ b/surfsense_backend/tests/e2e/fixtures/global_llm_config.yaml @@ -19,7 +19,7 @@ # so the resolved auto-pin id is never sent to a real LLM provider. # The values below only need to pass # auto_model_pin_service._is_usable_global_config() -# which requires id / model_name / litellm_provider / api_key all truthy. +# which requires id / model_name / provider / api_key all truthy. # # Why TWO entries (premium + free): # auto_model_pin_service.resolve_or_get_pinned_llm_config_id() splits @@ -44,10 +44,9 @@ global_llm_configs: anonymous_enabled: false seo_enabled: false quality_score: 1.0 - litellm_provider: "openai" + provider: "OPENAI" model_name: "fake-e2e-model-premium" api_key: "fake-e2e-api-key-not-for-production" - api_base: "https://api.openai.com/v1" supports_image_input: false quota_reserve_tokens: 1024 rpm: 1000 @@ -61,10 +60,9 @@ global_llm_configs: anonymous_enabled: false seo_enabled: false quality_score: 1.0 - litellm_provider: "openai" + provider: "OPENAI" model_name: "fake-e2e-model-free" api_key: "fake-e2e-api-key-not-for-production" - api_base: "https://api.openai.com/v1" supports_image_input: false quota_reserve_tokens: 1024 rpm: 1000 diff --git a/surfsense_backend/tests/integration/conftest.py b/surfsense_backend/tests/integration/conftest.py index 6b8aa3cdb..19f8e3d0a 100644 --- a/surfsense_backend/tests/integration/conftest.py +++ b/surfsense_backend/tests/integration/conftest.py @@ -123,24 +123,11 @@ async def db_search_space(db_session: AsyncSession, db_user: User) -> SearchSpac return space -@pytest.fixture(autouse=True) -def _derivation_caches_disabled(monkeypatch): - """Keep integration tests hermetic regardless of the developer's .env. - - With the embedding cache enabled, a successful index of some markdown makes - every later index of the same markdown a cache hit -- silently bypassing - patched ``embed_texts`` fakes/failure injections in unrelated tests. Cache - tests opt back in explicitly via ``monkeypatch.setattr``. - """ - monkeypatch.setattr(app_config, "ETL_CACHE_ENABLED", False) - monkeypatch.setattr(app_config, "EMBEDDING_CACHE_ENABLED", False) - - @pytest.fixture def patched_embed_texts(monkeypatch) -> MagicMock: mock = MagicMock(side_effect=lambda texts: [[0.1] * _EMBEDDING_DIM for _ in texts]) monkeypatch.setattr( - "app.indexing_pipeline.cache.cached_indexing.embed_texts", + "app.indexing_pipeline.indexing_pipeline_service.embed_texts", mock, ) return mock @@ -150,7 +137,7 @@ def patched_embed_texts(monkeypatch) -> MagicMock: def patched_embed_texts_raises(monkeypatch) -> MagicMock: mock = MagicMock(side_effect=RuntimeError("Embedding unavailable")) monkeypatch.setattr( - "app.indexing_pipeline.cache.cached_indexing.embed_texts", + "app.indexing_pipeline.indexing_pipeline_service.embed_texts", mock, ) return mock @@ -160,11 +147,11 @@ def patched_embed_texts_raises(monkeypatch) -> MagicMock: def patched_chunk_text(monkeypatch) -> MagicMock: mock = MagicMock(return_value=["Test chunk content."]) monkeypatch.setattr( - "app.indexing_pipeline.cache.cached_indexing.chunk_text", + "app.indexing_pipeline.indexing_pipeline_service.chunk_text", mock, ) monkeypatch.setattr( - "app.indexing_pipeline.cache.cached_indexing.chunk_text_hybrid", + "app.indexing_pipeline.indexing_pipeline_service.chunk_text_hybrid", mock, ) return mock diff --git a/surfsense_backend/tests/integration/document_upload/conftest.py b/surfsense_backend/tests/integration/document_upload/conftest.py index bd889360f..812140be3 100644 --- a/surfsense_backend/tests/integration/document_upload/conftest.py +++ b/surfsense_backend/tests/integration/document_upload/conftest.py @@ -283,11 +283,11 @@ async def credits(): def _mock_external_apis(monkeypatch): """Mock LLM, embedding, and chunking — these are external API boundaries.""" monkeypatch.setattr( - "app.indexing_pipeline.cache.cached_indexing.embed_texts", + "app.indexing_pipeline.indexing_pipeline_service.embed_texts", MagicMock(side_effect=lambda texts: [[0.1] * _EMBEDDING_DIM for _ in texts]), ) monkeypatch.setattr( - "app.indexing_pipeline.cache.cached_indexing.chunk_text", + "app.indexing_pipeline.indexing_pipeline_service.chunk_text", MagicMock(return_value=["Test chunk content."]), ) diff --git a/surfsense_backend/tests/integration/etl_pipeline/cache/conftest.py b/surfsense_backend/tests/integration/etl_pipeline/cache/conftest.py deleted file mode 100644 index 4369cc64d..000000000 --- a/surfsense_backend/tests/integration/etl_pipeline/cache/conftest.py +++ /dev/null @@ -1,32 +0,0 @@ -"""Real-infra fixtures for the parse-cache integration tests. - -``cache_local_storage`` points the cache's blob store at a throwaway directory so -tests exercise the real ``LocalFileBackend`` (no cloud, no mocks). ``clean_cache_table`` -removes rows written through the facade's own committing session, which the -savepoint-rolled-back ``db_session`` cannot undo. -""" - -from __future__ import annotations - -import pytest -import pytest_asyncio -from sqlalchemy import text - - -@pytest.fixture -def cache_local_storage(tmp_path, monkeypatch): - from app.config import config - from app.etl_pipeline.cache.storage.backend import resolve_cache_backend - - monkeypatch.setattr(config, "ETL_CACHE_STORAGE_BACKEND", "local") - monkeypatch.setattr(config, "ETL_CACHE_STORAGE_LOCAL_PATH", str(tmp_path)) - resolve_cache_backend.cache_clear() - yield tmp_path - resolve_cache_backend.cache_clear() - - -@pytest_asyncio.fixture -async def clean_cache_table(async_engine): - yield - async with async_engine.begin() as conn: - await conn.execute(text("DELETE FROM etl_cache_parses")) diff --git a/surfsense_backend/tests/integration/etl_pipeline/cache/test_cached_extraction.py b/surfsense_backend/tests/integration/etl_pipeline/cache/test_cached_extraction.py deleted file mode 100644 index f9acd02d5..000000000 --- a/surfsense_backend/tests/integration/etl_pipeline/cache/test_cached_extraction.py +++ /dev/null @@ -1,82 +0,0 @@ -"""extract_with_cache end-to-end: real DB + real local storage. - -The only seam mocked is the parser itself (``EtlPipelineService.extract``) -- the -external boundary the facade wraps. Everything else (eligibility, hashing, recall, -remember, blob I/O) runs for real, so these tests prove the actual cost saving: -identical bytes are parsed once and reused. -""" - -from __future__ import annotations - -import pytest - -from app.config import config -from app.etl_pipeline.cache.cached_extraction import extract_with_cache -from app.etl_pipeline.etl_document import EtlRequest, EtlResult, ProcessingMode - -pytestmark = pytest.mark.integration - - -class _CountingParser: - """Stand-in for the external parser; records how often it actually ran.""" - - def __init__(self, **_kwargs) -> None: - pass - - calls = 0 - - async def extract(self, request: EtlRequest) -> EtlResult: - type(self).calls += 1 - return EtlResult( - markdown_content="# Parsed once\n", - etl_service="LLAMACLOUD", - actual_pages=3, - content_type="application/pdf", - ) - - -@pytest.fixture -def counting_parser(monkeypatch): - _CountingParser.calls = 0 - monkeypatch.setattr( - "app.etl_pipeline.cache.cached_extraction.EtlPipelineService", - _CountingParser, - ) - return _CountingParser - - -async def test_identical_uploads_are_parsed_once_then_served_from_cache( - tmp_path, monkeypatch, counting_parser, cache_local_storage, clean_cache_table -): - monkeypatch.setattr(config, "ETL_CACHE_ENABLED", True) - monkeypatch.setattr(config, "ETL_SERVICE", "LLAMACLOUD") - - pdf = tmp_path / "doc.pdf" - pdf.write_bytes(b"%PDF-1.4 unique-bytes-for-this-test") - request = EtlRequest( - file_path=str(pdf), filename="doc.pdf", processing_mode=ProcessingMode.BASIC - ) - - first = await extract_with_cache(request) - second = await extract_with_cache(request) - - assert counting_parser.calls == 1 # second upload reused the cache - assert first.markdown_content == second.markdown_content == "# Parsed once\n" - assert second.actual_pages == 3 - assert second.content_type == "application/pdf" - - -async def test_disabled_cache_parses_every_time(tmp_path, monkeypatch, counting_parser): - monkeypatch.setattr(config, "ETL_CACHE_ENABLED", False) - monkeypatch.setattr(config, "ETL_SERVICE", "LLAMACLOUD") - - pdf = tmp_path / "doc.pdf" - pdf.write_bytes(b"%PDF-1.4 another-unique-payload") - request = EtlRequest( - file_path=str(pdf), filename="doc.pdf", processing_mode=ProcessingMode.BASIC - ) - - await extract_with_cache(request) - await extract_with_cache(request) - - assert counting_parser.calls == 2 # bypassed: no reuse diff --git a/surfsense_backend/tests/integration/etl_pipeline/cache/test_cached_parse_repository.py b/surfsense_backend/tests/integration/etl_pipeline/cache/test_cached_parse_repository.py deleted file mode 100644 index 4665c44c8..000000000 --- a/surfsense_backend/tests/integration/etl_pipeline/cache/test_cached_parse_repository.py +++ /dev/null @@ -1,94 +0,0 @@ -"""CachedParseRepository against real Postgres: the SQL behind eviction & dedup. - -These verify the parts that only a real database can: coldest-first ordering by -reuse then recency, TTL cutoff selection, the size accumulator, and the -insert-once guarantee under a duplicate key. -""" - -from __future__ import annotations - -from datetime import UTC, datetime, timedelta - -import pytest - -from app.etl_pipeline.cache.persistence import CachedParseRepository -from app.etl_pipeline.cache.schemas import ParseKey - -pytestmark = pytest.mark.integration - - -def _key(sha: str) -> ParseKey: - return ParseKey.for_document(sha, etl_service="LLAMACLOUD", mode="basic", version=1) - - -async def _insert(repo, *, sha, size=100, storage_key=None): - key = _key(sha) - await repo.insert( - key=key, - content_type="application/pdf", - actual_pages=1, - storage_backend="local", - storage_key=storage_key or f"etl_cache/{sha}.md", - size_bytes=size, - ) - return key - - -async def test_total_size_bytes_sums_all_rows(db_session): - repo = CachedParseRepository(db_session) - await _insert(repo, sha="a" * 64, size=100) - await _insert(repo, sha="b" * 64, size=250) - - assert await repo.total_size_bytes() == 350 - - -async def test_select_coldest_orders_by_reuse_then_recency(db_session): - repo = CachedParseRepository(db_session) - ka = await _insert(repo, sha="a" * 64) - kb = await _insert(repo, sha="b" * 64) - kc = await _insert(repo, sha="c" * 64) - - # Warm B once and C twice; A stays untouched and should be coldest. - await repo.mark_used((await repo.get(kb)).id) - await repo.mark_used((await repo.get(kc)).id) - await repo.mark_used((await repo.get(kc)).id) - - coldest = await repo.select_coldest(limit=10) - - ids_by_reuse = [c.id for c in coldest] - assert ids_by_reuse[:3] == [ - (await repo.get(ka)).id, - (await repo.get(kb)).id, - (await repo.get(kc)).id, - ] - - -async def test_select_expired_returns_only_rows_older_than_cutoff(db_session): - repo = CachedParseRepository(db_session) - await _insert(repo, sha="a" * 64) - - future = datetime.now(UTC) + timedelta(days=1) - past = datetime.now(UTC) - timedelta(days=1) - - # Row was just used, so it's older than a future cutoff but not a past one. - assert len(await repo.select_expired(cutoff=future, limit=10)) == 1 - assert await repo.select_expired(cutoff=past, limit=10) == [] - - -async def test_duplicate_key_insert_keeps_the_first_row(db_session): - repo = CachedParseRepository(db_session) - key = await _insert(repo, sha="a" * 64, size=100, storage_key="etl_cache/first.md") - - # Same content-addressed key (a concurrent re-parse): must be a no-op. - await repo.insert( - key=key, - content_type="application/pdf", - actual_pages=1, - storage_backend="local", - storage_key="etl_cache/second.md", - size_bytes=999, - ) - - row = await repo.get(key) - assert row.storage_key == "etl_cache/first.md" - assert await repo.total_size_bytes() == 100 diff --git a/surfsense_backend/tests/integration/etl_pipeline/cache/test_etl_cache_service.py b/surfsense_backend/tests/integration/etl_pipeline/cache/test_etl_cache_service.py deleted file mode 100644 index e6041d63e..000000000 --- a/surfsense_backend/tests/integration/etl_pipeline/cache/test_etl_cache_service.py +++ /dev/null @@ -1,65 +0,0 @@ -"""EtlCacheService end-to-end against real Postgres + real local storage. - -Exercises the public cache surface -- ``recall`` / ``remember`` -- with no mocks: -a miss returns nothing, and a remembered parse comes back as an equivalent -``EtlResult`` rebuilt from the row and the blob. -""" - -from __future__ import annotations - -import pytest - -from app.etl_pipeline.cache.schemas import ParseKey -from app.etl_pipeline.cache.service import EtlCacheService -from app.etl_pipeline.etl_document import EtlResult - -pytestmark = pytest.mark.integration - - -def _key(sha: str = "c" * 64) -> ParseKey: - return ParseKey.for_document(sha, etl_service="LLAMACLOUD", mode="basic", version=1) - - -async def test_recall_is_a_miss_for_an_unknown_key(db_session, cache_local_storage): - service = EtlCacheService(db_session) - assert await service.recall(_key()) is None - - -async def test_remembered_parse_recalls_as_equivalent_result( - db_session, cache_local_storage -): - service = EtlCacheService(db_session) - stored = EtlResult( - markdown_content="# Cached doc\n\nBody paragraph.\n", - etl_service="LLAMACLOUD", - actual_pages=7, - content_type="application/pdf", - ) - - await service.remember(_key(), stored) - recalled = await service.recall(_key()) - - assert recalled is not None - assert recalled.markdown_content == stored.markdown_content - assert recalled.etl_service == "LLAMACLOUD" - assert recalled.actual_pages == 7 - assert recalled.content_type == "application/pdf" - - -async def test_repeated_recall_keeps_serving_the_same_content( - db_session, cache_local_storage -): - service = EtlCacheService(db_session) - stored = EtlResult( - markdown_content="# Stable\n", - etl_service="LLAMACLOUD", - actual_pages=1, - content_type="application/pdf", - ) - await service.remember(_key(), stored) - - first = await service.recall(_key()) - second = await service.recall(_key()) - - assert first is not None and second is not None - assert first.markdown_content == second.markdown_content == "# Stable\n" diff --git a/surfsense_backend/tests/integration/etl_pipeline/cache/test_eviction_task.py b/surfsense_backend/tests/integration/etl_pipeline/cache/test_eviction_task.py deleted file mode 100644 index 939ac74a5..000000000 --- a/surfsense_backend/tests/integration/etl_pipeline/cache/test_eviction_task.py +++ /dev/null @@ -1,96 +0,0 @@ -"""The eviction task on real infra: TTL expiry first, then coldest-over-budget. - -Seeds entries through the real cache (DB rows + local blobs), runs the actual -``_evict`` coroutine, and checks what survives via ``recall`` -- no mocks. TTL and -budget are driven through config so each phase can be exercised in isolation. -""" - -from __future__ import annotations - -import pytest - -from app.config import config -from app.etl_pipeline.cache.eviction.task import _evict -from app.etl_pipeline.cache.schemas import ParseKey -from app.etl_pipeline.cache.service import EtlCacheService -from app.etl_pipeline.etl_document import EtlResult -from app.tasks.celery_tasks import get_celery_session_maker - -pytestmark = pytest.mark.integration - - -def _key(sha: str) -> ParseKey: - return ParseKey.for_document(sha, etl_service="LLAMACLOUD", mode="basic", version=1) - - -def _result(markdown: str) -> EtlResult: - return EtlResult( - markdown_content=markdown, - etl_service="LLAMACLOUD", - actual_pages=1, - content_type="application/pdf", - ) - - -async def _remember(key: ParseKey, result: EtlResult) -> None: - async with get_celery_session_maker()() as session: - await EtlCacheService(session).remember(key, result) - - -async def _recall(key: ParseKey) -> EtlResult | None: - async with get_celery_session_maker()() as session: - return await EtlCacheService(session).recall(key) - - -async def test_expired_entries_are_pruned( - monkeypatch, cache_local_storage, clean_cache_table -): - monkeypatch.setattr(config, "ETL_CACHE_ENABLED", True) - monkeypatch.setattr( - config, "ETL_CACHE_TTL_DAYS", -1 - ) # cutoff in the future -> stale - monkeypatch.setattr(config, "ETL_CACHE_MAX_TOTAL_MB", 10_000) # size phase no-op - - key = _key("a" * 64) - await _remember(key, _result("# stale doc\n")) - - await _evict() - - assert await _recall(key) is None - - -async def test_coldest_entries_are_shed_when_over_budget( - monkeypatch, cache_local_storage, clean_cache_table -): - monkeypatch.setattr(config, "ETL_CACHE_ENABLED", True) - monkeypatch.setattr(config, "ETL_CACHE_TTL_DAYS", 3650) # nothing TTL-expired - monkeypatch.setattr(config, "ETL_CACHE_MAX_TOTAL_MB", 1) # ~1 MiB budget - - cold = _key("a" * 64) - warm = _key("b" * 64) - # Two ~0.6 MiB entries together exceed the 1 MiB budget; one must go. - await _remember(cold, _result("x" * 600_000)) - await _remember(warm, _result("y" * 600_000)) - - # A reuse makes `warm` warmer than `cold`, so `cold` is the eviction target. - assert await _recall(warm) is not None - - await _evict() - - assert await _recall(cold) is None - assert await _recall(warm) is not None - - -async def test_nothing_is_evicted_within_ttl_and_budget( - monkeypatch, cache_local_storage, clean_cache_table -): - monkeypatch.setattr(config, "ETL_CACHE_ENABLED", True) - monkeypatch.setattr(config, "ETL_CACHE_TTL_DAYS", 3650) - monkeypatch.setattr(config, "ETL_CACHE_MAX_TOTAL_MB", 10_000) - - key = _key("a" * 64) - await _remember(key, _result("# keep me\n")) - - await _evict() - - assert await _recall(key) is not None diff --git a/surfsense_backend/tests/integration/etl_pipeline/cache/test_markdown_store.py b/surfsense_backend/tests/integration/etl_pipeline/cache/test_markdown_store.py deleted file mode 100644 index a9d685017..000000000 --- a/surfsense_backend/tests/integration/etl_pipeline/cache/test_markdown_store.py +++ /dev/null @@ -1,42 +0,0 @@ -"""MarkdownCacheStore against a real local filesystem backend (no mocks). - -Proves the blob side of the cache: markdown written under a content-addressed key -comes back byte-for-byte, and a delete actually removes it. -""" - -from __future__ import annotations - -import pytest - -from app.etl_pipeline.cache.schemas import ParseKey -from app.etl_pipeline.cache.storage import MarkdownCacheStore -from app.etl_pipeline.cache.storage.object_keys import build_parse_object_key - -pytestmark = pytest.mark.integration - - -def _key() -> ParseKey: - return ParseKey.for_document( - "d" * 64, etl_service="LLAMACLOUD", mode="basic", version=1 - ) - - -async def test_save_then_load_round_trips_markdown(cache_local_storage): - store = MarkdownCacheStore() - markdown = "# Title\n\nBody with unicode: café, naïve, 漢字.\n" - - storage_key = await store.save(_key(), markdown) - - assert storage_key == build_parse_object_key(_key()) - assert await store.load(storage_key) == markdown - - -async def test_delete_removes_the_blob(cache_local_storage): - store = MarkdownCacheStore() - storage_key = await store.save(_key(), "to be deleted") - - await store.delete(storage_key) - - # Eviction deleted the blob; a later read must fail rather than serve stale. - with pytest.raises(FileNotFoundError): - await store.load(storage_key) diff --git a/surfsense_backend/tests/integration/indexing_pipeline/adapters/test_file_upload_adapter.py b/surfsense_backend/tests/integration/indexing_pipeline/adapters/test_file_upload_adapter.py index 814129c8d..311716052 100644 --- a/surfsense_backend/tests/integration/indexing_pipeline/adapters/test_file_upload_adapter.py +++ b/surfsense_backend/tests/integration/indexing_pipeline/adapters/test_file_upload_adapter.py @@ -177,7 +177,7 @@ async def test_reindex_sets_status_ready(db_session, db_search_space, db_user, m async def test_reindex_replaces_chunks(db_session, db_search_space, db_user, mocker): """Reindexing replaces old chunks with new content rather than appending.""" mocker.patch( - "app.indexing_pipeline.cache.cached_indexing.chunk_text_hybrid", + "app.indexing_pipeline.indexing_pipeline_service.chunk_text_hybrid", side_effect=[["Original chunk."], ["Updated chunk."]], ) diff --git a/surfsense_backend/tests/integration/indexing_pipeline/cache/conftest.py b/surfsense_backend/tests/integration/indexing_pipeline/cache/conftest.py deleted file mode 100644 index 6acb457ee..000000000 --- a/surfsense_backend/tests/integration/indexing_pipeline/cache/conftest.py +++ /dev/null @@ -1,33 +0,0 @@ -"""Real-infra fixtures for the embedding-cache integration tests. - -``cache_local_storage`` points the shared cache backend at a throwaway directory -so tests exercise the real ``LocalFileBackend`` (no cloud, no mocks); the -embedding cache reuses the ETL cache backend, hence the ``ETL_CACHE_STORAGE_*`` -knobs. ``clean_embedding_cache_table`` removes rows written through the store's -own committing session, which the savepoint-rolled-back ``db_session`` cannot undo. -""" - -from __future__ import annotations - -import pytest -import pytest_asyncio -from sqlalchemy import text - - -@pytest.fixture -def cache_local_storage(tmp_path, monkeypatch): - from app.config import config - from app.etl_pipeline.cache.storage.backend import resolve_cache_backend - - monkeypatch.setattr(config, "ETL_CACHE_STORAGE_BACKEND", "local") - monkeypatch.setattr(config, "ETL_CACHE_STORAGE_LOCAL_PATH", str(tmp_path)) - resolve_cache_backend.cache_clear() - yield tmp_path - resolve_cache_backend.cache_clear() - - -@pytest_asyncio.fixture -async def clean_embedding_cache_table(async_engine): - yield - async with async_engine.begin() as conn: - await conn.execute(text("DELETE FROM embedding_cache_sets")) diff --git a/surfsense_backend/tests/integration/indexing_pipeline/cache/test_cached_embedding_repository.py b/surfsense_backend/tests/integration/indexing_pipeline/cache/test_cached_embedding_repository.py deleted file mode 100644 index 446932793..000000000 --- a/surfsense_backend/tests/integration/indexing_pipeline/cache/test_cached_embedding_repository.py +++ /dev/null @@ -1,110 +0,0 @@ -"""CachedEmbeddingSetRepository against real Postgres: the SQL behind eviction & dedup. - -These verify the parts only a real database can: the size accumulator, -coldest-first ordering by reuse then recency, TTL cutoff selection, the -insert-once guarantee under a duplicate key, and the reuse counter. -""" - -from __future__ import annotations - -from datetime import UTC, datetime, timedelta - -import pytest - -from app.indexing_pipeline.cache.persistence import CachedEmbeddingSetRepository -from app.indexing_pipeline.cache.schemas import EmbeddingKey - -pytestmark = pytest.mark.integration - - -def _key(sha: str) -> EmbeddingKey: - return EmbeddingKey( - markdown_sha256=sha, - embedding_model="test-model", - embedding_dim=4, - chunker_kind="hybrid", - chunker_version=1, - ) - - -async def _insert(repo, *, sha, size=100, storage_key=None, chunk_count=1): - key = _key(sha) - await repo.insert( - key=key, - storage_backend="local", - storage_key=storage_key or f"embedding_cache/{sha}.emb", - size_bytes=size, - chunk_count=chunk_count, - ) - return key - - -async def test_total_size_bytes_sums_all_rows(db_session): - repo = CachedEmbeddingSetRepository(db_session) - await _insert(repo, sha="a" * 64, size=100) - await _insert(repo, sha="b" * 64, size=250) - - assert await repo.total_size_bytes() == 350 - - -async def test_select_coldest_orders_by_reuse_then_recency(db_session): - repo = CachedEmbeddingSetRepository(db_session) - ka = await _insert(repo, sha="a" * 64) - kb = await _insert(repo, sha="b" * 64) - kc = await _insert(repo, sha="c" * 64) - - # Warm B once and C twice; A stays untouched and should be coldest. - await repo.mark_used((await repo.get(kb)).id) - await repo.mark_used((await repo.get(kc)).id) - await repo.mark_used((await repo.get(kc)).id) - - coldest = await repo.select_coldest(limit=10) - - assert [c.id for c in coldest][:3] == [ - (await repo.get(ka)).id, - (await repo.get(kb)).id, - (await repo.get(kc)).id, - ] - - -async def test_select_expired_returns_only_rows_older_than_cutoff(db_session): - repo = CachedEmbeddingSetRepository(db_session) - await _insert(repo, sha="a" * 64) - - future = datetime.now(UTC) + timedelta(days=1) - past = datetime.now(UTC) - timedelta(days=1) - - # Row was just used, so it predates a future cutoff but not a past one. - assert len(await repo.select_expired(cutoff=future, limit=10)) == 1 - assert await repo.select_expired(cutoff=past, limit=10) == [] - - -async def test_duplicate_key_insert_keeps_the_first_row(db_session): - repo = CachedEmbeddingSetRepository(db_session) - key = await _insert( - repo, sha="a" * 64, size=100, storage_key="embedding_cache/first.emb" - ) - - # Same content-addressed key (a concurrent re-embed): must be a no-op. - await repo.insert( - key=key, - storage_backend="local", - storage_key="embedding_cache/second.emb", - size_bytes=999, - chunk_count=42, - ) - - row = await repo.get(key) - assert row.storage_key == "embedding_cache/first.emb" - assert await repo.total_size_bytes() == 100 - - -async def test_mark_used_increments_reuse_count(db_session): - repo = CachedEmbeddingSetRepository(db_session) - key = await _insert(repo, sha="a" * 64) - assert (await repo.get(key)).times_reused == 0 - - await repo.mark_used((await repo.get(key)).id) - await repo.mark_used((await repo.get(key)).id) - - assert (await repo.get(key)).times_reused == 2 diff --git a/surfsense_backend/tests/integration/indexing_pipeline/cache/test_embedding_cache_service.py b/surfsense_backend/tests/integration/indexing_pipeline/cache/test_embedding_cache_service.py deleted file mode 100644 index 548208131..000000000 --- a/surfsense_backend/tests/integration/indexing_pipeline/cache/test_embedding_cache_service.py +++ /dev/null @@ -1,74 +0,0 @@ -"""EmbeddingCacheService end-to-end against real Postgres + real local storage. - -Exercises the public cache surface -- ``recall`` / ``remember`` -- with no mocks: -a miss returns nothing, a remembered set comes back as equivalent vectors, and a -dimension mismatch is refused rather than served. -""" - -from __future__ import annotations - -import numpy as np -import pytest - -from app.indexing_pipeline.cache.schemas import CachedChunk, EmbeddingKey, EmbeddingSet -from app.indexing_pipeline.cache.service import EmbeddingCacheService - -pytestmark = pytest.mark.integration - - -def _key(sha: str = "c" * 64, *, dim: int = 4) -> EmbeddingKey: - return EmbeddingKey( - markdown_sha256=sha, - embedding_model="test-model", - embedding_dim=dim, - chunker_kind="hybrid", - chunker_version=1, - ) - - -async def test_recall_is_a_miss_for_an_unknown_key(db_session, cache_local_storage): - service = EmbeddingCacheService(db_session) - assert await service.recall(_key()) is None - - -async def test_remembered_set_recalls_as_equivalent_vectors( - db_session, cache_local_storage, clean_embedding_cache_table -): - service = EmbeddingCacheService(db_session) - stored = EmbeddingSet( - summary_embedding=np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32), - chunks=[ - CachedChunk( - "first chunk", np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32) - ), - CachedChunk( - "second chunk", np.array([0.0, 1.0, 0.0, 0.0], dtype=np.float32) - ), - ], - ) - - await service.remember(_key(), stored) - recalled = await service.recall(_key()) - - assert recalled is not None - assert np.array_equal(recalled.summary_embedding, stored.summary_embedding) - assert [c.text for c in recalled.chunks] == ["first chunk", "second chunk"] - assert np.array_equal(recalled.chunks[0].embedding, stored.chunks[0].embedding) - assert np.array_equal(recalled.chunks[1].embedding, stored.chunks[1].embedding) - - -async def test_recall_refuses_a_set_whose_dimension_changed( - db_session, cache_local_storage, clean_embedding_cache_table -): - # A model kept its name but changed its output width: never serve the stale blob. - service = EmbeddingCacheService(db_session) - stored = EmbeddingSet( - summary_embedding=np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32), - chunks=[CachedChunk("c", np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32))], - ) - await service.remember(_key(dim=4), stored) - - # Same identity (model + chunker + markdown), but the caller now expects dim 8. - recalled = await service.recall(_key(dim=8)) - - assert recalled is None diff --git a/surfsense_backend/tests/integration/indexing_pipeline/cache/test_embedding_store.py b/surfsense_backend/tests/integration/indexing_pipeline/cache/test_embedding_store.py deleted file mode 100644 index 83becd7b5..000000000 --- a/surfsense_backend/tests/integration/indexing_pipeline/cache/test_embedding_store.py +++ /dev/null @@ -1,63 +0,0 @@ -"""EmbeddingCacheStore against a real local filesystem backend (no mocks). - -Proves the blob side of the cache: an embedding set written under a -content-addressed key comes back with identical vectors, and a delete actually -removes it. -""" - -from __future__ import annotations - -import numpy as np -import pytest - -from app.indexing_pipeline.cache.schemas import CachedChunk, EmbeddingKey, EmbeddingSet -from app.indexing_pipeline.cache.storage import EmbeddingCacheStore -from app.indexing_pipeline.cache.storage.object_keys import build_embedding_object_key - -pytestmark = pytest.mark.integration - - -def _key() -> EmbeddingKey: - return EmbeddingKey( - markdown_sha256="d" * 64, - embedding_model="test-model", - embedding_dim=4, - chunker_kind="hybrid", - chunker_version=1, - ) - - -def _set() -> EmbeddingSet: - return EmbeddingSet( - summary_embedding=np.array([0.5, 0.25, 0.125, 0.0625], dtype=np.float32), - chunks=[ - CachedChunk("café, naïve, 漢字", np.array([1, 2, 3, 4], dtype=np.float32)), - CachedChunk("second", np.array([5, 6, 7, 8], dtype=np.float32)), - ], - ) - - -async def test_save_then_load_round_trips_the_embedding_set(cache_local_storage): - store = EmbeddingCacheStore() - embedding_set = _set() - - storage_key, size_bytes = await store.save(_key(), embedding_set) - loaded = await store.load(storage_key) - - assert storage_key == build_embedding_object_key(_key()) - assert size_bytes > 0 - assert np.array_equal(loaded.summary_embedding, embedding_set.summary_embedding) - assert [c.text for c in loaded.chunks] == ["café, naïve, 漢字", "second"] - assert np.array_equal(loaded.chunks[0].embedding, embedding_set.chunks[0].embedding) - assert np.array_equal(loaded.chunks[1].embedding, embedding_set.chunks[1].embedding) - - -async def test_delete_removes_the_blob(cache_local_storage): - store = EmbeddingCacheStore() - storage_key, _ = await store.save(_key(), _set()) - - await store.delete(storage_key) - - # Eviction deleted the blob; a later read must fail rather than serve stale. - with pytest.raises(FileNotFoundError): - await store.load(storage_key) diff --git a/surfsense_backend/tests/integration/indexing_pipeline/test_index_editions.py b/surfsense_backend/tests/integration/indexing_pipeline/test_index_editions.py deleted file mode 100644 index 68d5ec0af..000000000 --- a/surfsense_backend/tests/integration/indexing_pipeline/test_index_editions.py +++ /dev/null @@ -1,193 +0,0 @@ -"""Edit path: re-indexing a document diffs chunks instead of replacing them. - -Unchanged paragraphs must keep their chunk rows (ids survive -> embeddings and -HNSW entries untouched), only new text is embedded, removed text is deleted, -and (position) keeps presentation order correct throughout. -""" - -import pytest -from sqlalchemy import select - -from app.db import Chunk, DocumentStatus -from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService - -pytestmark = pytest.mark.integration - -_V1 = "Intro paragraph.\n\nBody paragraph.\n\nOutro paragraph." - - -@pytest.fixture -def paragraph_chunker(monkeypatch): - """One chunk per markdown paragraph, so edits map to chunk-level diffs.""" - - def _split(markdown, **_kwargs): - return [p for p in markdown.split("\n\n") if p.strip()] - - monkeypatch.setattr( - "app.indexing_pipeline.cache.cached_indexing.chunk_text", _split - ) - monkeypatch.setattr( - "app.indexing_pipeline.cache.cached_indexing.chunk_text_hybrid", _split - ) - - -async def _index(service, connector_doc): - prepared = await service.prepare_for_indexing([connector_doc]) - document = prepared[0] - await service.index(document, connector_doc) - return document - - -async def _load_chunks(db_session, document_id): - result = await db_session.execute( - select(Chunk) - .where(Chunk.document_id == document_id) - .order_by(Chunk.position, Chunk.id) - ) - return result.scalars().all() - - -@pytest.mark.usefixtures("paragraph_chunker") -async def test_edit_keeps_unchanged_rows_and_embeds_only_the_new_text( - db_session, - db_search_space, - make_connector_document, - patched_embed_texts, -): - service = IndexingPipelineService(session=db_session) - doc_v1 = make_connector_document( - search_space_id=db_search_space.id, source_markdown=_V1 - ) - document = await _index(service, doc_v1) - - ids_v1 = {c.content: c.id for c in await _load_chunks(db_session, document.id)} - patched_embed_texts.reset_mock() - - edited = "Intro paragraph.\n\nBody paragraph EDITED.\n\nOutro paragraph." - doc_v2 = make_connector_document( - search_space_id=db_search_space.id, source_markdown=edited - ) - await _index(service, doc_v2) - - chunks = await _load_chunks(db_session, document.id) - by_content = {c.content: c for c in chunks} - - # Untouched paragraphs keep their rows (same ids => embeddings reused, - # no HNSW/GIN churn); the edited paragraph got a fresh row. - assert by_content["Intro paragraph."].id == ids_v1["Intro paragraph."] - assert by_content["Outro paragraph."].id == ids_v1["Outro paragraph."] - assert "Body paragraph." not in by_content - assert by_content["Body paragraph EDITED."].id not in ids_v1.values() - - # Exactly one embed call: the document summary plus only the edited text. - (embedded_texts,) = patched_embed_texts.call_args.args - assert embedded_texts == [edited, "Body paragraph EDITED."] - - assert [c.position for c in chunks] == [0, 1, 2] - assert [c.content for c in chunks] == [ - "Intro paragraph.", - "Body paragraph EDITED.", - "Outro paragraph.", - ] - - -@pytest.mark.usefixtures("paragraph_chunker", "patched_embed_texts") -async def test_head_insert_shifts_positions_without_new_rows_for_old_text( - db_session, - db_search_space, - make_connector_document, -): - service = IndexingPipelineService(session=db_session) - document = await _index( - service, - make_connector_document( - search_space_id=db_search_space.id, source_markdown=_V1 - ), - ) - ids_v1 = {c.content: c.id for c in await _load_chunks(db_session, document.id)} - - await _index( - service, - make_connector_document( - search_space_id=db_search_space.id, - source_markdown="Brand new opener.\n\n" + _V1, - ), - ) - - chunks = await _load_chunks(db_session, document.id) - assert [c.content for c in chunks] == [ - "Brand new opener.", - "Intro paragraph.", - "Body paragraph.", - "Outro paragraph.", - ] - assert [c.position for c in chunks] == [0, 1, 2, 3] - # The three original rows survived the shift. - surviving = {c.content: c.id for c in chunks if c.content in ids_v1} - assert surviving == ids_v1 - - -@pytest.mark.usefixtures("paragraph_chunker", "patched_embed_texts") -async def test_removed_paragraph_is_deleted_and_order_compacts( - db_session, - db_search_space, - make_connector_document, -): - service = IndexingPipelineService(session=db_session) - document = await _index( - service, - make_connector_document( - search_space_id=db_search_space.id, source_markdown=_V1 - ), - ) - ids_v1 = {c.content: c.id for c in await _load_chunks(db_session, document.id)} - - await _index( - service, - make_connector_document( - search_space_id=db_search_space.id, - source_markdown="Intro paragraph.\n\nOutro paragraph.", - ), - ) - - chunks = await _load_chunks(db_session, document.id) - assert [(c.content, c.position) for c in chunks] == [ - ("Intro paragraph.", 0), - ("Outro paragraph.", 1), - ] - assert chunks[0].id == ids_v1["Intro paragraph."] - assert chunks[1].id == ids_v1["Outro paragraph."] - - -@pytest.mark.usefixtures("paragraph_chunker", "patched_embed_texts") -async def test_kill_switch_falls_back_to_full_replace( - db_session, - db_search_space, - make_connector_document, - monkeypatch, -): - from app.config import config - - service = IndexingPipelineService(session=db_session) - document = await _index( - service, - make_connector_document( - search_space_id=db_search_space.id, source_markdown=_V1 - ), - ) - ids_v1 = {c.id for c in await _load_chunks(db_session, document.id)} - - monkeypatch.setattr(config, "CHUNK_RECONCILE_ENABLED", False) - await _index( - service, - make_connector_document( - search_space_id=db_search_space.id, - source_markdown=_V1 + "\n\nAppended paragraph.", - ), - ) - - chunks = await _load_chunks(db_session, document.id) - # Legacy behavior: every row is recreated, even unchanged paragraphs. - assert {c.id for c in chunks}.isdisjoint(ids_v1) - assert [c.position for c in chunks] == [0, 1, 2, 3] - assert DocumentStatus.is_state(document.status, DocumentStatus.READY) diff --git a/surfsense_backend/tests/integration/notifications/test_document_processing_handler.py b/surfsense_backend/tests/integration/notifications/test_document_processing_handler.py index 5ca560f11..f602f2e66 100644 --- a/surfsense_backend/tests/integration/notifications/test_document_processing_handler.py +++ b/surfsense_backend/tests/integration/notifications/test_document_processing_handler.py @@ -78,23 +78,3 @@ async def test_processing_completed_failure( assert done.title == "Failed: report.pdf" assert done.message == "Processing failed: bad file" assert done.notification_metadata["status"] == "failed" - - -async def test_processing_started_truncates_long_filename( - db_session: AsyncSession, db_user: User, db_search_space: SearchSpace -): - """A long filename is truncated in the title but kept in metadata.""" - long_name = "a" * 250 - - notification = await handler.notify_processing_started( - session=db_session, - user_id=db_user.id, - document_type="FILE", - document_name=long_name, - search_space_id=db_search_space.id, - ) - - assert len(notification.title) <= 200 - assert notification.title.startswith("Processing: ") - assert notification.title.endswith("...") - assert notification.notification_metadata["document_name"] == long_name diff --git a/surfsense_backend/tests/integration/podcasts/conftest.py b/surfsense_backend/tests/integration/podcasts/conftest.py index 75248a6a1..f244c17d2 100644 --- a/surfsense_backend/tests/integration/podcasts/conftest.py +++ b/surfsense_backend/tests/integration/podcasts/conftest.py @@ -120,9 +120,6 @@ class FakeStorageBackend: async def open_stream(self, key: str) -> AsyncIterator[bytes]: yield self.objects.get(key, b"audio-bytes") - async def exists(self, key: str) -> bool: - return key in self.objects - async def delete(self, key: str) -> None: self.deleted.append(key) @@ -217,7 +214,7 @@ def build_spec( slot=1, name="Guest", role=SpeakerRole.GUEST, voice_id=voice_ids[1] ), ], - duration=DurationTarget(min_seconds=600, max_seconds=1200), + duration=DurationTarget(min_minutes=10, max_minutes=20), ) diff --git a/surfsense_backend/tests/integration/podcasts/test_draft_task.py b/surfsense_backend/tests/integration/podcasts/test_draft_task.py index 014d98b1f..7dadfc2f5 100644 --- a/surfsense_backend/tests/integration/podcasts/test_draft_task.py +++ b/surfsense_backend/tests/integration/podcasts/test_draft_task.py @@ -76,7 +76,8 @@ async def test_quota_denial_fails_the_podcast_without_a_transcript( async def _deny(**_kwargs): raise QuotaInsufficientError( usage_type="podcast_generation", - balance_micros=5_000_000, + used_micros=5_000_000, + limit_micros=5_000_000, remaining_micros=0, ) yield # pragma: no cover - unreachable, satisfies the CM protocol diff --git a/surfsense_backend/tests/integration/podcasts/test_public_stream.py b/surfsense_backend/tests/integration/podcasts/test_public_stream.py index 63f634234..d2ba1d1b9 100644 --- a/surfsense_backend/tests/integration/podcasts/test_public_stream.py +++ b/surfsense_backend/tests/integration/podcasts/test_public_stream.py @@ -48,22 +48,6 @@ async def test_public_stream_serves_audio_via_storage_key( assert resp.content == b"public-audio" -async def test_public_stream_404_when_object_missing( - client, db_session, db_search_space, db_user, fake_storage -): - await _snapshot( - db_session, - search_space_id=db_search_space.id, - user=db_user, - token="tok-gone", - podcasts=[{"original_id": 556, "storage_key": "podcasts/gone.mp3"}], - ) - - resp = await client.get("/api/v1/public/tok-gone/podcasts/556/stream") - - assert resp.status_code == 404 - - async def test_public_stream_404_when_podcast_absent_from_snapshot( client, db_session, db_search_space, db_user ): diff --git a/surfsense_backend/tests/integration/podcasts/test_streaming.py b/surfsense_backend/tests/integration/podcasts/test_streaming.py index b924e2971..82456bac9 100644 --- a/surfsense_backend/tests/integration/podcasts/test_streaming.py +++ b/surfsense_backend/tests/integration/podcasts/test_streaming.py @@ -1,7 +1,8 @@ """Streaming a podcast's rendered audio over HTTP. -A ready podcast streams its bytes; an in-flight one is 409, a stored-but-missing -object is 404. Storage is an in-memory backend (the object store is a boundary). +A ready podcast streams its bytes from the storage backend; a podcast with no +stored audio returns 404. Storage is an in-memory backend (the object store is a +system boundary). """ from __future__ import annotations @@ -30,23 +31,11 @@ async def test_stream_serves_stored_audio( assert resp.content == b"the-audio" -async def test_stream_409_while_in_flight(client, db_search_space, make_podcast): +async def test_stream_404_when_no_audio(client, db_search_space, make_podcast): podcast = await make_podcast( search_space_id=db_search_space.id, status=PodcastStatus.DRAFTING ) resp = await client.get(f"{BASE}/{podcast.id}/stream") - assert resp.status_code == 409 - - -async def test_stream_404_when_object_missing( - client, db_search_space, make_podcast, fake_storage -): - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.READY - ) - - resp = await client.get(f"{BASE}/{podcast.id}/stream") - assert resp.status_code == 404 diff --git a/surfsense_backend/tests/integration/podcasts/test_voices.py b/surfsense_backend/tests/integration/podcasts/test_voices.py index fd41bfd4e..688ddad56 100644 --- a/surfsense_backend/tests/integration/podcasts/test_voices.py +++ b/surfsense_backend/tests/integration/podcasts/test_voices.py @@ -29,23 +29,3 @@ async def test_voices_503_when_no_tts_configured(client, monkeypatch): resp = await client.get(f"{BASE}/voices") assert resp.status_code == 503 - - -async def test_languages_returns_the_active_providers_offering(client): - """The brief form renders exactly what the backend offers — for a wildcard - provider (openai/tts-1) that is the curated list plus free entry.""" - resp = await client.get(f"{BASE}/languages") - - assert resp.status_code == 200 - offering = resp.json() - assert "en" in offering["languages"] - assert "fr" in offering["languages"] - assert offering["allows_custom"] is True - - -async def test_languages_503_when_no_tts_configured(client, monkeypatch): - monkeypatch.setattr(app_config, "TTS_SERVICE", "") - - resp = await client.get(f"{BASE}/languages") - - assert resp.status_code == 503 diff --git a/surfsense_backend/tests/integration/test_connector_index_authz.py b/surfsense_backend/tests/integration/test_connector_index_authz.py deleted file mode 100644 index cea2407cc..000000000 --- a/surfsense_backend/tests/integration/test_connector_index_authz.py +++ /dev/null @@ -1,148 +0,0 @@ -"""Cross-search-space authorization on the connector index endpoint. - -``POST /search-source-connectors/{connector_id}/index?search_space_id=`` must -authorize against the **connector's own** ``search_space_id`` (matching the -read/update/delete handlers), not the caller-supplied ``search_space_id`` query -parameter, and must reject a connector that does not belong to the requested -search space. - -Without this, a user who owns search space B could index another user's -connector (which lives in space A) by passing ``search_space_id=B``: the -background indexer would run with the **victim connector's stored credentials** -and write the fetched content into the attacker's space. These tests pin that -boundary. -""" - -from __future__ import annotations - -import contextlib -import uuid -from unittest.mock import AsyncMock, patch - -import pytest -from fastapi import HTTPException -from sqlalchemy.ext.asyncio import AsyncSession - -from app.db import ( - SearchSourceConnector, - SearchSourceConnectorType, - SearchSpace, - User, -) -from app.routes.search_source_connectors_routes import index_connector_content -from app.routes.search_spaces_routes import create_default_roles_and_membership - -pytestmark = pytest.mark.integration - -# The handler imports ``check_permission`` into its own module namespace. -_CHECK_PERMISSION = "app.routes.search_source_connectors_routes.check_permission" - - -async def _make_user_with_space(session: AsyncSession) -> tuple[User, SearchSpace]: - """A user plus a search space they own, with the default roles/membership - the ``POST /searchspaces`` route would create (so ``check_permission`` would - legitimately pass for this user on this space).""" - user = User( - id=uuid.uuid4(), - email=f"authz-{uuid.uuid4()}@surfsense.test", - hashed_password="x", - is_active=True, - is_superuser=False, - is_verified=True, - ) - session.add(user) - await session.flush() - space = SearchSpace(name=f"Space {uuid.uuid4().hex[:8]}", user_id=user.id) - session.add(space) - await session.flush() - await create_default_roles_and_membership(session, space.id, user.id) - await session.flush() - return user, space - - -async def _make_connector( - session: AsyncSession, - owner: User, - space: SearchSpace, - connector_type: SearchSourceConnectorType, -) -> SearchSourceConnector: - connector = SearchSourceConnector( - name="Connector", - connector_type=connector_type, - # A stored credential the indexer would use — the thing a cross-tenant - # index must never be able to abuse. - config={ - "GITHUB_PAT": "victim-secret-pat", - "repo_full_names": ["octocat/Hello-World"], - }, - is_indexable=True, - search_space_id=space.id, - user_id=owner.id, - ) - session.add(connector) - await session.flush() - return connector - - -class TestConnectorIndexCrossSpaceAuthz: - async def test_cross_space_index_is_rejected_before_permission_check( - self, db_session: AsyncSession - ): - """Attacker (owns space B) cannot index victim's connector (in space A) - by passing ``search_space_id=B``. - - The mismatch is rejected with 404 **before** ``check_permission`` runs — - which is essential, because that permission check *would* pass: the - attacker legitimately holds ``CONNECTORS_UPDATE`` on their own space B. - """ - victim, space_a = await _make_user_with_space(db_session) - attacker, space_b = await _make_user_with_space(db_session) - connector_a = await _make_connector( - db_session, victim, space_a, SearchSourceConnectorType.GITHUB_CONNECTOR - ) - - with ( - patch(_CHECK_PERMISSION, new=AsyncMock()) as check_permission_mock, - pytest.raises(HTTPException) as exc_info, - ): - await index_connector_content( - connector_id=connector_a.id, - search_space_id=space_b.id, # the attacker's own space - session=db_session, - user=attacker, - ) - - assert exc_info.value.status_code == 404 - # Rejected at the search-space reconciliation, never reaching (or relying - # on) the permission check — which would have passed for space B. - check_permission_mock.assert_not_awaited() - - async def test_same_space_index_authorizes_against_the_connectors_own_space( - self, db_session: AsyncSession - ): - """A legitimate same-space index passes the reconciliation and authorizes - ``check_permission`` against the connector's **own** search space (not the - client-supplied query param).""" - owner, space = await _make_user_with_space(db_session) - # A "live" connector type returns early (no Celery dispatch) right after - # the permission check, so the call exercises the authz path cleanly. - connector = await _make_connector( - db_session, owner, space, SearchSourceConnectorType.CLICKUP_CONNECTOR - ) - - # Any downstream indexing behaviour is irrelevant to the authz contract - # under test; we only assert what space was authorized. - with ( - patch(_CHECK_PERMISSION, new=AsyncMock()) as check_permission_mock, - contextlib.suppress(Exception), - ): - await index_connector_content( - connector_id=connector.id, - search_space_id=space.id, # the connector's own space - session=db_session, - user=owner, - ) - - check_permission_mock.assert_awaited_once() - # The space passed to check_permission must be the connector's own space. - assert connector.search_space_id in check_permission_mock.await_args.args diff --git a/surfsense_backend/tests/unit/agents/chat/runtime/test_llm_config_sanitizer.py b/surfsense_backend/tests/unit/agents/chat/runtime/test_llm_config_sanitizer.py deleted file mode 100644 index e987f8441..000000000 --- a/surfsense_backend/tests/unit/agents/chat/runtime/test_llm_config_sanitizer.py +++ /dev/null @@ -1,39 +0,0 @@ -"""Regression tests for model-boundary message sanitization.""" - -from __future__ import annotations - -import pytest -from langchain_core.messages import AIMessage - -from app.agents.chat.runtime.llm_config import _sanitize_messages - -pytestmark = pytest.mark.unit - - -def test_sanitize_messages_strips_provider_specific_thinking_blocks() -> None: - original = AIMessage( - content=[ - {"type": "thinking", "thinking": "private reasoning"}, - {"type": "text", "text": "visible answer"}, - ] - ) - - sanitized = _sanitize_messages([original]) - - assert sanitized[0].content == "visible answer" - assert original.content == [ - {"type": "thinking", "thinking": "private reasoning"}, - {"type": "text", "text": "visible answer"}, - ] - - -def test_sanitize_messages_sets_tool_only_ai_content_to_none() -> None: - message = AIMessage( - content="", - tool_calls=[{"name": "search", "args": {"q": "x"}, "id": "call_1"}], - ) - - sanitized = _sanitize_messages([message]) - - assert sanitized[0].content is None - assert message.content == "" diff --git a/surfsense_backend/tests/unit/automations/actions/builtin/agent_task/test_dependencies.py b/surfsense_backend/tests/unit/automations/actions/builtin/agent_task/test_dependencies.py index f5709e517..79da12933 100644 --- a/surfsense_backend/tests/unit/automations/actions/builtin/agent_task/test_dependencies.py +++ b/surfsense_backend/tests/unit/automations/actions/builtin/agent_task/test_dependencies.py @@ -1,6 +1,6 @@ """Lock the runtime model-policy backstop in ``build_dependencies``. -Automations resolve their LLM from the *captured* ``chat_model_id`` snapshot (so +Automations resolve their LLM from the *captured* ``agent_llm_id`` snapshot (so runs are insulated from later chat/search-space model changes), and the model policy is re-checked at run time so a captured model that is no longer billable fails the run clearly. When no snapshot is present, resolution falls back to the @@ -45,10 +45,10 @@ def patched_side_effects(monkeypatch: pytest.MonkeyPatch): return None -async def test_build_dependencies_resolves_captured_chat_model_id( +async def test_build_dependencies_resolves_captured_agent_llm_id( monkeypatch: pytest.MonkeyPatch, patched_side_effects ) -> None: - """The bundle loads with the *captured* ``chat_model_id``, not the live search space.""" + """The bundle loads with the *captured* ``agent_llm_id``, not the live search space.""" captured: dict[str, Any] = {} async def _fake_load(_session, *, config_id, search_space_id): @@ -67,13 +67,13 @@ async def test_build_dependencies_resolves_captured_chat_model_id( lambda _ss: pytest.fail("search-space policy should not run on captured path"), ) - search_space = SimpleNamespace(chat_model_id=-99) + search_space = SimpleNamespace(agent_llm_id=-99) result = await build_dependencies( session=_FakeSession(search_space), search_space_id=42, - chat_model_id=-7, - image_gen_model_id=5, - vision_model_id=-1, + agent_llm_id=-7, + image_generation_config_id=5, + vision_llm_config_id=-1, ) assert captured == {"config_id": -7, "search_space_id": 42} @@ -98,17 +98,17 @@ async def test_build_dependencies_validates_captured_ids( monkeypatch.setattr(deps_mod, "load_llm_bundle", _fake_load) await build_dependencies( - session=_FakeSession(SimpleNamespace(chat_model_id=0)), + session=_FakeSession(SimpleNamespace(agent_llm_id=0)), search_space_id=42, - chat_model_id=-7, - image_gen_model_id=5, - vision_model_id=-1, + agent_llm_id=-7, + image_generation_config_id=5, + vision_llm_config_id=-1, ) assert seen == { - "chat_model_id": -7, - "image_gen_model_id": 5, - "vision_model_id": -1, + "agent_llm_id": -7, + "image_generation_config_id": 5, + "vision_llm_config_id": -1, } @@ -119,7 +119,7 @@ async def test_build_dependencies_raises_on_captured_policy_violation( def _raise(**_kw): raise AutomationModelPolicyError( - [{"kind": "image", "model_id": -2, "reason": "free model"}] + [{"kind": "image", "config_id": -2, "reason": "free model"}] ) monkeypatch.setattr(deps_mod, "assert_models_billable", _raise) @@ -131,11 +131,11 @@ async def test_build_dependencies_raises_on_captured_policy_violation( with pytest.raises(DependencyError): await build_dependencies( - session=_FakeSession(SimpleNamespace(chat_model_id=-7)), + session=_FakeSession(SimpleNamespace(agent_llm_id=-7)), search_space_id=42, - chat_model_id=-7, - image_gen_model_id=-2, - vision_model_id=-1, + agent_llm_id=-7, + image_generation_config_id=-2, + vision_llm_config_id=-1, ) @@ -157,7 +157,7 @@ async def test_build_dependencies_falls_back_to_search_space( lambda **_kw: pytest.fail("captured policy should not run on fallback path"), ) - search_space = SimpleNamespace(chat_model_id=-7) + search_space = SimpleNamespace(agent_llm_id=-7) result = await build_dependencies( session=_FakeSession(search_space), search_space_id=42 ) diff --git a/surfsense_backend/tests/unit/automations/runtime/test_executor_action_ctx.py b/surfsense_backend/tests/unit/automations/runtime/test_executor_action_ctx.py index c89624fbf..d7e3c4a0c 100644 --- a/surfsense_backend/tests/unit/automations/runtime/test_executor_action_ctx.py +++ b/surfsense_backend/tests/unit/automations/runtime/test_executor_action_ctx.py @@ -28,9 +28,9 @@ def _run() -> SimpleNamespace: def test_build_action_ctx_propagates_captured_models() -> None: """``definition.models`` flows onto the ActionContext model fields.""" models = AutomationModels( - chat_model_id=-1, - image_gen_model_id=5, - vision_model_id=-1, + agent_llm_id=-1, + image_generation_config_id=5, + vision_llm_config_id=-1, ) ctx = _build_action_ctx( cast(AsyncSession, None), @@ -40,9 +40,9 @@ def test_build_action_ctx_propagates_captured_models() -> None: ) assert ctx.search_space_id == 42 - assert ctx.chat_model_id == -1 - assert ctx.image_gen_model_id == 5 - assert ctx.vision_model_id == -1 + assert ctx.agent_llm_id == -1 + assert ctx.image_generation_config_id == 5 + assert ctx.vision_llm_config_id == -1 def test_build_action_ctx_none_models_leaves_fields_none() -> None: @@ -54,6 +54,6 @@ def test_build_action_ctx_none_models_leaves_fields_none() -> None: None, ) - assert ctx.chat_model_id is None - assert ctx.image_gen_model_id is None - assert ctx.vision_model_id is None + assert ctx.agent_llm_id is None + assert ctx.image_generation_config_id is None + assert ctx.vision_llm_config_id is None diff --git a/surfsense_backend/tests/unit/automations/schemas/definition/test_envelope.py b/surfsense_backend/tests/unit/automations/schemas/definition/test_envelope.py index dc7221b11..25e193ffa 100644 --- a/surfsense_backend/tests/unit/automations/schemas/definition/test_envelope.py +++ b/surfsense_backend/tests/unit/automations/schemas/definition/test_envelope.py @@ -40,24 +40,24 @@ def test_automation_definition_models_round_trip() -> None: name="Daily digest", plan=[PlanStep(step_id="s1", action="agent_task")], models=AutomationModels( - chat_model_id=-1, - image_gen_model_id=5, - vision_model_id=-1, + agent_llm_id=-1, + image_generation_config_id=5, + vision_llm_config_id=-1, ), ) dumped = definition.model_dump(mode="json", by_alias=True) assert dumped["models"] == { - "chat_model_id": -1, - "image_gen_model_id": 5, - "vision_model_id": -1, + "agent_llm_id": -1, + "image_generation_config_id": 5, + "vision_llm_config_id": -1, } restored = AutomationDefinition.model_validate(dumped) assert restored.models is not None - assert restored.models.chat_model_id == -1 - assert restored.models.image_gen_model_id == 5 - assert restored.models.vision_model_id == -1 + assert restored.models.agent_llm_id == -1 + assert restored.models.image_generation_config_id == 5 + assert restored.models.vision_llm_config_id == -1 def test_automation_definition_rejects_unknown_top_level_field() -> None: diff --git a/surfsense_backend/tests/unit/automations/services/test_automation_service_policy.py b/surfsense_backend/tests/unit/automations/services/test_automation_service_policy.py index c97dec6a2..0bbff39dc 100644 --- a/surfsense_backend/tests/unit/automations/services/test_automation_service_policy.py +++ b/surfsense_backend/tests/unit/automations/services/test_automation_service_policy.py @@ -64,12 +64,12 @@ async def test_assert_models_billable_raises_422_on_violation( def _raise(_ss): raise AutomationModelPolicyError( - [{"kind": "llm", "model_id": 0, "reason": "Auto mode"}] + [{"kind": "llm", "config_id": 0, "reason": "Auto mode"}] ) monkeypatch.setattr(automation_mod, "assert_automation_models_billable", _raise) - service = _service(SimpleNamespace(chat_model_id=0)) + service = _service(SimpleNamespace(agent_llm_id=0)) with pytest.raises(HTTPException) as exc_info: await service._assert_models_billable(1) @@ -99,7 +99,7 @@ async def test_assert_models_billable_returns_search_space_when_ok( automation_mod, "assert_automation_models_billable", lambda _ss: None ) - search_space = SimpleNamespace(chat_model_id=-1) + search_space = SimpleNamespace(agent_llm_id=-1) service = _service(search_space) assert await service._assert_models_billable(1) is search_space @@ -123,9 +123,9 @@ async def test_create_injects_captured_models_from_search_space( monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added) search_space = SimpleNamespace( - chat_model_id=-1, - image_gen_model_id=5, - vision_model_id=-1, + agent_llm_id=-1, + image_generation_config_id=5, + vision_llm_config_id=-1, ) service = _service(search_space) payload = AutomationCreate( @@ -137,9 +137,9 @@ async def test_create_injects_captured_models_from_search_space( automation = await service.create(payload) assert automation.definition["models"] == { - "chat_model_id": -1, - "image_gen_model_id": 5, - "vision_model_id": -1, + "agent_llm_id": -1, + "image_generation_config_id": 5, + "vision_llm_config_id": -1, } @@ -162,9 +162,9 @@ async def test_create_treats_unset_prefs_as_auto_zero( monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added) search_space = SimpleNamespace( - chat_model_id=None, - image_gen_model_id=None, - vision_model_id=None, + agent_llm_id=None, + image_generation_config_id=None, + vision_llm_config_id=None, ) service = _service(search_space) payload = AutomationCreate(search_space_id=1, name="A", definition=_definition()) @@ -172,9 +172,9 @@ async def test_create_treats_unset_prefs_as_auto_zero( automation = await service.create(payload) assert automation.definition["models"] == { - "chat_model_id": 0, - "image_gen_model_id": 0, - "vision_model_id": 0, + "agent_llm_id": 0, + "image_generation_config_id": 0, + "vision_llm_config_id": 0, } @@ -195,11 +195,11 @@ async def test_create_honors_selected_models_when_provided( ) validated: dict[str, Any] = {} - def _assert_ok(*, chat_model_id, image_gen_model_id, vision_model_id): + def _assert_ok(*, agent_llm_id, image_generation_config_id, vision_llm_config_id): validated["ids"] = ( - chat_model_id, - image_gen_model_id, - vision_model_id, + agent_llm_id, + image_generation_config_id, + vision_llm_config_id, ) monkeypatch.setattr(automation_mod, "assert_models_billable", _assert_ok) @@ -213,15 +213,15 @@ async def test_create_honors_selected_models_when_provided( monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize) monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added) - service = _service(SimpleNamespace(chat_model_id=-99)) + service = _service(SimpleNamespace(agent_llm_id=-99)) payload = AutomationCreate( search_space_id=1, name="A", definition=_definition( models=AutomationModels( - chat_model_id=-1, - image_gen_model_id=7, - vision_model_id=-2, + agent_llm_id=-1, + image_generation_config_id=7, + vision_llm_config_id=-2, ) ), ) @@ -230,9 +230,9 @@ async def test_create_honors_selected_models_when_provided( assert validated["ids"] == (-1, 7, -2) assert automation.definition["models"] == { - "chat_model_id": -1, - "image_gen_model_id": 7, - "vision_model_id": -2, + "agent_llm_id": -1, + "image_generation_config_id": 7, + "vision_llm_config_id": -2, } @@ -241,9 +241,9 @@ async def test_create_rejects_unbillable_selected_models( ) -> None: """A non-billable explicit selection maps the policy error to HTTP 422.""" - def _raise(*, chat_model_id, image_gen_model_id, vision_model_id): + def _raise(*, agent_llm_id, image_generation_config_id, vision_llm_config_id): raise AutomationModelPolicyError( - [{"kind": "llm", "model_id": -3, "reason": "free model"}] + [{"kind": "llm", "config_id": -3, "reason": "free model"}] ) monkeypatch.setattr(automation_mod, "assert_models_billable", _raise) @@ -253,15 +253,15 @@ async def test_create_rejects_unbillable_selected_models( monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize) - service = _service(SimpleNamespace(chat_model_id=-3)) + service = _service(SimpleNamespace(agent_llm_id=-3)) payload = AutomationCreate( search_space_id=1, name="A", definition=_definition( models=AutomationModels( - chat_model_id=-3, - image_gen_model_id=7, - vision_model_id=-2, + agent_llm_id=-3, + image_generation_config_id=7, + vision_llm_config_id=-2, ) ), ) @@ -277,9 +277,9 @@ async def test_update_preserves_captured_models( ) -> None: """A definition edit carries over the previously captured ``models``.""" captured = { - "chat_model_id": -1, - "image_gen_model_id": 5, - "vision_model_id": -1, + "agent_llm_id": -1, + "image_generation_config_id": 5, + "vision_llm_config_id": -1, } existing = SimpleNamespace( search_space_id=1, @@ -318,20 +318,20 @@ async def test_update_honors_changed_models_when_valid( "name": "A", "plan": [], "models": { - "chat_model_id": -1, - "image_gen_model_id": 5, - "vision_model_id": -1, + "agent_llm_id": -1, + "image_generation_config_id": 5, + "vision_llm_config_id": -1, }, }, version=3, ) validated: dict[str, Any] = {} - def _assert_ok(*, chat_model_id, image_gen_model_id, vision_model_id): + def _assert_ok(*, agent_llm_id, image_generation_config_id, vision_llm_config_id): validated["ids"] = ( - chat_model_id, - image_gen_model_id, - vision_model_id, + agent_llm_id, + image_generation_config_id, + vision_llm_config_id, ) monkeypatch.setattr(automation_mod, "assert_models_billable", _assert_ok) @@ -351,9 +351,9 @@ async def test_update_honors_changed_models_when_valid( patch = AutomationUpdate( definition=_definition( models=AutomationModels( - chat_model_id=-2, - image_gen_model_id=9, - vision_model_id=-2, + agent_llm_id=-2, + image_generation_config_id=9, + vision_llm_config_id=-2, ) ) ) @@ -362,9 +362,9 @@ async def test_update_honors_changed_models_when_valid( assert validated["ids"] == (-2, 9, -2) assert result.definition["models"] == { - "chat_model_id": -2, - "image_gen_model_id": 9, - "vision_model_id": -2, + "agent_llm_id": -2, + "image_generation_config_id": 9, + "vision_llm_config_id": -2, } assert result.version == 4 @@ -379,17 +379,17 @@ async def test_update_rejects_changed_unbillable_models( "name": "A", "plan": [], "models": { - "chat_model_id": -1, - "image_gen_model_id": 5, - "vision_model_id": -1, + "agent_llm_id": -1, + "image_generation_config_id": 5, + "vision_llm_config_id": -1, }, }, version=3, ) - def _raise(*, chat_model_id, image_gen_model_id, vision_model_id): + def _raise(*, agent_llm_id, image_generation_config_id, vision_llm_config_id): raise AutomationModelPolicyError( - [{"kind": "llm", "model_id": -7, "reason": "free model"}] + [{"kind": "llm", "config_id": -7, "reason": "free model"}] ) monkeypatch.setattr(automation_mod, "assert_models_billable", _raise) @@ -409,9 +409,9 @@ async def test_update_rejects_changed_unbillable_models( patch = AutomationUpdate( definition=_definition( models=AutomationModels( - chat_model_id=-7, - image_gen_model_id=5, - vision_model_id=-1, + agent_llm_id=-7, + image_generation_config_id=5, + vision_llm_config_id=-1, ) ) ) @@ -431,9 +431,9 @@ async def test_update_keeps_unchanged_models_without_revalidation( premium without an unrelated edit tripping the policy check. """ captured = { - "chat_model_id": -1, - "image_gen_model_id": 5, - "vision_model_id": -1, + "agent_llm_id": -1, + "image_generation_config_id": 5, + "vision_llm_config_id": -1, } existing = SimpleNamespace( search_space_id=1, @@ -485,7 +485,7 @@ async def test_model_eligibility_authorizes_and_returns_payload( lambda _ss: {"allowed": False, "violations": [{"kind": "image"}]}, ) - service = _service(SimpleNamespace(chat_model_id=-2)) + service = _service(SimpleNamespace(agent_llm_id=-2)) result = await service.model_eligibility(search_space_id=5) assert result == {"allowed": False, "violations": [{"kind": "image"}]} diff --git a/surfsense_backend/tests/unit/automations/services/test_model_policy.py b/surfsense_backend/tests/unit/automations/services/test_model_policy.py index 574f6d9fd..8e0806151 100644 --- a/surfsense_backend/tests/unit/automations/services/test_model_policy.py +++ b/surfsense_backend/tests/unit/automations/services/test_model_policy.py @@ -27,9 +27,9 @@ pytestmark = pytest.mark.unit def _search_space(*, llm: int | None, image: int | None, vision: int | None): """Minimal stand-in for the ``SearchSpace`` ORM row the policy reads.""" return SimpleNamespace( - chat_model_id=llm, - image_gen_model_id=image, - vision_model_id=vision, + agent_llm_id=llm, + image_generation_config_id=image, + vision_llm_config_id=vision, ) @@ -39,11 +39,29 @@ def patched_globals(monkeypatch: pytest.MonkeyPatch): Negative ids: -1 is premium, -2 is free, for each of llm/image/vision. """ + llm_configs = { + -1: {"id": -1, "billing_tier": "premium"}, + -2: {"id": -2, "billing_tier": "free"}, + } + monkeypatch.setattr( + "app.agents.chat.runtime.llm_config.load_global_llm_config_by_id", + lambda cid: llm_configs.get(cid), + ) + from app.config import config as app_config monkeypatch.setattr( app_config, - "GLOBAL_MODELS", + "GLOBAL_IMAGE_GEN_CONFIGS", + [ + {"id": -1, "billing_tier": "premium"}, + {"id": -2, "billing_tier": "free"}, + ], + raising=False, + ) + monkeypatch.setattr( + app_config, + "GLOBAL_VISION_LLM_CONFIGS", [ {"id": -1, "billing_tier": "premium"}, {"id": -2, "billing_tier": "free"}, @@ -53,7 +71,7 @@ def patched_globals(monkeypatch: pytest.MonkeyPatch): return None -@pytest.mark.parametrize("kind", ["chat", "image", "vision"]) +@pytest.mark.parametrize("kind", ["llm", "image", "vision"]) def test_byok_positive_id_is_allowed(kind: str, patched_globals) -> None: """A positive config id is a user-owned BYOK model — always billable.""" allowed, reason = model_policy._classify(kind, 7) @@ -61,7 +79,7 @@ def test_byok_positive_id_is_allowed(kind: str, patched_globals) -> None: assert reason == "" -@pytest.mark.parametrize("kind", ["chat", "image", "vision"]) +@pytest.mark.parametrize("kind", ["llm", "image", "vision"]) @pytest.mark.parametrize("config_id", [0, None]) def test_auto_mode_is_blocked(kind: str, config_id, patched_globals) -> None: """Auto mode (id 0) and an unset slot (None) are blocked.""" @@ -70,7 +88,7 @@ def test_auto_mode_is_blocked(kind: str, config_id, patched_globals) -> None: assert "Auto mode" in reason -@pytest.mark.parametrize("kind", ["chat", "image", "vision"]) +@pytest.mark.parametrize("kind", ["llm", "image", "vision"]) def test_premium_global_is_allowed(kind: str, patched_globals) -> None: """A negative (global) id with premium billing tier is allowed.""" allowed, reason = model_policy._classify(kind, -1) @@ -78,7 +96,7 @@ def test_premium_global_is_allowed(kind: str, patched_globals) -> None: assert reason == "" -@pytest.mark.parametrize("kind", ["chat", "image", "vision"]) +@pytest.mark.parametrize("kind", ["llm", "image", "vision"]) def test_free_global_is_blocked(kind: str, patched_globals) -> None: """A negative (global) id with a free billing tier is blocked.""" allowed, reason = model_policy._classify(kind, -2) @@ -86,7 +104,7 @@ def test_free_global_is_blocked(kind: str, patched_globals) -> None: assert "free model" in reason -@pytest.mark.parametrize("kind", ["chat", "image", "vision"]) +@pytest.mark.parametrize("kind", ["llm", "image", "vision"]) def test_unknown_global_id_is_blocked(kind: str, patched_globals) -> None: """A negative id that resolves to no config is treated as not premium.""" allowed, _ = model_policy._classify(kind, -999) @@ -107,10 +125,10 @@ def test_eligibility_reports_each_violation(patched_globals) -> None: assert result["allowed"] is False kinds = {v["kind"] for v in result["violations"]} - assert kinds == {"chat", "image", "vision"} - # model_id is echoed back for the UI / settings deep-link. - by_kind = {v["kind"]: v["model_id"] for v in result["violations"]} - assert by_kind == {"chat": -2, "image": 0, "vision": -2} + assert kinds == {"llm", "image", "vision"} + # config_id is echoed back for the UI / settings deep-link. + by_kind = {v["kind"]: v["config_id"] for v in result["violations"]} + assert by_kind == {"llm": -2, "image": 0, "vision": -2} def test_assert_raises_with_violations(patched_globals) -> None: @@ -120,7 +138,7 @@ def test_assert_raises_with_violations(patched_globals) -> None: assert_automation_models_billable(search_space) assert len(exc_info.value.violations) == 1 - assert exc_info.value.violations[0]["kind"] == "chat" + assert exc_info.value.violations[0]["kind"] == "llm" def test_assert_passes_when_all_billable(patched_globals) -> None: @@ -135,7 +153,7 @@ def test_assert_passes_when_all_billable(patched_globals) -> None: def test_get_model_eligibility_all_billable(patched_globals) -> None: """Premium LLM + BYOK image + premium vision (explicit ids) → allowed.""" result = get_model_eligibility( - chat_model_id=-1, image_gen_model_id=5, vision_model_id=-1 + agent_llm_id=-1, image_generation_config_id=5, vision_llm_config_id=-1 ) assert result == {"allowed": True, "violations": []} @@ -143,28 +161,28 @@ def test_get_model_eligibility_all_billable(patched_globals) -> None: def test_get_model_eligibility_reports_each_violation(patched_globals) -> None: """Free LLM, Auto image, free vision (explicit ids) each produce a violation.""" result = get_model_eligibility( - chat_model_id=-2, image_gen_model_id=0, vision_model_id=-2 + agent_llm_id=-2, image_generation_config_id=0, vision_llm_config_id=-2 ) assert result["allowed"] is False - by_kind = {v["kind"]: v["model_id"] for v in result["violations"]} - assert by_kind == {"chat": -2, "image": 0, "vision": -2} + by_kind = {v["kind"]: v["config_id"] for v in result["violations"]} + assert by_kind == {"llm": -2, "image": 0, "vision": -2} def test_assert_models_billable_raises(patched_globals) -> None: """``assert_models_billable`` raises when any explicit id is blocked.""" with pytest.raises(AutomationModelPolicyError) as exc_info: assert_models_billable( - chat_model_id=0, image_gen_model_id=5, vision_model_id=-1 + agent_llm_id=0, image_generation_config_id=5, vision_llm_config_id=-1 ) assert len(exc_info.value.violations) == 1 - assert exc_info.value.violations[0]["kind"] == "chat" + assert exc_info.value.violations[0]["kind"] == "llm" def test_assert_models_billable_passes(patched_globals) -> None: """No exception when every explicit id is premium or BYOK.""" assert ( assert_models_billable( - chat_model_id=3, image_gen_model_id=-1, vision_model_id=4 + agent_llm_id=3, image_generation_config_id=-1, vision_llm_config_id=4 ) is None ) @@ -174,5 +192,5 @@ def test_search_space_wrapper_delegates_to_core(patched_globals) -> None: """The search-space wrapper produces the same result as the ID core.""" search_space = _search_space(llm=-2, image=0, vision=-2) assert get_automation_model_eligibility(search_space) == get_model_eligibility( - chat_model_id=-2, image_gen_model_id=0, vision_model_id=-2 + agent_llm_id=-2, image_generation_config_id=0, vision_llm_config_id=-2 ) diff --git a/surfsense_backend/tests/unit/etl_pipeline/cache/conftest.py b/surfsense_backend/tests/unit/etl_pipeline/cache/conftest.py deleted file mode 100644 index c6efddc09..000000000 --- a/surfsense_backend/tests/unit/etl_pipeline/cache/conftest.py +++ /dev/null @@ -1,28 +0,0 @@ -"""Stub the cache package __init__s so unit tests import only pure leaf modules. - -The real ``cache``/``storage``/``eviction``/``persistence`` __init__s eagerly -import the facade, file storage, Celery, and ``app.db`` -- none of which a pure -unit test should need. Turning those packages into bare namespace packages lets -``from app.etl_pipeline.cache.. import ...`` resolve the leaf module -without running the heavy __init__. ``schemas`` is left real (it is pure). -""" - -import sys -import types -from pathlib import Path - -_CACHE_DIR = Path(__file__).resolve().parents[4] / "app" / "etl_pipeline" / "cache" - - -def _stub_namespace_package(dotted: str, fs_dir: Path) -> None: - if dotted in sys.modules: - return - module = types.ModuleType(dotted) - module.__path__ = [str(fs_dir)] - module.__package__ = dotted - sys.modules[dotted] = module - - -_stub_namespace_package("app.etl_pipeline.cache", _CACHE_DIR) -_stub_namespace_package("app.etl_pipeline.cache.storage", _CACHE_DIR / "storage") -_stub_namespace_package("app.etl_pipeline.cache.eviction", _CACHE_DIR / "eviction") diff --git a/surfsense_backend/tests/unit/etl_pipeline/cache/test_eligibility.py b/surfsense_backend/tests/unit/etl_pipeline/cache/test_eligibility.py deleted file mode 100644 index 99d8e67b6..000000000 --- a/surfsense_backend/tests/unit/etl_pipeline/cache/test_eligibility.py +++ /dev/null @@ -1,88 +0,0 @@ -"""What is allowed into the cache -- the gating rules, as pure logic. - -These rules decide whether a given upload may be served from / written to the -parse cache. They live in a pure predicate so every branch (disabled, vision, -no service, file category) is covered here without touching DB, storage, or the -parser. -""" - -from __future__ import annotations - -import pytest - -from app.etl_pipeline.cache.eligibility import is_parse_cacheable - -pytestmark = pytest.mark.unit - - -def test_document_with_service_and_cache_on_is_cacheable(): - assert is_parse_cacheable( - filename="report.pdf", - etl_service="LLAMACLOUD", - cache_enabled=True, - has_vision_llm=False, - ) - - -def test_disabled_cache_is_never_cacheable(): - assert not is_parse_cacheable( - filename="report.pdf", - etl_service="LLAMACLOUD", - cache_enabled=False, - has_vision_llm=False, - ) - - -def test_vision_llm_run_is_not_cacheable(): - # Vision appends model output not captured by the key; sharing it would leak - # one run's generated text into a plain parse of the same bytes. - assert not is_parse_cacheable( - filename="report.pdf", - etl_service="LLAMACLOUD", - cache_enabled=True, - has_vision_llm=True, - ) - - -@pytest.mark.parametrize("etl_service", [None, ""]) -def test_missing_etl_service_is_not_cacheable(etl_service): - assert not is_parse_cacheable( - filename="report.pdf", - etl_service=etl_service, - cache_enabled=True, - has_vision_llm=False, - ) - - -@pytest.mark.parametrize( - "filename", - ["paper.pdf", "memo.docx", "slides.pptx", "sheet.xlsx", "book.epub"], -) -def test_document_extensions_are_cacheable(filename): - assert is_parse_cacheable( - filename=filename, - etl_service="LLAMACLOUD", - cache_enabled=True, - has_vision_llm=False, - ) - - -@pytest.mark.parametrize( - "filename", - [ - "notes.txt", # plaintext - "readme.md", # plaintext - "main.py", # plaintext - "podcast.mp3", # audio - "photo.png", # image (vision path / fallback, not a shared doc parse) - "data.csv", # direct-convert - "archive.xyz", # unsupported - ], -) -def test_non_document_categories_are_not_cacheable(filename): - assert not is_parse_cacheable( - filename=filename, - etl_service="LLAMACLOUD", - cache_enabled=True, - has_vision_llm=False, - ) diff --git a/surfsense_backend/tests/unit/etl_pipeline/cache/test_eviction_policy.py b/surfsense_backend/tests/unit/etl_pipeline/cache/test_eviction_policy.py deleted file mode 100644 index 5113d7c42..000000000 --- a/surfsense_backend/tests/unit/etl_pipeline/cache/test_eviction_policy.py +++ /dev/null @@ -1,76 +0,0 @@ -"""Size-based eviction: drop just enough of the coldest entries to fit budget. - -The caller supplies candidates already ordered coldest-first; this pure rule only -decides how far down that list to cut. It must never over-evict (stop as soon as -the footprint fits) and never promise more than the candidates can free. -""" - -from __future__ import annotations - -from datetime import UTC, datetime - -import pytest - -from app.etl_pipeline.cache.eviction.policy import select_over_budget -from app.etl_pipeline.cache.schemas import EvictionCandidate - -pytestmark = pytest.mark.unit - - -def _candidate(id_: int, size_bytes: int) -> EvictionCandidate: - return EvictionCandidate( - id=id_, - storage_key=f"etl_cache/{id_}.md", - size_bytes=size_bytes, - last_used_at=datetime(2026, 1, 1, tzinfo=UTC), - times_reused=0, - ) - - -def test_over_budget_drops_coldest_until_it_fits(): - # 300 used, budget 100 -> must free >=200. Coldest-first [120, 90, 70]; - # 120+90=210 >=200, so the third (70) is spared. - coldest_first = [_candidate(1, 120), _candidate(2, 90), _candidate(3, 70)] - - chosen = select_over_budget( - coldest_first, current_total_bytes=300, max_total_bytes=100 - ) - - assert [c.id for c in chosen] == [1, 2] - - -@pytest.mark.parametrize("current_total_bytes", [100, 80]) -def test_within_budget_evicts_nothing(current_total_bytes): - # At or under budget there is nothing to free, so no blob is touched. - coldest_first = [_candidate(1, 50), _candidate(2, 50)] - - chosen = select_over_budget( - coldest_first, - current_total_bytes=current_total_bytes, - max_total_bytes=100, - ) - - assert chosen == [] - - -def test_stops_as_soon_as_one_entry_covers_the_overage(): - # Only 10 over budget; the first (cold) entry already frees enough. - coldest_first = [_candidate(1, 40), _candidate(2, 40)] - - chosen = select_over_budget( - coldest_first, current_total_bytes=110, max_total_bytes=100 - ) - - assert [c.id for c in chosen] == [1] - - -def test_returns_all_candidates_when_they_cannot_free_enough(): - # Deficit is 500 but candidates only total 150: return everything available - # rather than looping forever or raising. - coldest_first = [_candidate(1, 100), _candidate(2, 50)] - - chosen = select_over_budget( - coldest_first, current_total_bytes=600, max_total_bytes=100 - ) - - assert [c.id for c in chosen] == [1, 2] diff --git a/surfsense_backend/tests/unit/etl_pipeline/cache/test_parse_key.py b/surfsense_backend/tests/unit/etl_pipeline/cache/test_parse_key.py deleted file mode 100644 index d69e74ee0..000000000 --- a/surfsense_backend/tests/unit/etl_pipeline/cache/test_parse_key.py +++ /dev/null @@ -1,70 +0,0 @@ -"""Content-addressing: equal (bytes + recipe) must map to one storage location. - -This is the dedup guarantee the whole cache rests on -- two users uploading the -same file under the same parser settings have to land on the same object key, and -any change to bytes or recipe has to land somewhere else. -""" - -from __future__ import annotations - -import pytest - -from app.etl_pipeline.cache.schemas import ParseKey -from app.etl_pipeline.cache.storage.object_keys import ( - CACHE_PREFIX, - build_parse_object_key, -) - -pytestmark = pytest.mark.unit - - -def _key(**overrides) -> ParseKey: - base = { - "source_sha256": "a" * 64, - "etl_service": "LLAMACLOUD", - "mode": "basic", - "version": 1, - } - base.update(overrides) - return ParseKey.for_document( - base["source_sha256"], - etl_service=base["etl_service"], - mode=base["mode"], - version=base["version"], - ) - - -def test_same_bytes_and_recipe_produce_the_same_object_key(): - assert build_parse_object_key(_key()) == build_parse_object_key(_key()) - - -def test_different_bytes_produce_different_object_keys(): - assert build_parse_object_key( - _key(source_sha256="a" * 64) - ) != build_parse_object_key(_key(source_sha256="b" * 64)) - - -@pytest.mark.parametrize( - "field, value", - [ - ("etl_service", "DOCLING"), - ("mode", "premium"), - ("version", 2), - ], -) -def test_any_recipe_change_produces_a_different_object_key(field, value): - # Same bytes but a different parser/mode/version must not collide: the recipe - # is part of the identity, so changing it has to re-parse, not reuse. - assert build_parse_object_key(_key()) != build_parse_object_key( - _key(**{field: value}) - ) - - -def test_object_key_is_prefixed_and_sharded_by_source_hash(): - # Shape matters operationally: a dedicated top-level prefix keeps cache blobs - # out of the normal store, and the sha directory groups every recipe variant - # of one file together. - key = _key() - assert build_parse_object_key(key) == ( - f"{CACHE_PREFIX}/{key.source_sha256}/LLAMACLOUD.basic.v1.md" - ) diff --git a/surfsense_backend/tests/unit/indexing_pipeline/cache/conftest.py b/surfsense_backend/tests/unit/indexing_pipeline/cache/conftest.py deleted file mode 100644 index 081dddaa7..000000000 --- a/surfsense_backend/tests/unit/indexing_pipeline/cache/conftest.py +++ /dev/null @@ -1,28 +0,0 @@ -"""Stub the cache package __init__s so unit tests import only pure leaf modules. - -The real ``cache``/``storage``/``eviction``/``persistence`` __init__s eagerly -import the facade, file storage, Celery, and ``app.db`` -- none of which a pure -unit test should need. Turning those packages into bare namespace packages lets -``from app.indexing_pipeline.cache. import ...`` resolve the leaf module -without running the heavy __init__. ``schemas`` is left real (it is pure). -""" - -import sys -import types -from pathlib import Path - -_CACHE_DIR = Path(__file__).resolve().parents[4] / "app" / "indexing_pipeline" / "cache" - - -def _stub_namespace_package(dotted: str, fs_dir: Path) -> None: - if dotted in sys.modules: - return - module = types.ModuleType(dotted) - module.__path__ = [str(fs_dir)] - module.__package__ = dotted - sys.modules[dotted] = module - - -_stub_namespace_package("app.indexing_pipeline.cache", _CACHE_DIR) -_stub_namespace_package("app.indexing_pipeline.cache.storage", _CACHE_DIR / "storage") -_stub_namespace_package("app.indexing_pipeline.cache.eviction", _CACHE_DIR / "eviction") diff --git a/surfsense_backend/tests/unit/indexing_pipeline/cache/test_eligibility.py b/surfsense_backend/tests/unit/indexing_pipeline/cache/test_eligibility.py deleted file mode 100644 index 2e488231c..000000000 --- a/surfsense_backend/tests/unit/indexing_pipeline/cache/test_eligibility.py +++ /dev/null @@ -1,28 +0,0 @@ -from app.indexing_pipeline.cache.eligibility import is_embedding_cacheable - - -def test_disabled_cache_is_never_cacheable(): - assert not is_embedding_cacheable( - cache_enabled=False, embedding_model="m", embedding_dim=384 - ) - - -def test_missing_model_is_not_cacheable(): - assert not is_embedding_cacheable( - cache_enabled=True, embedding_model=None, embedding_dim=384 - ) - - -def test_missing_dimension_is_not_cacheable(): - assert not is_embedding_cacheable( - cache_enabled=True, embedding_model="m", embedding_dim=None - ) - assert not is_embedding_cacheable( - cache_enabled=True, embedding_model="m", embedding_dim=0 - ) - - -def test_enabled_with_model_and_dim_is_cacheable(): - assert is_embedding_cacheable( - cache_enabled=True, embedding_model="m", embedding_dim=384 - ) diff --git a/surfsense_backend/tests/unit/indexing_pipeline/cache/test_embedding_key.py b/surfsense_backend/tests/unit/indexing_pipeline/cache/test_embedding_key.py deleted file mode 100644 index ce9c8672d..000000000 --- a/surfsense_backend/tests/unit/indexing_pipeline/cache/test_embedding_key.py +++ /dev/null @@ -1,31 +0,0 @@ -from app.indexing_pipeline.cache.schemas import EmbeddingKey - - -def _key(**overrides) -> EmbeddingKey: - base = { - "markdown_sha256": "a" * 64, - "embedding_model": "openai://text-embedding-3-small", - "embedding_dim": 1536, - "chunker_kind": "hybrid", - "chunker_version": 1, - } - base.update(overrides) - return EmbeddingKey(**base) - - -def test_object_suffix_is_stable(): - assert _key().object_suffix == _key().object_suffix - - -def test_object_suffix_differs_by_model(): - assert _key().object_suffix != _key(embedding_model="local/minilm").object_suffix - - -def test_object_suffix_differs_by_chunker_kind_and_version(): - assert _key().object_suffix != _key(chunker_kind="code").object_suffix - assert _key().object_suffix != _key(chunker_version=2).object_suffix - - -def test_object_suffix_encodes_kind_and_version(): - suffix = _key(chunker_kind="code", chunker_version=3).object_suffix - assert suffix.endswith(".code.v3.emb") diff --git a/surfsense_backend/tests/unit/indexing_pipeline/cache/test_serialization.py b/surfsense_backend/tests/unit/indexing_pipeline/cache/test_serialization.py deleted file mode 100644 index f8cff6355..000000000 --- a/surfsense_backend/tests/unit/indexing_pipeline/cache/test_serialization.py +++ /dev/null @@ -1,54 +0,0 @@ -import numpy as np -import pytest - -from app.indexing_pipeline.cache.schemas import CachedChunk, EmbeddingSet -from app.indexing_pipeline.cache.serialization import deserialize, serialize - - -def _make_set(dim: int, n_chunks: int) -> EmbeddingSet: - rng = np.random.default_rng(0) - return EmbeddingSet( - summary_embedding=rng.random(dim, dtype=np.float64), - chunks=[ - CachedChunk(text=f"chunk {i}\nwith newline", embedding=rng.random(dim)) - for i in range(n_chunks) - ], - ) - - -def test_round_trip_preserves_texts_and_vectors(): - original = _make_set(dim=8, n_chunks=3) - - restored = deserialize(serialize(original)) - - assert [c.text for c in restored.chunks] == [c.text for c in original.chunks] - assert restored.chunk_count == 3 - assert np.allclose( - restored.summary_embedding, original.summary_embedding, atol=1e-6 - ) - for got, want in zip(restored.chunks, original.chunks, strict=True): - assert np.allclose(got.embedding, want.embedding, atol=1e-6) - - -def test_round_trip_with_no_chunks(): - original = _make_set(dim=4, n_chunks=0) - - restored = deserialize(serialize(original)) - - assert restored.chunk_count == 0 - assert restored.summary_embedding.shape[0] == 4 - - -def test_serialize_rejects_mismatched_dimensions(): - bad = EmbeddingSet( - summary_embedding=np.zeros(4, dtype=np.float32), - chunks=[CachedChunk(text="x", embedding=np.zeros(8, dtype=np.float32))], - ) - - with pytest.raises(ValueError): - serialize(bad) - - -def test_deserialize_rejects_foreign_blob(): - with pytest.raises(ValueError): - deserialize(b"not-a-surfsense-blob") diff --git a/surfsense_backend/tests/unit/indexing_pipeline/test_chunk_reconciler.py b/surfsense_backend/tests/unit/indexing_pipeline/test_chunk_reconciler.py deleted file mode 100644 index 7effce840..000000000 --- a/surfsense_backend/tests/unit/indexing_pipeline/test_chunk_reconciler.py +++ /dev/null @@ -1,94 +0,0 @@ -"""reconcile(): diff existing chunk rows against new chunk texts. - -The reconciler decides which rows (and embeddings) survive an edit, which texts -must be embedded, and which rows go away -- purely from content, no DB. -""" - -from __future__ import annotations - -from app.indexing_pipeline.chunk_reconciler import ExistingChunk, reconcile - - -def _existing(*contents: str) -> list[ExistingChunk]: - return [ - ExistingChunk(id=i + 1, content=text, position=i) - for i, text in enumerate(contents) - ] - - -def test_identical_content_keeps_every_row_untouched(): - plan = reconcile(_existing("alpha", "beta", "gamma"), ["alpha", "beta", "gamma"]) - - assert plan.to_embed == [] - assert plan.to_delete == [] - assert plan.reused == [] - - -def test_head_insert_embeds_only_the_new_chunk_and_shifts_the_rest(): - plan = reconcile(_existing("alpha", "beta"), ["intro", "alpha", "beta"]) - - assert plan.to_embed == [(0, "intro")] - assert plan.to_delete == [] - # alpha: position 0 -> 1, beta: 1 -> 2; embeddings untouched. - assert plan.reused == [(1, 1), (2, 2)] - - -def test_middle_edit_swaps_exactly_one_chunk(): - plan = reconcile( - _existing("alpha", "beta", "gamma"), ["alpha", "beta EDITED", "gamma"] - ) - - assert plan.to_embed == [(1, "beta EDITED")] - assert plan.to_delete == [2] - # Neighbours did not move, so no position writes at all. - assert plan.reused == [] - - -def test_removed_chunk_is_deleted_and_followers_shift_up(): - plan = reconcile(_existing("alpha", "beta", "gamma"), ["alpha", "gamma"]) - - assert plan.to_embed == [] - assert plan.to_delete == [2] - assert plan.reused == [(3, 1)] - - -def test_duplicate_texts_pair_up_one_to_one(): - # Two identical boilerplate chunks, only one survives the edit: exactly one - # row is kept and exactly one is deleted -- never both kept or both dropped. - plan = reconcile(_existing("boiler", "boiler", "body"), ["boiler", "body"]) - - assert plan.to_embed == [] - assert plan.to_delete == [2] - assert plan.reused == [(3, 1)] - - -def test_duplicate_growth_embeds_only_the_extra_copy(): - plan = reconcile(_existing("boiler", "body"), ["boiler", "boiler", "body"]) - - assert plan.to_embed == [(1, "boiler")] - assert plan.to_delete == [] - assert plan.reused == [(2, 2)] - - -def test_reorder_becomes_position_updates_with_no_embedding(): - plan = reconcile(_existing("alpha", "beta"), ["beta", "alpha"]) - - assert plan.to_embed == [] - assert plan.to_delete == [] - assert sorted(plan.reused) == [(1, 1), (2, 0)] - - -def test_full_rewrite_replaces_everything(): - plan = reconcile(_existing("alpha", "beta"), ["new one", "new two"]) - - assert plan.to_embed == [(0, "new one"), (1, "new two")] - assert sorted(plan.to_delete) == [1, 2] - assert plan.reused == [] - - -def test_no_existing_chunks_embeds_all(): - plan = reconcile([], ["alpha", "beta"]) - - assert plan.to_embed == [(0, "alpha"), (1, "beta")] - assert plan.to_delete == [] - assert plan.reused == [] diff --git a/surfsense_backend/tests/unit/indexing_pipeline/test_index_batch_parallel.py b/surfsense_backend/tests/unit/indexing_pipeline/test_index_batch_parallel.py index feb7bbc52..3a1b77d90 100644 --- a/surfsense_backend/tests/unit/indexing_pipeline/test_index_batch_parallel.py +++ b/surfsense_backend/tests/unit/indexing_pipeline/test_index_batch_parallel.py @@ -54,7 +54,7 @@ async def test_index_calls_embed_and_chunk_via_to_thread( mock_chunk_hybrid = MagicMock(return_value=["chunk1"]) mock_chunk_hybrid.__name__ = "chunk_text_hybrid" monkeypatch.setattr( - "app.indexing_pipeline.cache.cached_indexing.chunk_text_hybrid", + "app.indexing_pipeline.indexing_pipeline_service.chunk_text_hybrid", mock_chunk_hybrid, ) mock_embed = MagicMock( @@ -62,21 +62,13 @@ async def test_index_calls_embed_and_chunk_via_to_thread( ) mock_embed.__name__ = "embed_texts" monkeypatch.setattr( - "app.indexing_pipeline.cache.cached_indexing.embed_texts", + "app.indexing_pipeline.indexing_pipeline_service.embed_texts", mock_embed, ) + # Bypass set_committed_value, which requires a real ORM instance (not MagicMock). monkeypatch.setattr( - pipeline, - "_load_existing_chunks", - AsyncMock(return_value=[]), - ) - - async def _noop_persist(_session, doc, *_args, **_kwargs): - doc.status = DocumentStatus.ready() - - monkeypatch.setattr( - "app.indexing_pipeline.indexing_pipeline_service.persist_scratch_index", - _noop_persist, + "app.indexing_pipeline.indexing_pipeline_service.attach_chunks_to_document", + MagicMock(), ) connector_doc = make_connector_document( @@ -110,31 +102,22 @@ async def test_non_code_documents_use_hybrid_chunker( mock_chunk_hybrid = MagicMock(return_value=["chunk1"]) mock_chunk_hybrid.__name__ = "chunk_text_hybrid" monkeypatch.setattr( - "app.indexing_pipeline.cache.cached_indexing.chunk_text_hybrid", + "app.indexing_pipeline.indexing_pipeline_service.chunk_text_hybrid", mock_chunk_hybrid, ) mock_chunk_code = MagicMock(return_value=["chunk1"]) mock_chunk_code.__name__ = "chunk_text" monkeypatch.setattr( - "app.indexing_pipeline.cache.cached_indexing.chunk_text", + "app.indexing_pipeline.indexing_pipeline_service.chunk_text", mock_chunk_code, ) monkeypatch.setattr( - "app.indexing_pipeline.cache.cached_indexing.embed_texts", + "app.indexing_pipeline.indexing_pipeline_service.embed_texts", MagicMock(side_effect=lambda texts: [[0.1] * _EMBEDDING_DIM for _ in texts]), ) monkeypatch.setattr( - pipeline, - "_load_existing_chunks", - AsyncMock(return_value=[]), - ) - - async def _noop_persist(_session, doc, *_args, **_kwargs): - doc.status = DocumentStatus.ready() - - monkeypatch.setattr( - "app.indexing_pipeline.indexing_pipeline_service.persist_scratch_index", - _noop_persist, + "app.indexing_pipeline.indexing_pipeline_service.attach_chunks_to_document", + MagicMock(), ) connector_doc = make_connector_document( diff --git a/surfsense_backend/tests/unit/indexing_pipeline/test_persist_scratch_index.py b/surfsense_backend/tests/unit/indexing_pipeline/test_persist_scratch_index.py deleted file mode 100644 index 026c3161d..000000000 --- a/surfsense_backend/tests/unit/indexing_pipeline/test_persist_scratch_index.py +++ /dev/null @@ -1,65 +0,0 @@ -from unittest.mock import AsyncMock, MagicMock - -import pytest - -from app.db import Chunk, Document, DocumentStatus -from app.indexing_pipeline.document_persistence import persist_scratch_index - -pytestmark = pytest.mark.unit - - -def _make_document(doc_id: int = 1) -> Document: - document = MagicMock(spec=Document) - document.id = doc_id - document.content = None - document.status = DocumentStatus.processing() - return document - - -@pytest.mark.asyncio -async def test_persist_scratch_index_batches_commits(monkeypatch): - monkeypatch.setattr( - "app.indexing_pipeline.document_persistence.set_committed_value", - lambda *_args, **_kwargs: None, - ) - session = MagicMock() - session.commit = AsyncMock() - document = _make_document() - chunks = [Chunk(content=f"c{i}", embedding=[0.1], position=i) for i in range(5)] - perf = MagicMock() - - await persist_scratch_index( - session, - document, - "body", - chunks, - batch_size=2, - perf=perf, - ) - - assert session.commit.await_count == 5 - assert document.status == DocumentStatus.ready() - - -@pytest.mark.asyncio -async def test_persist_scratch_index_empty_chunks(monkeypatch): - monkeypatch.setattr( - "app.indexing_pipeline.document_persistence.set_committed_value", - lambda *_args, **_kwargs: None, - ) - session = MagicMock() - session.commit = AsyncMock() - document = _make_document() - perf = MagicMock() - - await persist_scratch_index( - session, - document, - "body", - [], - batch_size=200, - perf=perf, - ) - - assert session.commit.await_count == 2 - assert document.status == DocumentStatus.ready() diff --git a/surfsense_backend/tests/unit/notifications/service/messages/test_document_processing.py b/surfsense_backend/tests/unit/notifications/service/messages/test_document_processing.py index 9fc93d3ed..2f0a6a9d3 100644 --- a/surfsense_backend/tests/unit/notifications/service/messages/test_document_processing.py +++ b/surfsense_backend/tests/unit/notifications/service/messages/test_document_processing.py @@ -61,21 +61,3 @@ def test_completion_failure(): assert message == "Processing failed: bad" assert status == "failed" assert meta["processing_stage"] == "failed" - - -def test_started_title_truncates_long_name(): - """Very long filenames are truncated to fit the notification title column.""" - long_name = "a" * 250 - title = msg.started_title(long_name) - assert len(title) <= 200 - assert title.startswith("Processing: ") - assert title.endswith("...") - - -def test_completion_truncates_long_name(): - """Completion titles truncate long document names.""" - long_name = "b" * 250 - title, _, _, _ = msg.completion(long_name, document_id=1) - assert len(title) <= 200 - assert title.startswith("Ready: ") - assert title.endswith("...") diff --git a/surfsense_backend/tests/unit/notifications/service/messages/test_text.py b/surfsense_backend/tests/unit/notifications/service/messages/test_text.py index 183779a9c..bf3611607 100644 --- a/surfsense_backend/tests/unit/notifications/service/messages/test_text.py +++ b/surfsense_backend/tests/unit/notifications/service/messages/test_text.py @@ -4,7 +4,7 @@ from __future__ import annotations import pytest -from app.notifications.service.messages.text import format_title, truncate +from app.notifications.service.messages.text import truncate pytestmark = pytest.mark.unit @@ -22,22 +22,3 @@ def test_truncate_keeps_text_at_exact_limit(): def test_truncate_appends_ellipsis_when_over_limit(): """Text past the limit is cut to the limit and gains an ellipsis.""" assert truncate("a" * 41, 40) == "a" * 40 + "..." - - -def test_format_title_keeps_short_name(): - """Short names are joined to the prefix without truncation.""" - assert format_title("Ready: ", "report.pdf") == "Ready: report.pdf" - - -def test_format_title_truncates_long_name(): - """Long names are truncated so the full title fits the DB limit.""" - long_name = "a" * 250 - title = format_title("Processing: ", long_name) - assert len(title) == 200 - assert title.startswith("Processing: ") - assert title.endswith("...") - - -def test_format_title_respects_custom_max_length(): - """A custom max length caps the title.""" - assert len(format_title("Go: ", "hello world", max_length=10)) == 10 diff --git a/surfsense_backend/tests/unit/podcasts/conftest.py b/surfsense_backend/tests/unit/podcasts/conftest.py index c77eb1cc6..5eb4d8457 100644 --- a/surfsense_backend/tests/unit/podcasts/conftest.py +++ b/surfsense_backend/tests/unit/podcasts/conftest.py @@ -31,8 +31,8 @@ def make_spec(): language: str = "en", style: PodcastStyle = PodcastStyle.CONVERSATIONAL, speakers: list[SpeakerSpec] | None = None, - min_seconds: int = 600, - max_seconds: int = 1200, + min_minutes: int = 10, + max_minutes: int = 20, focus: str | None = None, ) -> PodcastSpec: if speakers is None: @@ -54,7 +54,7 @@ def make_spec(): language=language, style=style, speakers=speakers, - duration=DurationTarget(min_seconds=min_seconds, max_seconds=max_seconds), + duration=DurationTarget(min_minutes=min_minutes, max_minutes=max_minutes), focus=focus, ) diff --git a/surfsense_backend/tests/unit/podcasts/test_renderer.py b/surfsense_backend/tests/unit/podcasts/test_renderer.py index bb7b8f181..2bcdff967 100644 --- a/surfsense_backend/tests/unit/podcasts/test_renderer.py +++ b/surfsense_backend/tests/unit/podcasts/test_renderer.py @@ -66,7 +66,7 @@ def _spec(voice_id: str) -> PodcastSpec: speakers=[ SpeakerSpec(slot=0, name="Host", role=SpeakerRole.HOST, voice_id=voice_id) ], - duration=DurationTarget(min_seconds=300, max_seconds=600), + duration=DurationTarget(min_minutes=5, max_minutes=10), ) diff --git a/surfsense_backend/tests/unit/podcasts/test_spec.py b/surfsense_backend/tests/unit/podcasts/test_spec.py index 77e720286..4efd530e9 100644 --- a/surfsense_backend/tests/unit/podcasts/test_spec.py +++ b/surfsense_backend/tests/unit/podcasts/test_spec.py @@ -57,7 +57,7 @@ def test_spec_normalizes_its_language_on_construction(): spec = PodcastSpec( language="EN-us", speakers=[_speaker(0)], - duration=DurationTarget(min_seconds=300, max_seconds=600), + duration=DurationTarget(min_minutes=5, max_minutes=10), ) assert spec.language == "en-us" @@ -68,7 +68,7 @@ def test_speakers_must_have_unique_slots(): PodcastSpec( language="en", speakers=[_speaker(0), _speaker(0, voice_id="kokoro:af_bella")], - duration=DurationTarget(min_seconds=300, max_seconds=600), + duration=DurationTarget(min_minutes=5, max_minutes=10), ) @@ -77,7 +77,7 @@ def test_a_brief_needs_at_least_one_speaker(): PodcastSpec( language="en", speakers=[], - duration=DurationTarget(min_seconds=300, max_seconds=600), + duration=DurationTarget(min_minutes=5, max_minutes=10), ) @@ -86,7 +86,7 @@ def test_a_monologue_brief_carries_exactly_one_speaker(): language="en", style=PodcastStyle.MONOLOGUE, speakers=[_speaker(0)], - duration=DurationTarget(min_seconds=300, max_seconds=600), + duration=DurationTarget(min_minutes=5, max_minutes=10), ) assert spec.style is PodcastStyle.MONOLOGUE @@ -98,25 +98,18 @@ def test_a_monologue_brief_rejects_multiple_speakers(): language="en", style=PodcastStyle.MONOLOGUE, speakers=[_speaker(0), _speaker(1, voice_id="kokoro:af_bella")], - duration=DurationTarget(min_seconds=300, max_seconds=600), + duration=DurationTarget(min_minutes=5, max_minutes=10), ) def test_duration_rejects_an_inverted_range(): """A max below the min is a user error caught at the brief gate.""" with pytest.raises(ValidationError): - DurationTarget(min_seconds=1200, max_seconds=600) + DurationTarget(min_minutes=20, max_minutes=10) def test_duration_midpoint_is_where_drafting_aims(): - assert DurationTarget(min_seconds=600, max_seconds=1200).midpoint_seconds == 900 - assert DurationTarget(min_seconds=600, max_seconds=1200).midpoint_minutes == 15 - - -def test_duration_loads_legacy_minute_fields_from_json(): - duration = DurationTarget.model_validate({"min_minutes": 10, "max_minutes": 20}) - assert duration.min_seconds == 600 - assert duration.max_seconds == 1200 + assert DurationTarget(min_minutes=10, max_minutes=20).midpoint_minutes == 15 def test_blank_focus_becomes_absent(): @@ -124,7 +117,7 @@ def test_blank_focus_becomes_absent(): spec = PodcastSpec( language="en", speakers=[_speaker(0)], - duration=DurationTarget(min_seconds=300, max_seconds=600), + duration=DurationTarget(min_minutes=5, max_minutes=10), focus=" ", ) assert spec.focus is None @@ -134,7 +127,7 @@ def test_speaker_for_returns_the_speaker_bound_to_a_slot(): spec = PodcastSpec( language="en", speakers=[_speaker(0), _speaker(1, voice_id="kokoro:af_bella")], - duration=DurationTarget(min_seconds=300, max_seconds=600), + duration=DurationTarget(min_minutes=5, max_minutes=10), ) assert spec.speaker_for(1).voice_id == "kokoro:af_bella" @@ -143,7 +136,7 @@ def test_speaker_for_raises_when_no_speaker_matches(): spec = PodcastSpec( language="en", speakers=[_speaker(0)], - duration=DurationTarget(min_seconds=300, max_seconds=600), + duration=DurationTarget(min_minutes=5, max_minutes=10), ) with pytest.raises(KeyError): spec.speaker_for(99) diff --git a/surfsense_backend/tests/unit/podcasts/test_voice_catalog.py b/surfsense_backend/tests/unit/podcasts/test_voice_catalog.py index d120d4bfc..861d8768c 100644 --- a/surfsense_backend/tests/unit/podcasts/test_voice_catalog.py +++ b/surfsense_backend/tests/unit/podcasts/test_voice_catalog.py @@ -75,59 +75,6 @@ def test_supports_language_reports_availability(): assert not catalog.supports_language(TtsProvider.KOKORO, "de") -def test_offerable_languages_for_a_concrete_roster_are_its_tags_only(): - """A provider whose voices are language-bound offers exactly those tags.""" - catalog = VoiceCatalog( - [ - _voice("k1", language="en-US"), - _voice("k2", language="fr"), - _voice("k3", language="fr"), - ] - ) - - offering = catalog.offerable_languages(TtsProvider.KOKORO) - - assert offering.languages == ["en-US", "fr"] - assert offering.allows_custom is False - - -def test_a_wildcard_roster_offers_the_curated_languages_and_custom_entry(): - """Voices that speak anything can't enumerate languages themselves, so the - catalog offers the curated common list and invites free entry.""" - catalog = VoiceCatalog( - [_voice("o1", provider=TtsProvider.OPENAI, language=ANY_LANGUAGE)] - ) - - offering = catalog.offerable_languages(TtsProvider.OPENAI) - - assert {"en", "fr", "sw", "hi", "zh"} <= set(offering.languages) - assert offering.allows_custom is True - - -def test_a_mixed_roster_offers_the_union_of_concrete_and_curated(): - catalog = VoiceCatalog( - [ - _voice("v1", provider=TtsProvider.VERTEX_AI, language="en-GB"), - _voice("v2", provider=TtsProvider.VERTEX_AI, language=ANY_LANGUAGE), - ] - ) - - offering = catalog.offerable_languages(TtsProvider.VERTEX_AI) - - assert "en-GB" in offering.languages - assert "fr" in offering.languages - assert offering.allows_custom is True - - -def test_a_provider_with_no_voices_offers_nothing(): - catalog = VoiceCatalog([_voice("k1")]) - - offering = catalog.offerable_languages(TtsProvider.OPENAI) - - assert offering.languages == [] - assert offering.allows_custom is False - - def test_get_raises_for_an_unknown_voice(): catalog = VoiceCatalog([_voice("k1")]) with pytest.raises(KeyError): diff --git a/surfsense_backend/tests/unit/routes/test_byok_supports_image_input.py b/surfsense_backend/tests/unit/routes/test_byok_supports_image_input.py new file mode 100644 index 000000000..c9f18d77d --- /dev/null +++ b/surfsense_backend/tests/unit/routes/test_byok_supports_image_input.py @@ -0,0 +1,110 @@ +"""Unit tests for ``supports_image_input`` derivation on BYOK chat config +endpoints (``GET /new-llm-configs`` list, ``GET /new-llm-configs/{id}``). + +There is no DB column for ``supports_image_input`` on +``NewLLMConfig`` — the value is resolved at the API boundary by +``derive_supports_image_input`` so the new-chat selector / streaming +task can read the same field shape regardless of source (BYOK vs YAML +vs OpenRouter dynamic). Default-allow on unknown so we don't lock the +user out of their own model choice. +""" + +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace +from uuid import uuid4 + +import pytest + +from app.db import LiteLLMProvider +from app.routes import new_llm_config_routes + +pytestmark = pytest.mark.unit + + +def _byok_row( + *, + id_: int, + model_name: str, + base_model: str | None = None, + provider: LiteLLMProvider = LiteLLMProvider.OPENAI, + custom_provider: str | None = None, +) -> object: + """Mimic the SQLAlchemy row's attribute surface; ``model_validate`` + walks ``from_attributes=True`` so a ``SimpleNamespace`` is enough. + + ``provider`` is a real ``LiteLLMProvider`` enum value so Pydantic's + enum validator accepts it — same as the ORM row would carry.""" + return SimpleNamespace( + id=id_, + name=f"BYOK-{id_}", + description=None, + provider=provider, + custom_provider=custom_provider, + model_name=model_name, + api_key="sk-byok", + api_base=None, + litellm_params={"base_model": base_model} if base_model else None, + system_instructions="", + use_default_system_instructions=True, + citations_enabled=True, + created_at=datetime.now(tz=UTC), + search_space_id=42, + user_id=uuid4(), + ) + + +def test_serialize_byok_known_vision_model_resolves_true(): + """The catalog resolver consults LiteLLM's map for ``gpt-4o`` -> + True. The serialized row carries that value through to the + ``NewLLMConfigRead`` schema.""" + row = _byok_row(id_=1, model_name="gpt-4o") + serialized = new_llm_config_routes._serialize_byok_config(row) + + assert serialized.supports_image_input is True + assert serialized.id == 1 + assert serialized.model_name == "gpt-4o" + + +def test_serialize_byok_unknown_model_default_allows(): + """Unknown / unmapped: default-allow. The streaming-task safety net + is the actual block, and it requires LiteLLM to *explicitly* say + text-only — so a brand new BYOK model should not be pre-judged.""" + row = _byok_row( + id_=2, + model_name="brand-new-model-x9-unmapped", + provider=LiteLLMProvider.CUSTOM, + custom_provider="brand_new_proxy", + ) + serialized = new_llm_config_routes._serialize_byok_config(row) + + assert serialized.supports_image_input is True + + +def test_serialize_byok_uses_base_model_when_present(): + """Azure-style: ``model_name`` is the deployment id, ``base_model`` + inside ``litellm_params`` is the canonical sku LiteLLM knows. The + helper must consult ``base_model`` first or unrecognised deployment + ids would shadow the real capability.""" + row = _byok_row( + id_=3, + model_name="my-azure-deployment-id-no-litellm-knows-this", + base_model="gpt-4o", + provider=LiteLLMProvider.AZURE_OPENAI, + ) + serialized = new_llm_config_routes._serialize_byok_config(row) + + assert serialized.supports_image_input is True + + +def test_serialize_byok_returns_pydantic_read_model(): + """The route now returns ``NewLLMConfigRead`` (not the raw ORM) so + the schema additions are guaranteed to be present in the API + surface. This guards against a future regression where someone + deletes the augmentation step and falls back to ORM passthrough.""" + from app.schemas import NewLLMConfigRead + + row = _byok_row(id_=4, model_name="gpt-4o") + serialized = new_llm_config_routes._serialize_byok_config(row) + assert isinstance(serialized, NewLLMConfigRead) diff --git a/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py b/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py new file mode 100644 index 000000000..2b6c76485 --- /dev/null +++ b/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py @@ -0,0 +1,184 @@ +"""Unit tests for ``is_premium`` derivation on the global image-gen and +vision-LLM list endpoints. + +Chat globals (``GET /global-llm-configs``) already emit +``is_premium = (billing_tier == "premium")``. Image and vision did not, +which made the new-chat ``model-selector`` render the Free/Premium badge +on the Chat tab but skip it on the Image and Vision tabs (the selector +keys its badge logic off ``is_premium``). These tests pin parity: + +* YAML free entry → ``is_premium=False`` +* YAML premium entry → ``is_premium=True`` +* OpenRouter dynamic premium entry → ``is_premium=True`` +* Auto stub (always emitted when at least one config is present) + → ``is_premium=False`` +""" + +from __future__ import annotations + +import pytest + +pytestmark = pytest.mark.unit + + +_IMAGE_FIXTURE: list[dict] = [ + { + "id": -1, + "name": "DALL-E 3", + "provider": "OPENAI", + "model_name": "dall-e-3", + "api_key": "sk-test", + "billing_tier": "free", + }, + { + "id": -2, + "name": "GPT-Image 1 (premium)", + "provider": "OPENAI", + "model_name": "gpt-image-1", + "api_key": "sk-test", + "billing_tier": "premium", + }, + { + "id": -20_001, + "name": "google/gemini-2.5-flash-image (OpenRouter)", + "provider": "OPENROUTER", + "model_name": "google/gemini-2.5-flash-image", + "api_key": "sk-or-test", + "api_base": "https://openrouter.ai/api/v1", + "billing_tier": "premium", + }, +] + + +_VISION_FIXTURE: list[dict] = [ + { + "id": -1, + "name": "GPT-4o Vision", + "provider": "OPENAI", + "model_name": "gpt-4o", + "api_key": "sk-test", + "billing_tier": "free", + }, + { + "id": -2, + "name": "Claude 3.5 Sonnet (premium)", + "provider": "ANTHROPIC", + "model_name": "claude-3-5-sonnet", + "api_key": "sk-ant-test", + "billing_tier": "premium", + }, + { + "id": -30_001, + "name": "openai/gpt-4o (OpenRouter)", + "provider": "OPENROUTER", + "model_name": "openai/gpt-4o", + "api_key": "sk-or-test", + "api_base": "https://openrouter.ai/api/v1", + "billing_tier": "premium", + }, +] + + +# ============================================================================= +# Image generation +# ============================================================================= + + +@pytest.mark.asyncio +async def test_global_image_gen_configs_emit_is_premium(monkeypatch): + """Each emitted config must carry ``is_premium`` derived server-side + from ``billing_tier``. The Auto stub is always free. + """ + from app.config import config + from app.routes import image_generation_routes + + monkeypatch.setattr( + config, "GLOBAL_IMAGE_GEN_CONFIGS", _IMAGE_FIXTURE, raising=False + ) + + payload = await image_generation_routes.get_global_image_gen_configs(user=None) + + by_id = {c["id"]: c for c in payload} + + # Auto stub is always emitted when at least one global config exists, + # and it must always declare itself free (Auto-mode billing-tier + # surfacing is a separate follow-up). + assert 0 in by_id, "Auto stub should be emitted when at least one config exists" + assert by_id[0]["is_premium"] is False + assert by_id[0]["billing_tier"] == "free" + + # YAML free entry — ``is_premium=False`` + assert by_id[-1]["is_premium"] is False + assert by_id[-1]["billing_tier"] == "free" + + # YAML premium entry — ``is_premium=True`` + assert by_id[-2]["is_premium"] is True + assert by_id[-2]["billing_tier"] == "premium" + + # OpenRouter dynamic premium entry — same field, same derivation + assert by_id[-20_001]["is_premium"] is True + assert by_id[-20_001]["billing_tier"] == "premium" + + # Every emitted dict (including Auto) must have the field — never missing. + for cfg in payload: + assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}" + assert isinstance(cfg["is_premium"], bool) + + +@pytest.mark.asyncio +async def test_global_image_gen_configs_no_globals_no_auto_stub(monkeypatch): + """When there are no global configs at all, the endpoint emits an + empty list (no Auto stub) — Auto mode would have nothing to route to. + """ + from app.config import config + from app.routes import image_generation_routes + + monkeypatch.setattr(config, "GLOBAL_IMAGE_GEN_CONFIGS", [], raising=False) + payload = await image_generation_routes.get_global_image_gen_configs(user=None) + assert payload == [] + + +# ============================================================================= +# Vision LLM +# ============================================================================= + + +@pytest.mark.asyncio +async def test_global_vision_llm_configs_emit_is_premium(monkeypatch): + from app.config import config + from app.routes import vision_llm_routes + + monkeypatch.setattr( + config, "GLOBAL_VISION_LLM_CONFIGS", _VISION_FIXTURE, raising=False + ) + + payload = await vision_llm_routes.get_global_vision_llm_configs(user=None) + + by_id = {c["id"]: c for c in payload} + + assert 0 in by_id, "Auto stub should be emitted when at least one config exists" + assert by_id[0]["is_premium"] is False + assert by_id[0]["billing_tier"] == "free" + + assert by_id[-1]["is_premium"] is False + assert by_id[-1]["billing_tier"] == "free" + + assert by_id[-2]["is_premium"] is True + assert by_id[-2]["billing_tier"] == "premium" + + assert by_id[-30_001]["is_premium"] is True + assert by_id[-30_001]["billing_tier"] == "premium" + + for cfg in payload: + assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}" + assert isinstance(cfg["is_premium"], bool) + + +@pytest.mark.asyncio +async def test_global_vision_llm_configs_no_globals_no_auto_stub(monkeypatch): + from app.config import config + from app.routes import vision_llm_routes + + monkeypatch.setattr(config, "GLOBAL_VISION_LLM_CONFIGS", [], raising=False) + payload = await vision_llm_routes.get_global_vision_llm_configs(user=None) + assert payload == [] diff --git a/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py b/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py new file mode 100644 index 000000000..b47d9134b --- /dev/null +++ b/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py @@ -0,0 +1,106 @@ +"""Unit tests for ``supports_image_input`` derivation on the chat global +config endpoint (``GET /global-new-llm-configs``). + +Resolution order (matches ``new_llm_config_routes.get_global_new_llm_configs``): + +1. Explicit ``supports_image_input`` on the cfg dict (set by the YAML + loader for operator overrides, or by the OpenRouter integration from + ``architecture.input_modalities``) — wins. +2. ``derive_supports_image_input`` helper — default-allow on unknown + models, only False when LiteLLM / OR modalities are definitive. + +The flag is purely informational at the API boundary. The streaming +task safety net (``is_known_text_only_chat_model``) is the actual block, +and it requires LiteLLM to *explicitly* mark the model as text-only. +""" + +from __future__ import annotations + +import pytest + +pytestmark = pytest.mark.unit + + +_FIXTURE: list[dict] = [ + { + "id": -1, + "name": "GPT-4o (explicit true)", + "description": "vision-capable, explicit YAML override", + "provider": "OPENAI", + "model_name": "gpt-4o", + "api_key": "sk-test", + "billing_tier": "free", + "supports_image_input": True, + }, + { + "id": -2, + "name": "DeepSeek V3 (explicit false)", + "description": "OpenRouter dynamic — modality-derived false", + "provider": "OPENROUTER", + "model_name": "deepseek/deepseek-v3.2-exp", + "api_key": "sk-or-test", + "api_base": "https://openrouter.ai/api/v1", + "billing_tier": "free", + "supports_image_input": False, + }, + { + "id": -10_010, + "name": "Unannotated GPT-4o", + "description": "no flag set — resolver should derive True via LiteLLM", + "provider": "OPENAI", + "model_name": "gpt-4o", + "api_key": "sk-test", + "billing_tier": "free", + # supports_image_input intentionally absent + }, + { + "id": -10_011, + "name": "Unannotated unknown model", + "description": "unmapped — default-allow True", + "provider": "CUSTOM", + "custom_provider": "brand_new_proxy", + "model_name": "brand-new-model-x9", + "api_key": "sk-test", + "billing_tier": "free", + }, +] + + +@pytest.mark.asyncio +async def test_global_new_llm_configs_emit_supports_image_input(monkeypatch): + """Each emitted chat config carries ``supports_image_input`` as a + bool. Explicit values win; unannotated entries are resolved via the + helper (default-allow True).""" + from app.config import config + from app.routes import new_llm_config_routes + + monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", _FIXTURE, raising=False) + + payload = await new_llm_config_routes.get_global_new_llm_configs(user=None) + by_id = {c["id"]: c for c in payload} + + # Auto stub: optimistic True so the user can keep Auto selected with + # vision-capable deployments somewhere in the pool. + assert 0 in by_id, "Auto stub should be emitted when configs exist" + assert by_id[0]["supports_image_input"] is True + assert by_id[0]["is_auto_mode"] is True + + # Explicit True is preserved. + assert by_id[-1]["supports_image_input"] is True + + # Explicit False is preserved (the exact failure mode the safety net + # guards against — DeepSeek V3 over OpenRouter would 404 with "No + # endpoints found that support image input"). + assert by_id[-2]["supports_image_input"] is False + + # Unannotated GPT-4o: resolver consults LiteLLM, which says vision. + assert by_id[-10_010]["supports_image_input"] is True + + # Unknown / unmapped model: default-allow rather than pre-judge. + assert by_id[-10_011]["supports_image_input"] is True + + for cfg in payload: + assert "supports_image_input" in cfg, ( + f"supports_image_input missing from {cfg.get('id')}" + ) + assert isinstance(cfg["supports_image_input"], bool) diff --git a/surfsense_backend/tests/unit/routes/test_image_gen_quota.py b/surfsense_backend/tests/unit/routes/test_image_gen_quota.py index 3d94c6c51..636b7de31 100644 --- a/surfsense_backend/tests/unit/routes/test_image_gen_quota.py +++ b/surfsense_backend/tests/unit/routes/test_image_gen_quota.py @@ -27,18 +27,9 @@ async def test_resolve_billing_for_auto_mode(monkeypatch): from app.routes import image_generation_routes from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS - async def _no_auto_candidates(*_args, **_kwargs): - return [] - - monkeypatch.setattr( - image_generation_routes, - "auto_model_candidates", - _no_auto_candidates, - ) - - search_space = SimpleNamespace(id=1, user_id=None, image_gen_model_id=None) + search_space = SimpleNamespace(image_generation_config_id=None) tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen( - session=None, + session=None, # Not consumed on this code path. config_id=0, # IMAGE_GEN_AUTO_MODE_ID search_space=search_space, ) @@ -54,48 +45,26 @@ async def test_resolve_billing_for_premium_global_config(monkeypatch): monkeypatch.setattr( config, - "GLOBAL_MODELS", + "GLOBAL_IMAGE_GEN_CONFIGS", [ { "id": -1, - "connection_id": -101, - "model_id": "gpt-image-1", + "provider": "OPENAI", + "model_name": "gpt-image-1", "billing_tier": "premium", - "catalog": {"quota_reserve_micros": 75_000}, + "quota_reserve_micros": 75_000, }, { "id": -2, - "connection_id": -102, - "model_id": "google/gemini-2.5-flash-image", + "provider": "OPENROUTER", + "model_name": "google/gemini-2.5-flash-image", "billing_tier": "free", - "catalog": {}, - }, - ], - raising=False, - ) - monkeypatch.setattr( - config, - "GLOBAL_CONNECTIONS", - [ - { - "id": -101, - "provider": "openai", - "api_key": "sk-test", - "base_url": None, - "extra": {}, - }, - { - "id": -102, - "provider": "openrouter", - "api_key": "sk-or-test", - "base_url": "https://openrouter.ai/api/v1", - "extra": {}, }, ], raising=False, ) - search_space = SimpleNamespace(id=1, user_id=None, image_gen_model_id=None) + search_space = SimpleNamespace(image_generation_config_id=None) # Premium with override. tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen( @@ -125,7 +94,7 @@ async def test_resolve_billing_for_user_owned_byok_is_free(): from app.routes import image_generation_routes from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS - search_space = SimpleNamespace(id=1, user_id=None, image_gen_model_id=None) + search_space = SimpleNamespace(image_generation_config_id=None) tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen( session=None, config_id=42, search_space=search_space ) @@ -136,7 +105,7 @@ async def test_resolve_billing_for_user_owned_byok_is_free(): @pytest.mark.asyncio async def test_resolve_billing_falls_back_to_search_space_default(monkeypatch): - """When the request omits ``image_gen_model_id``, the helper + """When the request omits ``image_generation_config_id``, the helper must consult the search space's default — so a search space pinned to a premium global config still gates new requests by quota. """ @@ -145,34 +114,19 @@ async def test_resolve_billing_falls_back_to_search_space_default(monkeypatch): monkeypatch.setattr( config, - "GLOBAL_MODELS", + "GLOBAL_IMAGE_GEN_CONFIGS", [ { "id": -7, - "connection_id": -101, - "model_id": "gpt-image-1", + "provider": "OPENAI", + "model_name": "gpt-image-1", "billing_tier": "premium", - "catalog": {}, - } - ], - raising=False, - ) - monkeypatch.setattr( - config, - "GLOBAL_CONNECTIONS", - [ - { - "id": -101, - "provider": "openai", - "api_key": "sk-test", - "base_url": None, - "extra": {}, } ], raising=False, ) - search_space = SimpleNamespace(id=1, user_id=None, image_gen_model_id=-7) + search_space = SimpleNamespace(image_generation_config_id=-7) ( tier, model, diff --git a/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py b/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py index b43540ba7..fa8819b39 100644 --- a/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py +++ b/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py @@ -1,4 +1,27 @@ -"""Unit tests for ``_resolve_agent_billing_for_search_space``.""" +"""Unit tests for ``_resolve_agent_billing_for_search_space``. + +Validates the resolver used by Celery podcast/video tasks to compute +``(owner_user_id, billing_tier, base_model)`` from a search space and its +agent LLM config. The resolver mirrors chat's billing-resolution pattern at +``stream_new_chat.py:2294-2351`` and is the single integration point that +prevents Auto-mode podcast/video from leaking premium credit. + +Coverage: + +* Auto mode + ``thread_id`` set, pin resolves to a negative-id premium + global → returns ``("premium", )``. +* Auto mode + ``thread_id`` set, pin resolves to a negative-id free + global → returns ``("free", )``. +* Auto mode + ``thread_id`` set, pin resolves to a positive-id BYOK config + → always ``"free"``. +* Auto mode + ``thread_id=None`` → fallback to ``("free", "auto")`` without + hitting the pin service. +* Negative id (no Auto) → uses ``get_global_llm_config``'s + ``billing_tier``. +* Positive id (user BYOK) → always ``"free"``. +* Search space not found → raises ``ValueError``. +* ``agent_llm_id`` is None → raises ``ValueError``. +""" from __future__ import annotations @@ -11,6 +34,11 @@ import pytest pytestmark = pytest.mark.unit +# --------------------------------------------------------------------------- +# Fakes +# --------------------------------------------------------------------------- + + class _FakeExecResult: def __init__(self, obj): self._obj = obj @@ -23,6 +51,14 @@ class _FakeExecResult: class _FakeSession: + """Tiny AsyncSession stub. + + ``responses`` is a list of objects to return from successive + ``execute()`` calls (in order). The resolver makes at most two + ``execute()`` calls (search-space lookup, then optionally NewLLMConfig + lookup), so two queued responses cover the matrix. + """ + def __init__(self, responses: list): self._responses = list(responses) @@ -31,6 +67,9 @@ class _FakeSession: return _FakeExecResult(None) return _FakeExecResult(self._responses.pop(0)) + async def commit(self) -> None: + pass + @dataclass class _FakePinResolution: @@ -39,33 +78,53 @@ class _FakePinResolution: from_existing_pin: bool = False -def _make_search_space(*, chat_model_id: int | None, user_id: UUID) -> SimpleNamespace: - return SimpleNamespace(id=42, chat_model_id=chat_model_id, user_id=user_id) +def _make_search_space(*, agent_llm_id: int | None, user_id: UUID) -> SimpleNamespace: + return SimpleNamespace( + id=42, + agent_llm_id=agent_llm_id, + user_id=user_id, + ) -def _make_byok_model( - *, id_: int, base_model: str | None = None, model_id: str = "gpt-byok" +def _make_byok_config( + *, id_: int, base_model: str | None = None, model_name: str = "gpt-byok" ) -> SimpleNamespace: return SimpleNamespace( id=id_, - model_id=model_id, - catalog={"base_model": base_model} if base_model else {}, - connection=SimpleNamespace(enabled=True, search_space_id=42, user_id=None), + model_name=model_name, + litellm_params={"base_model": base_model} if base_model else {}, ) +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + @pytest.mark.asyncio async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch): + """Auto + thread → pin service resolves to negative-id premium config → + resolver returns ``("premium", )``.""" from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - session = _FakeSession([_make_search_space(chat_model_id=0, user_id=user_id)]) + session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)]) - async def _fake_resolve_pin(*_args, **kwargs): - assert kwargs["selected_llm_config_id"] == 0 - assert kwargs["thread_id"] == 99 + # Mock the pin service to return a concrete premium config id. + async def _fake_resolve_pin( + sess, + *, + thread_id, + search_space_id, + user_id, + selected_llm_config_id, + force_repin_free=False, + ): + assert selected_llm_config_id == 0 + assert thread_id == 99 return _FakePinResolution(resolved_llm_config_id=-1, resolved_tier="premium") + # Mock global config lookup to return a premium entry. def _fake_get_global(cfg_id): if cfg_id == -1: return { @@ -76,6 +135,8 @@ async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch): } return None + # Lazy imports inside the resolver — patch the *target* modules so the + # imported names resolve to our fakes. import app.services.auto_model_pin_service as pin_module import app.services.llm_service as llm_module @@ -94,17 +155,76 @@ async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch): @pytest.mark.asyncio -async def test_auto_mode_with_thread_id_resolves_to_byok_is_free(monkeypatch): +async def test_auto_mode_with_thread_id_resolves_to_free_global(monkeypatch): + """Auto + thread → pin returns negative-id free config → resolver + returns ``("free", )``. Same path the pin service takes for + out-of-credit users (graceful degradation).""" from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - search_space = _make_search_space(chat_model_id=0, user_id=user_id) - byok_model = _make_byok_model( - id_=17, base_model="anthropic/claude-3-haiku", model_id="my-claude" - ) - session = _FakeSession([search_space, byok_model]) + session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)]) - async def _fake_resolve_pin(*_args, **_kwargs): + async def _fake_resolve_pin( + sess, + *, + thread_id, + search_space_id, + user_id, + selected_llm_config_id, + force_repin_free=False, + ): + return _FakePinResolution(resolved_llm_config_id=-3, resolved_tier="free") + + def _fake_get_global(cfg_id): + if cfg_id == -3: + return { + "id": -3, + "model_name": "openrouter/free-model", + "billing_tier": "free", + "litellm_params": {"base_model": "openrouter/free-model"}, + } + return None + + import app.services.auto_model_pin_service as pin_module + import app.services.llm_service as llm_module + + monkeypatch.setattr( + pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin + ) + monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42, thread_id=99 + ) + + assert owner == user_id + assert tier == "free" + assert base_model == "openrouter/free-model" + + +@pytest.mark.asyncio +async def test_auto_mode_with_thread_id_resolves_to_byok_is_free(monkeypatch): + """Auto + thread → pin returns positive-id BYOK config → resolver + returns ``("free", ...)`` (BYOK is always free per + ``AgentConfig.from_new_llm_config``).""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + search_space = _make_search_space(agent_llm_id=0, user_id=user_id) + byok_cfg = _make_byok_config( + id_=17, base_model="anthropic/claude-3-haiku", model_name="my-claude" + ) + session = _FakeSession([search_space, byok_cfg]) + + async def _fake_resolve_pin( + sess, + *, + thread_id, + search_space_id, + user_id, + selected_llm_config_id, + force_repin_free=False, + ): return _FakePinResolution(resolved_llm_config_id=17, resolved_tier="free") import app.services.auto_model_pin_service as pin_module @@ -124,10 +244,13 @@ async def test_auto_mode_with_thread_id_resolves_to_byok_is_free(monkeypatch): @pytest.mark.asyncio async def test_auto_mode_without_thread_id_falls_back_to_free(): + """Auto + ``thread_id=None`` → ``("free", "auto")`` without invoking + the pin service. Forward-compat fallback for any future direct-API + entrypoint that doesn't have a chat thread.""" from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - session = _FakeSession([_make_search_space(chat_model_id=0, user_id=user_id)]) + session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)]) owner, tier, base_model = await _resolve_agent_billing_for_search_space( session, search_space_id=42, thread_id=None @@ -140,10 +263,13 @@ async def test_auto_mode_without_thread_id_falls_back_to_free(): @pytest.mark.asyncio async def test_auto_mode_pin_failure_falls_back_to_free(monkeypatch): + """If the pin service raises ``ValueError`` (thread missing / + mismatched search space), the resolver should log and return free + rather than killing the whole task.""" from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - session = _FakeSession([_make_search_space(chat_model_id=0, user_id=user_id)]) + session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)]) async def _fake_resolve_pin(*args, **kwargs): raise ValueError("thread missing") @@ -165,10 +291,12 @@ async def test_auto_mode_pin_failure_falls_back_to_free(monkeypatch): @pytest.mark.asyncio async def test_negative_id_premium_global_returns_premium(monkeypatch): + """Explicit negative agent_llm_id → ``get_global_llm_config`` → + return its ``billing_tier``.""" from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - session = _FakeSession([_make_search_space(chat_model_id=-1, user_id=user_id)]) + session = _FakeSession([_make_search_space(agent_llm_id=-1, user_id=user_id)]) def _fake_get_global(cfg_id): return { @@ -192,14 +320,49 @@ async def test_negative_id_premium_global_returns_premium(monkeypatch): @pytest.mark.asyncio -async def test_negative_id_missing_base_model_falls_back_to_model_name(monkeypatch): +async def test_negative_id_free_global_returns_free(monkeypatch): from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - session = _FakeSession([_make_search_space(chat_model_id=-5, user_id=user_id)]) + session = _FakeSession([_make_search_space(agent_llm_id=-2, user_id=user_id)]) def _fake_get_global(cfg_id): - return {"id": cfg_id, "model_name": "fallback-model", "billing_tier": "premium"} + return { + "id": cfg_id, + "model_name": "openrouter/some-free", + "billing_tier": "free", + "litellm_params": {"base_model": "openrouter/some-free"}, + } + + import app.services.llm_service as llm_module + + monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global) + + owner, tier, base_model = await _resolve_agent_billing_for_search_space( + session, search_space_id=42, thread_id=None + ) + + assert owner == user_id + assert tier == "free" + assert base_model == "openrouter/some-free" + + +@pytest.mark.asyncio +async def test_negative_id_missing_base_model_falls_back_to_model_name(monkeypatch): + """When the global config has no ``litellm_params.base_model``, the + resolver falls back to ``model_name`` — matching chat's behavior.""" + from app.services.billable_calls import _resolve_agent_billing_for_search_space + + user_id = uuid4() + session = _FakeSession([_make_search_space(agent_llm_id=-5, user_id=user_id)]) + + def _fake_get_global(cfg_id): + return { + "id": cfg_id, + "model_name": "fallback-model", + "billing_tier": "premium", + # No litellm_params. + } import app.services.llm_service as llm_module @@ -215,12 +378,14 @@ async def test_negative_id_missing_base_model_falls_back_to_model_name(monkeypat @pytest.mark.asyncio async def test_positive_id_byok_is_always_free(): + """Positive agent_llm_id → user-owned BYOK NewLLMConfig → always free, + regardless of underlying provider tier.""" from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - search_space = _make_search_space(chat_model_id=23, user_id=user_id) - byok_model = _make_byok_model(id_=23, base_model="anthropic/claude-3.5-sonnet") - session = _FakeSession([search_space, byok_model]) + search_space = _make_search_space(agent_llm_id=23, user_id=user_id) + byok_cfg = _make_byok_config(id_=23, base_model="anthropic/claude-3.5-sonnet") + session = _FakeSession([search_space, byok_cfg]) owner, tier, base_model = await _resolve_agent_billing_for_search_space( session, search_space_id=42 @@ -233,10 +398,13 @@ async def test_positive_id_byok_is_always_free(): @pytest.mark.asyncio async def test_positive_id_byok_missing_returns_free_with_empty_base_model(): + """If the BYOK config row is missing/deleted but the search space still + points at it, the resolver still returns free (no debit) with an empty + base_model — billable_call's premium path is skipped, no harm done.""" from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - session = _FakeSession([_make_search_space(chat_model_id=99, user_id=user_id)]) + session = _FakeSession([_make_search_space(agent_llm_id=99, user_id=user_id)]) owner, tier, base_model = await _resolve_agent_billing_for_search_space( session, search_space_id=42 @@ -251,18 +419,18 @@ async def test_positive_id_byok_missing_returns_free_with_empty_base_model(): async def test_search_space_not_found_raises_value_error(): from app.services.billable_calls import _resolve_agent_billing_for_search_space + session = _FakeSession([None]) + with pytest.raises(ValueError, match="Search space"): - await _resolve_agent_billing_for_search_space( - _FakeSession([None]), search_space_id=999 - ) + await _resolve_agent_billing_for_search_space(session, search_space_id=999) @pytest.mark.asyncio -async def test_chat_model_id_none_raises_value_error(): +async def test_agent_llm_id_none_raises_value_error(): from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - session = _FakeSession([_make_search_space(chat_model_id=None, user_id=user_id)]) + session = _FakeSession([_make_search_space(agent_llm_id=None, user_id=user_id)]) - with pytest.raises(ValueError, match="chat_model_id"): + with pytest.raises(ValueError, match="agent_llm_id"): await _resolve_agent_billing_for_search_space(session, search_space_id=42) diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py index 598e9b1ab..5c5c90283 100644 --- a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py +++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py @@ -17,39 +17,8 @@ from app.services.auto_model_pin_service import ( pytestmark = pytest.mark.unit -class _FakeRedis: - def __init__(self): - self.values: dict[str, str] = {} - self.ttls: dict[str, int] = {} - - def set(self, key: str, value: str, *, ex: int | None = None): - self.values[key] = value - if ex is not None: - self.ttls[key] = ex - return True - - def mget(self, keys: list[str]): - return [self.values.get(key) for key in keys] - - def delete(self, *keys: str): - removed = 0 - for key in keys: - if key in self.values: - removed += 1 - self.values.pop(key, None) - self.ttls.pop(key, None) - return removed - - def scan_iter(self, pattern: str): - prefix = pattern.removesuffix("*") - return (key for key in list(self.values) if key.startswith(prefix)) - - @pytest.fixture(autouse=True) -def _clear_runtime_cooldown_map(monkeypatch): - import app.services.auto_model_pin_service as svc - - monkeypatch.setattr(svc, "_runtime_cooldown_redis", _FakeRedis()) +def _clear_runtime_cooldown_map(): clear_runtime_cooldown() clear_healthy() yield @@ -63,9 +32,8 @@ class _FakeQuotaResult: class _FakeExecResult: - def __init__(self, *, thread=None, scalars=None): + def __init__(self, thread): self._thread = thread - self._scalars = scalars or [] def unique(self): return self @@ -73,71 +41,19 @@ class _FakeExecResult: def scalar_one_or_none(self): return self._thread - def scalars(self): - return SimpleNamespace(all=lambda: self._scalars) - class _FakeSession: - def __init__(self, thread, *, models=None): + def __init__(self, thread): self.thread = thread - self.models = models or [] self.commit_count = 0 - self.execute_count = 0 async def execute(self, _stmt): - self.execute_count += 1 - if self.execute_count == 1: - return _FakeExecResult(thread=self.thread) - return _FakeExecResult(scalars=self.models) + return _FakeExecResult(self.thread) async def commit(self): self.commit_count += 1 -def _set_global_llm_configs(monkeypatch, config, configs: list[dict]): - """Patch the new global model catalog shape from compact legacy cfg fixtures.""" - connections = [] - models = [] - for cfg in configs: - config_id = int(cfg["id"]) - connection_id = config_id - 100_000 - provider = cfg.get("provider") or cfg.get("litellm_provider") - model_name = cfg["model_name"] - connections.append( - { - "id": connection_id, - "provider": provider, - "scope": "GLOBAL", - "enabled": True, - } - ) - models.append( - { - "id": config_id, - "connection_id": connection_id, - "model_id": model_name, - "display_name": cfg.get("name") or model_name, - "supports_chat": cfg.get("supports_chat", True), - "supports_image_input": cfg.get("supports_image_input", True), - "supports_tools": cfg.get("supports_tools", True), - "supports_image_generation": cfg.get( - "supports_image_generation", False - ), - "capabilities_override": cfg.get("capabilities_override") or {}, - "billing_tier": cfg.get("billing_tier", "free"), - "catalog": { - "auto_pin_tier": cfg.get("auto_pin_tier"), - "quality_score": cfg.get("quality_score") - or cfg.get("quality_score_static"), - }, - } - ) - - monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", configs) - monkeypatch.setattr(config, "GLOBAL_CONNECTIONS", connections) - monkeypatch.setattr(config, "GLOBAL_MODELS", models) - - def _thread( *, search_space_id: int = 10, @@ -155,19 +71,14 @@ async def test_auto_first_turn_pins_one_model(monkeypatch): from app.config import config session = _FakeSession(_thread()) - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [ - { - "id": -2, - "litellm_provider": "openai", - "model_name": "gpt-free", - "api_key": "k1", - }, + {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"}, { "id": -1, - "litellm_provider": "openai", + "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium", @@ -200,13 +111,13 @@ async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch): from app.config import config session = _FakeSession(_thread()) - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [ { "id": -2, - "litellm_provider": "openai", + "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free", @@ -214,7 +125,7 @@ async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch): }, { "id": -1, - "litellm_provider": "openai", + "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium", @@ -243,19 +154,17 @@ async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch): @pytest.mark.asyncio -async def test_premium_eligible_auto_uses_quality_pool_not_single_preferred_model( - monkeypatch, -): +async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch): from app.config import config session = _FakeSession(_thread()) - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [ { "id": -1, - "litellm_provider": "azure", + "provider": "AZURE_OPENAI", "model_name": "gpt-5.1", "api_key": "k1", "billing_tier": "premium", @@ -264,7 +173,7 @@ async def test_premium_eligible_auto_uses_quality_pool_not_single_preferred_mode }, { "id": -2, - "litellm_provider": "azure", + "provider": "AZURE_OPENAI", "model_name": "gpt-5.4", "api_key": "k2", "billing_tier": "premium", @@ -273,39 +182,12 @@ async def test_premium_eligible_auto_uses_quality_pool_not_single_preferred_mode }, { "id": -3, - "litellm_provider": "anthropic", - "model_name": "claude-opus", + "provider": "OPENROUTER", + "model_name": "openai/gpt-5.4", "api_key": "k3", "billing_tier": "premium", - "auto_pin_tier": "A", - "quality_score": 99, - }, - { - "id": -4, - "litellm_provider": "openai", - "model_name": "gpt-5.3", - "api_key": "k4", - "billing_tier": "premium", - "auto_pin_tier": "A", - "quality_score": 98, - }, - { - "id": -5, - "litellm_provider": "gemini", - "model_name": "gemini-3-pro", - "api_key": "k5", - "billing_tier": "premium", - "auto_pin_tier": "A", - "quality_score": 97, - }, - { - "id": -6, - "litellm_provider": "xai", - "model_name": "grok-5", - "api_key": "k6", - "billing_tier": "premium", - "auto_pin_tier": "A", - "quality_score": 96, + "auto_pin_tier": "B", + "quality_score": 100, }, ], ) @@ -325,7 +207,7 @@ async def test_premium_eligible_auto_uses_quality_pool_not_single_preferred_mode user_id="00000000-0000-0000-0000-000000000001", selected_llm_config_id=0, ) - assert result.resolved_llm_config_id in {-1, -3, -4, -5, -6} + assert result.resolved_llm_config_id == -2 assert result.resolved_tier == "premium" @@ -334,13 +216,13 @@ async def test_next_turn_reuses_existing_pin(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [ { "id": -1, - "litellm_provider": "openai", + "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium", @@ -375,13 +257,13 @@ async def test_premium_eligible_auto_can_pin_premium(monkeypatch): from app.config import config session = _FakeSession(_thread()) - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [ { "id": -1, - "litellm_provider": "openai", + "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium", @@ -413,20 +295,20 @@ async def test_premium_ineligible_auto_pins_free_only(monkeypatch): from app.config import config session = _FakeSession(_thread()) - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [ { "id": -2, - "litellm_provider": "openai", + "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free", }, { "id": -1, - "litellm_provider": "openai", + "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium", @@ -458,20 +340,20 @@ async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [ { "id": -2, - "litellm_provider": "openai", + "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free", }, { "id": -1, - "litellm_provider": "openai", + "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium", @@ -503,20 +385,20 @@ async def test_force_repin_free_switches_auto_premium_pin_to_free(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [ { "id": -2, - "litellm_provider": "openai", + "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free", }, { "id": -1, - "litellm_provider": "openai", + "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium", @@ -551,16 +433,11 @@ async def test_explicit_user_model_change_clears_pin(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-2)) - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [ - { - "id": -2, - "litellm_provider": "openai", - "model_name": "gpt-free", - "api_key": "k1", - }, + {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"}, ], ) @@ -581,16 +458,11 @@ async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-999)) - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [ - { - "id": -2, - "litellm_provider": "openai", - "model_name": "gpt-free", - "api_key": "k1", - }, + {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"}, ], ) @@ -615,7 +487,7 @@ async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch): # --------------------------------------------------------------------------- -# Quality-aware pin selection (Auto upgrade) +# Quality-aware pin selection (Auto Fastest upgrade) # --------------------------------------------------------------------------- @@ -626,13 +498,13 @@ async def test_health_gated_config_is_excluded_from_selection(monkeypatch): from app.config import config session = _FakeSession(_thread()) - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [ { "id": -1, - "litellm_provider": "openrouter", + "provider": "OPENROUTER", "model_name": "venice/dead-model", "api_key": "k1", "billing_tier": "free", @@ -642,7 +514,7 @@ async def test_health_gated_config_is_excluded_from_selection(monkeypatch): }, { "id": -2, - "litellm_provider": "openrouter", + "provider": "OPENROUTER", "model_name": "google/gemini-flash", "api_key": "k1", "billing_tier": "free", @@ -678,13 +550,13 @@ async def test_tier_a_locks_first_premium_user_skips_or(monkeypatch): from app.config import config session = _FakeSession(_thread()) - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [ { "id": -1, - "litellm_provider": "azure", + "provider": "AZURE_OPENAI", "model_name": "gpt-5", "api_key": "k-yaml", "billing_tier": "premium", @@ -694,7 +566,7 @@ async def test_tier_a_locks_first_premium_user_skips_or(monkeypatch): }, { "id": -2, - "litellm_provider": "openrouter", + "provider": "OPENROUTER", "model_name": "openai/gpt-5", "api_key": "k-or", "billing_tier": "premium", @@ -730,13 +602,13 @@ async def test_tier_a_falls_through_to_or_when_a_pool_empty_for_user(monkeypatch from app.config import config session = _FakeSession(_thread()) - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [ { "id": -1, - "litellm_provider": "azure", + "provider": "AZURE_OPENAI", "model_name": "gpt-5", "api_key": "k-yaml", "billing_tier": "premium", @@ -746,7 +618,7 @@ async def test_tier_a_falls_through_to_or_when_a_pool_empty_for_user(monkeypatch }, { "id": -2, - "litellm_provider": "openrouter", + "provider": "OPENROUTER", "model_name": "google/gemini-flash:free", "api_key": "k-or", "billing_tier": "free", @@ -784,7 +656,7 @@ async def test_top_k_picks_only_high_score_models(monkeypatch): high_score_cfgs = [ { "id": -i, - "litellm_provider": "azure", + "provider": "AZURE_OPENAI", "model_name": f"gpt-x-{i}", "api_key": "k", "billing_tier": "premium", @@ -796,7 +668,7 @@ async def test_top_k_picks_only_high_score_models(monkeypatch): ] low_score_trap = { "id": -99, - "litellm_provider": "azure", + "provider": "AZURE_OPENAI", "model_name": "tiny-legacy", "api_key": "k", "billing_tier": "premium", @@ -804,9 +676,9 @@ async def test_top_k_picks_only_high_score_models(monkeypatch): "quality_score": 10, "health_gated": False, } - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [*high_score_cfgs, low_score_trap], ) @@ -851,13 +723,13 @@ async def test_pin_reuse_survives_health_gating_for_existing_pin(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [ { "id": -1, - "litellm_provider": "openrouter", + "provider": "OPENROUTER", "model_name": "venice/dead-model", "api_key": "k", "billing_tier": "premium", @@ -867,7 +739,7 @@ async def test_pin_reuse_survives_health_gating_for_existing_pin(monkeypatch): }, { "id": -2, - "litellm_provider": "azure", + "provider": "AZURE_OPENAI", "model_name": "gpt-5", "api_key": "k", "billing_tier": "premium", @@ -903,13 +775,13 @@ async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [ { "id": -1, - "litellm_provider": "azure", + "provider": "AZURE_OPENAI", "model_name": "gpt-5", "api_key": "k", "billing_tier": "premium", @@ -919,7 +791,7 @@ async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch): }, { "id": -2, - "litellm_provider": "azure", + "provider": "AZURE_OPENAI", "model_name": "gpt-5-pro", "api_key": "k", "billing_tier": "premium", @@ -961,13 +833,13 @@ async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [ { "id": -1, - "litellm_provider": "openrouter", + "provider": "OPENROUTER", "model_name": "google/gemma-4-26b-a4b-it:free", "api_key": "k", "billing_tier": "free", @@ -977,7 +849,7 @@ async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch): }, { "id": -2, - "litellm_provider": "openrouter", + "provider": "OPENROUTER", "model_name": "google/gemini-2.5-flash:free", "api_key": "k", "billing_tier": "free", @@ -1009,86 +881,18 @@ async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch): assert result.from_existing_pin is False -def test_mark_runtime_cooldown_writes_shared_redis(monkeypatch): - import app.services.auto_model_pin_service as svc - - mark_runtime_cooldown(-9, reason="provider_rate_limited", cooldown_seconds=123) - - redis_client = svc._runtime_cooldown_redis - assert redis_client.values["auto:cooldown:llm:-9"] == "provider_rate_limited" - assert redis_client.ttls["auto:cooldown:llm:-9"] == 123 - - -@pytest.mark.asyncio -async def test_shared_runtime_cooldown_blocks_pin_across_workers(monkeypatch): - """A Redis cooldown written by another worker should invalidate local pins.""" - import app.services.auto_model_pin_service as svc - from app.config import config - - session = _FakeSession(_thread(pinned_llm_config_id=-1)) - _set_global_llm_configs( - monkeypatch, - config, - [ - { - "id": -1, - "litellm_provider": "openrouter", - "model_name": "google/gemma-4-26b-a4b-it:free", - "api_key": "k", - "billing_tier": "free", - "auto_pin_tier": "C", - "quality_score": 90, - "health_gated": False, - }, - { - "id": -2, - "litellm_provider": "openrouter", - "model_name": "google/gemini-2.5-flash:free", - "api_key": "k", - "billing_tier": "free", - "auto_pin_tier": "C", - "quality_score": 80, - "health_gated": False, - }, - ], - ) - svc._runtime_cooldown_redis.set( - "auto:cooldown:llm:-1", - "provider_rate_limited", - ex=600, - ) - - async def _blocked(*_args, **_kwargs): - return _FakeQuotaResult(allowed=False) - - monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", - _blocked, - ) - - result = await resolve_or_get_pinned_llm_config_id( - session, - thread_id=1, - search_space_id=10, - user_id="00000000-0000-0000-0000-000000000001", - selected_llm_config_id=0, - ) - assert result.resolved_llm_config_id == -2 - assert result.from_existing_pin is False - - @pytest.mark.asyncio async def test_clearing_runtime_cooldown_restores_pin_reuse(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [ { "id": -1, - "litellm_provider": "openrouter", + "provider": "OPENROUTER", "model_name": "google/gemma-4-26b-a4b-it:free", "api_key": "k", "billing_tier": "free", @@ -1127,13 +931,13 @@ async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypa from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [ { "id": -1, - "litellm_provider": "openrouter", + "provider": "OPENROUTER", "model_name": "google/gemma-4-26b-a4b-it:free", "api_key": "k", "billing_tier": "free", @@ -1143,7 +947,7 @@ async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypa }, { "id": -2, - "litellm_provider": "openrouter", + "provider": "OPENROUTER", "model_name": "google/gemini-2.5-flash:free", "api_key": "k", "billing_tier": "free", diff --git a/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py b/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py index c1e66feb9..3ca5c7a67 100644 --- a/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py +++ b/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py @@ -45,9 +45,8 @@ class _FakeQuotaResult: class _FakeExecResult: - def __init__(self, *, thread=None, scalars=None): + def __init__(self, thread): self._thread = thread - self._scalars = scalars or [] def unique(self): return self @@ -55,21 +54,14 @@ class _FakeExecResult: def scalar_one_or_none(self): return self._thread - def scalars(self): - return SimpleNamespace(all=lambda: self._scalars) - class _FakeSession: def __init__(self, thread): self.thread = thread self.commit_count = 0 - self.execute_count = 0 async def execute(self, _stmt): - self.execute_count += 1 - if self.execute_count == 1: - return _FakeExecResult(thread=self.thread) - return _FakeExecResult(scalars=[]) + return _FakeExecResult(self.thread) async def commit(self): self.commit_count += 1 @@ -79,64 +71,10 @@ def _thread(*, pinned: int | None = None): return SimpleNamespace(id=1, search_space_id=10, pinned_llm_config_id=pinned) -def _set_global_llm_configs(monkeypatch, config, configs: list[dict]): - from app.services.provider_capabilities import derive_supports_image_input - - connections = [] - models = [] - for cfg in configs: - config_id = int(cfg["id"]) - connection_id = config_id - 100_000 - provider = cfg.get("provider") or cfg.get("litellm_provider") - model_name = cfg["model_name"] - if "supports_image_input" not in cfg: - litellm_params = cfg.get("litellm_params") or {} - base_model = ( - litellm_params.get("base_model") - if isinstance(litellm_params, dict) - else None - ) - cfg["supports_image_input"] = derive_supports_image_input( - provider=provider, - model_name=model_name, - base_model=base_model, - custom_provider=cfg.get("custom_provider"), - ) - connections.append( - { - "id": connection_id, - "provider": provider, - "scope": "GLOBAL", - "enabled": True, - } - ) - model = { - "id": config_id, - "connection_id": connection_id, - "model_id": model_name, - "display_name": cfg.get("name") or model_name, - "supports_chat": cfg.get("supports_chat", True), - "supports_tools": cfg.get("supports_tools", True), - "supports_image_generation": cfg.get("supports_image_generation", False), - "capabilities_override": cfg.get("capabilities_override") or {}, - "billing_tier": cfg.get("billing_tier", "free"), - "catalog": { - "auto_pin_tier": cfg.get("auto_pin_tier"), - "quality_score": cfg.get("quality_score"), - }, - "supports_image_input": cfg["supports_image_input"], - } - models.append(model) - - monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", configs) - monkeypatch.setattr(config, "GLOBAL_CONNECTIONS", connections) - monkeypatch.setattr(config, "GLOBAL_MODELS", models) - - def _vision_cfg(id_: int, *, tier: str = "free", quality: int = 80) -> dict: return { "id": id_, - "litellm_provider": "openai", + "provider": "OPENAI", "model_name": f"vision-{id_}", "api_key": "k", "billing_tier": tier, @@ -149,7 +87,7 @@ def _vision_cfg(id_: int, *, tier: str = "free", quality: int = 80) -> dict: def _text_only_cfg(id_: int, *, tier: str = "free", quality: int = 90) -> dict: return { "id": id_, - "litellm_provider": "openai", + "provider": "OPENAI", "model_name": f"text-{id_}", "api_key": "k", "billing_tier": tier, @@ -170,7 +108,11 @@ async def test_image_turn_filters_out_text_only_candidates(monkeypatch): from app.config import config session = _FakeSession(_thread()) - _set_global_llm_configs(monkeypatch, config, [_text_only_cfg(-1), _vision_cfg(-2)]) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [_text_only_cfg(-1), _vision_cfg(-2)], + ) monkeypatch.setattr( "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", _premium_allowed, @@ -198,7 +140,11 @@ async def test_image_turn_force_repins_stale_text_only_pin(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned=-1)) - _set_global_llm_configs(monkeypatch, config, [_text_only_cfg(-1), _vision_cfg(-2)]) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [_text_only_cfg(-1), _vision_cfg(-2)], + ) monkeypatch.setattr( "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", _premium_allowed, @@ -226,9 +172,9 @@ async def test_image_turn_reuses_existing_vision_pin(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned=-2)) - _set_global_llm_configs( - monkeypatch, + monkeypatch.setattr( config, + "GLOBAL_LLM_CONFIGS", [_text_only_cfg(-1), _vision_cfg(-2), _vision_cfg(-3, quality=70)], ) monkeypatch.setattr( @@ -257,8 +203,10 @@ async def test_image_turn_with_no_vision_candidates_raises(monkeypatch): from app.config import config session = _FakeSession(_thread()) - _set_global_llm_configs( - monkeypatch, config, [_text_only_cfg(-1), _text_only_cfg(-2)] + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [_text_only_cfg(-1), _text_only_cfg(-2)], ) monkeypatch.setattr( "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", @@ -283,7 +231,11 @@ async def test_non_image_turn_keeps_text_only_in_pool(monkeypatch): from app.config import config session = _FakeSession(_thread()) - _set_global_llm_configs(monkeypatch, config, [_text_only_cfg(-1)]) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [_text_only_cfg(-1)], + ) monkeypatch.setattr( "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", _premium_allowed, @@ -309,7 +261,7 @@ async def test_image_turn_unannotated_cfg_resolves_via_helper(monkeypatch): session = _FakeSession(_thread()) cfg_unannotated_vision = { "id": -2, - "litellm_provider": "openai", + "provider": "OPENAI", "model_name": "gpt-4o", # known vision model in LiteLLM map "api_key": "k", "billing_tier": "free", @@ -317,7 +269,7 @@ async def test_image_turn_unannotated_cfg_resolves_via_helper(monkeypatch): "quality_score": 80, # NOTE: no supports_image_input key } - _set_global_llm_configs(monkeypatch, config, [cfg_unannotated_vision]) + monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [cfg_unannotated_vision]) monkeypatch.setattr( "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", _premium_allowed, diff --git a/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py b/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py index adcfeed48..571e7d15b 100644 --- a/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py +++ b/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py @@ -1,4 +1,19 @@ -"""Image-gen call sites must pass each config's explicit ``api_base``.""" +"""Defense-in-depth: image-gen call sites must not let an empty +``api_base`` fall through to LiteLLM's module-global ``litellm.api_base``. + +The bug repro: an OpenRouter image-gen config ships +``api_base=""``. The pre-fix call site in +``image_generation_routes._execute_image_generation`` did +``if cfg.get("api_base"): kwargs["api_base"] = cfg["api_base"]`` which +silently dropped the empty string. LiteLLM then fell back to +``litellm.api_base`` (commonly inherited from ``AZURE_OPENAI_ENDPOINT``) +and OpenRouter's ``image_generation/transformation`` appended +``/chat/completions`` to it → 404 ``Resource not found``. + +This test pins the post-fix behaviour: with an empty ``api_base`` in +the config, the call site MUST set ``api_base`` to OpenRouter's public +URL instead of leaving it unset. +""" from __future__ import annotations @@ -11,23 +26,22 @@ pytestmark = pytest.mark.unit @pytest.mark.asyncio -async def test_global_openrouter_image_gen_sets_explicit_api_base(): - """The global-config branch forwards the explicit OpenRouter base.""" +async def test_global_openrouter_image_gen_sets_api_base_when_config_empty(): + """The global-config branch (``config_id < 0``) of + ``_execute_image_generation`` must apply the resolver and pin + ``api_base`` to OpenRouter when the config ships an empty string. + """ from app.routes import image_generation_routes - global_model = { + cfg = { "id": -20_001, - "connection_id": -101, - "model_id": "openai/gpt-image-1", - "supports_image_generation": True, - "capabilities_override": {}, - } - global_connection = { - "id": -101, - "provider": "openrouter", + "name": "GPT Image 1 (OpenRouter)", + "provider": "OPENROUTER", + "model_name": "openai/gpt-image-1", "api_key": "sk-or-test", - "base_url": "https://openrouter.ai/api/v1", - "extra": {}, + "api_base": "", # the original bug shape + "api_version": None, + "litellm_params": {}, } captured: dict = {} @@ -37,7 +51,7 @@ async def test_global_openrouter_image_gen_sets_explicit_api_base(): return MagicMock(model_dump=lambda: {"data": []}, _hidden_params={}) image_gen = MagicMock() - image_gen.image_gen_model_id = global_model["id"] + image_gen.image_generation_config_id = cfg["id"] image_gen.prompt = "test" image_gen.n = 1 image_gen.quality = None @@ -47,19 +61,14 @@ async def test_global_openrouter_image_gen_sets_explicit_api_base(): image_gen.model = None search_space = MagicMock() - search_space.image_gen_model_id = global_model["id"] + search_space.image_generation_config_id = cfg["id"] session = MagicMock() with ( patch.object( image_generation_routes, - "_get_global_model", - return_value=global_model, - ), - patch.object( - image_generation_routes, - "_get_global_connection", - return_value=global_connection, + "_get_global_image_gen_config", + return_value=cfg, ), patch.object( image_generation_routes, @@ -71,31 +80,30 @@ async def test_global_openrouter_image_gen_sets_explicit_api_base(): session=session, image_gen=image_gen, search_space=search_space ) + # The whole point of the fix: even with empty ``api_base`` in the + # config, we forward OpenRouter's public URL so the call doesn't + # inherit an Azure endpoint. assert captured.get("api_base") == "https://openrouter.ai/api/v1" assert captured["model"] == "openrouter/openai/gpt-image-1" @pytest.mark.asyncio -async def test_generate_image_tool_global_sets_explicit_api_base(): - """Same explicit-base behavior at the agent tool entry point — both surfaces share +async def test_generate_image_tool_global_sets_api_base_when_config_empty(): + """Same defense at the agent tool entry point — both surfaces share the same OpenRouter config payloads.""" from app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.tools import ( generate_image as gi_module, ) - global_model = { + cfg = { "id": -20_001, - "connection_id": -101, - "model_id": "openai/gpt-image-1", - "supports_image_generation": True, - "capabilities_override": {}, - } - global_connection = { - "id": -101, - "provider": "openrouter", + "name": "GPT Image 1 (OpenRouter)", + "provider": "OPENROUTER", + "model_name": "openai/gpt-image-1", "api_key": "sk-or-test", - "base_url": "https://openrouter.ai/api/v1", - "extra": {}, + "api_base": "", + "api_version": None, + "litellm_params": {}, } captured: dict = {} @@ -111,7 +119,7 @@ async def test_generate_image_tool_global_sets_explicit_api_base(): search_space = MagicMock() search_space.id = 1 - search_space.image_gen_model_id = global_model["id"] + search_space.image_generation_config_id = cfg["id"] session_cm = AsyncMock() session = AsyncMock() @@ -134,10 +142,7 @@ async def test_generate_image_tool_global_sets_explicit_api_base(): with ( patch.object(gi_module, "shielded_async_session", return_value=session_cm), - patch.object(gi_module, "_get_global_model", return_value=global_model), - patch.object( - gi_module, "_get_global_connection", return_value=global_connection - ), + patch.object(gi_module, "_get_global_image_gen_config", return_value=cfg), patch.object( gi_module, "aimage_generation", side_effect=fake_aimage_generation ), @@ -166,16 +171,20 @@ async def test_generate_image_tool_global_sets_explicit_api_base(): assert captured["model"] == "openrouter/openai/gpt-image-1" -def test_image_gen_router_deployment_sets_explicit_api_base(): - """The Auto-mode router pool carries explicit api_base into deployments.""" +def test_image_gen_router_deployment_sets_api_base_when_config_empty(): + """The Auto-mode router pool must also resolve ``api_base`` when an + OpenRouter config ships an empty string. The deployment dict is fed + straight to ``litellm.Router``, so a missing ``api_base`` would + leak the same way as the direct call sites. + """ from app.services.image_gen_router_service import ImageGenRouterService deployment = ImageGenRouterService._config_to_deployment( { "model_name": "openai/gpt-image-1", - "litellm_provider": "openrouter", + "provider": "OPENROUTER", "api_key": "sk-or-test", - "api_base": "https://openrouter.ai/api/v1", + "api_base": "", } ) assert deployment is not None diff --git a/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py b/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py index 349a7d445..c309ff881 100644 --- a/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py +++ b/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py @@ -25,10 +25,10 @@ def _fake_yaml_config( return { "id": id, "name": f"yaml-{id}", - "litellm_provider": "openai", + "provider": "OPENAI", "model_name": model_name, "api_key": "sk-test", - "api_base": "https://api.openai.com/v1", + "api_base": "", "billing_tier": billing_tier, "rpm": 100, "tpm": 100_000, @@ -54,10 +54,10 @@ def _fake_openrouter_config( return { "id": id, "name": f"or-{id}", - "litellm_provider": "openrouter", + "provider": "OPENROUTER", "model_name": model_name, "api_key": "sk-or-test", - "api_base": "https://openrouter.ai/api/v1", + "api_base": "", "billing_tier": billing_tier, "rpm": 20 if billing_tier == "free" else 200, "tpm": 100_000 if billing_tier == "free" else 1_000_000, @@ -217,64 +217,10 @@ def test_auto_model_pin_candidates_include_dynamic_openrouter(): model_name="meta-llama/llama-3.3-70b:free", billing_tier="free", ) - global_connections = [ - { - "id": -110_001, - "provider": "openrouter", - "scope": "GLOBAL", - "enabled": True, - }, - { - "id": -110_002, - "provider": "openrouter", - "scope": "GLOBAL", - "enabled": True, - }, - ] - global_models = [ - { - "id": or_premium["id"], - "connection_id": -110_001, - "model_id": or_premium["model_name"], - "display_name": or_premium["name"], - "supports_chat": True, - "supports_image_input": True, - "supports_tools": True, - "supports_image_generation": False, - "capabilities_override": {}, - "billing_tier": or_premium["billing_tier"], - "catalog": { - "auto_pin_tier": "A", - "quality_score": 50, - }, - }, - { - "id": or_free["id"], - "connection_id": -110_002, - "model_id": or_free["model_name"], - "display_name": or_free["name"], - "supports_chat": True, - "supports_image_input": True, - "supports_tools": True, - "supports_image_generation": False, - "capabilities_override": {}, - "billing_tier": or_free["billing_tier"], - "catalog": { - "auto_pin_tier": "A", - "quality_score": 50, - }, - }, - ] - original_configs = config.GLOBAL_LLM_CONFIGS - original_connections = config.GLOBAL_CONNECTIONS - original_models = config.GLOBAL_MODELS + original = config.GLOBAL_LLM_CONFIGS try: config.GLOBAL_LLM_CONFIGS = [or_premium, or_free] - config.GLOBAL_CONNECTIONS = global_connections - config.GLOBAL_MODELS = global_models candidate_ids = {c["id"] for c in _global_candidates()} assert candidate_ids == {-10_001, -10_002} finally: - config.GLOBAL_LLM_CONFIGS = original_configs - config.GLOBAL_CONNECTIONS = original_connections - config.GLOBAL_MODELS = original_models + config.GLOBAL_LLM_CONFIGS = original diff --git a/surfsense_backend/tests/unit/services/test_model_connections.py b/surfsense_backend/tests/unit/services/test_model_connections.py deleted file mode 100644 index 937eda806..000000000 --- a/surfsense_backend/tests/unit/services/test_model_connections.py +++ /dev/null @@ -1,78 +0,0 @@ -from app.services.global_model_catalog import materialize_global_model_catalog -from app.services.model_resolver import ensure_v1, to_litellm - - -def test_openai_compatible_resolver_uses_explicit_api_base() -> None: - model, kwargs = to_litellm( - { - "protocol": "OPENAI_COMPATIBLE", - "provider": "openai", - "base_url": "http://host.docker.internal:1234/v1", - "api_key": "local-key", - "extra": {}, - }, - "qwen/qwen3", - ) - - assert model == "openai/qwen/qwen3" - assert kwargs["api_base"] == "http://host.docker.internal:1234/v1" - assert kwargs["api_key"] == "local-key" - assert ensure_v1("http://example.com/v1") == "http://example.com/v1" - - -def test_ollama_resolver_uses_native_api_base() -> None: - model, kwargs = to_litellm( - { - "protocol": "OLLAMA", - "provider": "ollama_chat", - "base_url": "http://host.docker.internal:11434", - "api_key": None, - "extra": {}, - }, - "llama3.2", - ) - - assert model == "ollama_chat/llama3.2" - assert kwargs["api_base"] == "http://host.docker.internal:11434" - - -def test_global_materialization_preserves_tier_and_keeps_key_server_side() -> None: - connections, models = materialize_global_model_catalog( - chat_configs=[ - { - "id": -101, - "name": "OpenRouter Free", - "litellm_provider": "openrouter", - "model_name": "meta-llama/llama-3.1-8b-instruct:free", - "api_key": "sk-global-secret", - "api_base": "https://openrouter.ai/api/v1", - "billing_tier": "free", - "anonymous_enabled": True, - "seo_enabled": True, - "rpm": 10, - "tpm": 1000, - }, - { - "id": -102, - "name": "OpenRouter Premium", - "litellm_provider": "openrouter", - "model_name": "anthropic/claude-sonnet-4", - "api_key": "sk-global-secret", - "api_base": "https://openrouter.ai/api/v1", - "billing_tier": "premium", - }, - ], - image_configs=[], - ) - - assert len(connections) == 1 - assert connections[0]["api_key"] == "sk-global-secret" - assert {model["billing_tier"] for model in models} == {"free", "premium"} - assert models[0]["catalog"]["anonymous_enabled"] is True - assert models[0]["catalog"]["rpm"] == 10 - - public_connections = [ - {key: value for key, value in connection.items() if key != "api_key"} - for connection in connections - ] - assert "sk-" not in repr(public_connections) diff --git a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py index 147d62592..88fcf2db3 100644 --- a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py +++ b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py @@ -217,7 +217,7 @@ def test_generate_configs_drops_non_text_and_non_tool_models(): # --------------------------------------------------------------------------- -# _generate_image_gen_configs +# _generate_image_gen_configs / _generate_vision_llm_configs # --------------------------------------------------------------------------- @@ -263,15 +263,18 @@ def test_generate_image_gen_configs_filters_by_image_output(): # Each config must carry ``billing_tier`` for routing in image_generation_routes. for c in cfgs: assert c["billing_tier"] in {"free", "premium"} - assert c["provider"] == "openrouter" + assert c["provider"] == "OPENROUTER" assert c[_OPENROUTER_DYNAMIC_MARKER] is True - # Emit the OpenRouter base URL at source so every call path passes an - # explicit api_base and cannot inherit a process-global endpoint. + # Defense-in-depth: emit the OpenRouter base URL at source so a + # downstream call site that forgets ``resolve_api_base`` still + # doesn't 404 against an inherited Azure endpoint. assert c["api_base"] == "https://openrouter.ai/api/v1" def test_generate_image_gen_configs_assigns_image_id_offset(): - """Image configs use their own id_offset (-20000).""" + """Image configs use a different id_offset (-20000) so their negative + IDs don't collide with chat configs (-10000) or vision configs (-30000). + """ from app.services.openrouter_integration_service import ( _generate_image_gen_configs, ) @@ -288,3 +291,90 @@ def test_generate_image_gen_configs_assigns_image_id_offset(): cfgs = _generate_image_gen_configs(raw, dict(_SETTINGS_BASE)) assert all(c["id"] < -20_000 + 1 for c in cfgs) assert all(c["id"] > -29_000_000 for c in cfgs) + + +def test_generate_vision_llm_configs_filters_by_image_input_text_output(): + """Vision LLMs must accept image input AND emit text — pure image-gen + (no text out) and text-only (no image in) models are excluded. + """ + from app.services.openrouter_integration_service import ( + _generate_vision_llm_configs, + ) + + raw = [ + # GPT-4o: vision LLM (image in, text out) — must emit. + { + "id": "openai/gpt-4o", + "architecture": { + "input_modalities": ["text", "image"], + "output_modalities": ["text"], + }, + "context_length": 128_000, + "pricing": {"prompt": "0.000005", "completion": "0.000015"}, + }, + # Pure image generator — image *output*, no text out. Must NOT emit. + { + "id": "openai/gpt-image-1", + "architecture": { + "input_modalities": ["text"], + "output_modalities": ["image"], + }, + "context_length": 4_000, + "pricing": {"prompt": "0", "completion": "0"}, + }, + # Pure text model (no image in). Must NOT emit. + { + "id": "anthropic/claude-3-haiku", + "architecture": { + "input_modalities": ["text"], + "output_modalities": ["text"], + }, + "context_length": 200_000, + "pricing": {"prompt": "0.000001", "completion": "0.000005"}, + }, + ] + + cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE)) + names = {c["model_name"] for c in cfgs} + assert names == {"openai/gpt-4o"} + + cfg = cfgs[0] + assert cfg["billing_tier"] == "premium" + # Pricing carried inline so pricing_registration can register vision + # under ``openrouter/openai/gpt-4o`` even if the chat catalogue cache + # is cleared. + assert cfg["input_cost_per_token"] == pytest.approx(5e-6) + assert cfg["output_cost_per_token"] == pytest.approx(15e-6) + assert cfg[_OPENROUTER_DYNAMIC_MARKER] is True + # Defense-in-depth: emit the OpenRouter base URL at source so a + # downstream call site that forgets ``resolve_api_base`` still + # doesn't inherit an Azure endpoint. + assert cfg["api_base"] == "https://openrouter.ai/api/v1" + + +def test_generate_vision_llm_configs_drops_chat_only_filters(): + """A small-context vision model that doesn't advertise tool calling is + still a valid vision LLM for "describe this image" prompts. The chat + filters (``supports_tool_calling``, ``has_sufficient_context``) must + NOT be applied to vision emission. + """ + from app.services.openrouter_integration_service import ( + _generate_vision_llm_configs, + ) + + raw = [ + { + "id": "tiny/vision-mini", + "architecture": { + "input_modalities": ["text", "image"], + "output_modalities": ["text"], + }, + "supported_parameters": [], # no tools + "context_length": 4_000, # well below MIN_CONTEXT_LENGTH + "pricing": {"prompt": "0.0000001", "completion": "0.0000005"}, + } + ] + + cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE)) + assert len(cfgs) == 1 + assert cfgs[0]["model_name"] == "tiny/vision-mini" diff --git a/surfsense_backend/tests/unit/services/test_or_health_enrichment.py b/surfsense_backend/tests/unit/services/test_or_health_enrichment.py index 00706f43c..1c74aa928 100644 --- a/surfsense_backend/tests/unit/services/test_or_health_enrichment.py +++ b/surfsense_backend/tests/unit/services/test_or_health_enrichment.py @@ -25,7 +25,7 @@ def _or_cfg( ) -> dict: return { "id": cid, - "provider": "openrouter", + "provider": "OPENROUTER", "model_name": model_name, "billing_tier": tier, "auto_pin_tier": "B" if tier == "premium" else "C", @@ -144,7 +144,7 @@ async def test_enrich_health_only_touches_or_provider(monkeypatch): """YAML cfgs that aren't OPENROUTER must be skipped entirely.""" yaml_cfg = { "id": -1, - "litellm_provider": "azure", + "provider": "AZURE_OPENAI", "model_name": "gpt-5", "billing_tier": "premium", "auto_pin_tier": "A", @@ -313,7 +313,7 @@ async def test_enrich_health_no_or_cfgs_is_noop(monkeypatch): """When the catalogue has no OR cfgs at all, no HTTP calls fire.""" yaml_cfg: dict[str, Any] = { "id": -1, - "litellm_provider": "azure", + "provider": "AZURE_OPENAI", "model_name": "gpt-5", "billing_tier": "premium", } diff --git a/surfsense_backend/tests/unit/services/test_pricing_registration.py b/surfsense_backend/tests/unit/services/test_pricing_registration.py index e1437ca24..e97250ff2 100644 --- a/surfsense_backend/tests/unit/services/test_pricing_registration.py +++ b/surfsense_backend/tests/unit/services/test_pricing_registration.py @@ -186,7 +186,7 @@ def test_openrouter_models_register_under_aliases(monkeypatch): [ { "id": 1, - "litellm_provider": "openrouter", + "provider": "OPENROUTER", "model_name": "anthropic/claude-3-5-sonnet", } ], @@ -228,7 +228,7 @@ def test_yaml_override_registers_under_alias_set(monkeypatch): [ { "id": 1, - "litellm_provider": "azure", + "provider": "AZURE_OPENAI", "model_name": "gpt-5.4", "litellm_params": { "base_model": "gpt-5.4", @@ -243,6 +243,7 @@ def test_yaml_override_registers_under_alias_set(monkeypatch): keys = spy.all_keys assert "gpt-5.4" in keys + assert "azure_openai/gpt-5.4" in keys assert "azure/gpt-5.4" in keys payload = spy.calls[0] @@ -270,7 +271,7 @@ def test_no_override_means_no_registration(monkeypatch): [ { "id": 1, - "litellm_provider": "openai", + "provider": "OPENAI", "model_name": "gpt-4o", "litellm_params": {"base_model": "gpt-4o"}, } @@ -301,7 +302,7 @@ def test_openrouter_skipped_when_pricing_missing(monkeypatch): [ { "id": 1, - "litellm_provider": "openrouter", + "provider": "OPENROUTER", "model_name": "anthropic/claude-3-5-sonnet", } ], @@ -348,12 +349,12 @@ def test_register_continues_after_individual_failure(monkeypatch, caplog): [ { "id": 1, - "litellm_provider": "openrouter", + "provider": "OPENROUTER", "model_name": "anthropic/claude-3-5-sonnet", }, { "id": 2, - "litellm_provider": "openai", + "provider": "OPENAI", "model_name": "custom-deployment", "litellm_params": { "base_model": "custom-deployment", @@ -368,3 +369,79 @@ def test_register_continues_after_individual_failure(monkeypatch, caplog): # The good config still registered. assert any("custom-deployment" in payload for payload in successful_calls) + + +def test_vision_configs_registered_with_chat_shape(monkeypatch): + """``register_pricing_from_global_configs`` walks + ``GLOBAL_VISION_LLM_CONFIGS`` in addition to the chat configs so vision + calls (during indexing) bill correctly. Vision configs use the same + chat-shape token prices, but image-gen pricing is intentionally NOT + registered here (handled via ``response_cost`` in LiteLLM). + """ + from app.config import config + from app.services.pricing_registration import register_pricing_from_global_configs + + spy = _patch_register(monkeypatch) + _patch_openrouter_pricing( + monkeypatch, + {"openai/gpt-4o": {"prompt": "0.000005", "completion": "0.000015"}}, + ) + + # No chat configs — only vision. Proves the vision walk is a separate + # iteration, not piggy-backed on the chat list. + monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", []) + monkeypatch.setattr( + config, + "GLOBAL_VISION_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "OPENROUTER", + "model_name": "openai/gpt-4o", + "billing_tier": "premium", + "input_cost_per_token": 5e-6, + "output_cost_per_token": 15e-6, + } + ], + ) + + register_pricing_from_global_configs() + + assert "openrouter/openai/gpt-4o" in spy.all_keys + payload_value = spy.calls[0]["openrouter/openai/gpt-4o"] + assert payload_value["mode"] == "chat" + assert payload_value["litellm_provider"] == "openrouter" + assert payload_value["input_cost_per_token"] == pytest.approx(5e-6) + assert payload_value["output_cost_per_token"] == pytest.approx(15e-6) + + +def test_vision_with_inline_pricing_when_or_cache_missing(monkeypatch): + """If the OpenRouter pricing cache misses a vision model (different + catalogue surface), the vision walk falls back to inline + ``input_cost_per_token``/``output_cost_per_token`` on the cfg itself. + """ + from app.config import config + from app.services.pricing_registration import register_pricing_from_global_configs + + spy = _patch_register(monkeypatch) + _patch_openrouter_pricing(monkeypatch, {}) + + monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", []) + monkeypatch.setattr( + config, + "GLOBAL_VISION_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "OPENROUTER", + "model_name": "google/gemini-2.5-flash", + "billing_tier": "premium", + "input_cost_per_token": 1e-6, + "output_cost_per_token": 4e-6, + } + ], + ) + + register_pricing_from_global_configs() + + assert "openrouter/google/gemini-2.5-flash" in spy.all_keys diff --git a/surfsense_backend/tests/unit/services/test_provider_api_base.py b/surfsense_backend/tests/unit/services/test_provider_api_base.py new file mode 100644 index 000000000..12cd0a3d5 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_provider_api_base.py @@ -0,0 +1,107 @@ +"""Unit tests for the shared ``api_base`` resolver. + +The cascade exists so vision and image-gen call sites can't silently +inherit ``litellm.api_base`` (commonly set by ``AZURE_OPENAI_ENDPOINT``) +when an OpenRouter / Groq / etc. config ships an empty string. See +``provider_api_base`` module docstring for the original repro +(OpenRouter image-gen 404-ing against an Azure endpoint). +""" + +from __future__ import annotations + +import pytest + +from app.services.provider_api_base import ( + PROVIDER_DEFAULT_API_BASE, + PROVIDER_KEY_DEFAULT_API_BASE, + resolve_api_base, +) + +pytestmark = pytest.mark.unit + + +def test_config_value_wins_over_defaults(): + """A non-empty config value is always returned verbatim, even when the + provider has a default — the operator gets the last word.""" + result = resolve_api_base( + provider="OPENROUTER", + provider_prefix="openrouter", + config_api_base="https://my-openrouter-mirror.example.com/v1", + ) + assert result == "https://my-openrouter-mirror.example.com/v1" + + +def test_provider_key_default_when_config_missing(): + """``DEEPSEEK`` shares the ``openai`` LiteLLM prefix but has its own + base URL — the provider-key map must take precedence over the prefix + map so DeepSeek requests don't go to OpenAI.""" + result = resolve_api_base( + provider="DEEPSEEK", + provider_prefix="openai", + config_api_base=None, + ) + assert result == PROVIDER_KEY_DEFAULT_API_BASE["DEEPSEEK"] + + +def test_provider_prefix_default_when_no_key_default(): + result = resolve_api_base( + provider="OPENROUTER", + provider_prefix="openrouter", + config_api_base=None, + ) + assert result == PROVIDER_DEFAULT_API_BASE["openrouter"] + + +def test_unknown_provider_returns_none(): + """When neither map matches we return ``None`` so the caller can let + LiteLLM apply its own provider-integration default (Azure deployment + URL, custom-provider URL, etc.).""" + result = resolve_api_base( + provider="SOMETHING_NEW", + provider_prefix="something_new", + config_api_base=None, + ) + assert result is None + + +def test_empty_string_config_treated_as_missing(): + """The original bug: OpenRouter dynamic configs ship ``api_base=""`` + and downstream call sites use ``if cfg.get("api_base"):`` — empty + strings are falsy in Python but the cascade has to step in anyway.""" + result = resolve_api_base( + provider="OPENROUTER", + provider_prefix="openrouter", + config_api_base="", + ) + assert result == PROVIDER_DEFAULT_API_BASE["openrouter"] + + +def test_whitespace_only_config_treated_as_missing(): + """A config value of ``" "`` is a configuration mistake — treat it + as missing instead of forwarding whitespace to LiteLLM (which would + almost certainly 404).""" + result = resolve_api_base( + provider="OPENROUTER", + provider_prefix="openrouter", + config_api_base=" ", + ) + assert result == PROVIDER_DEFAULT_API_BASE["openrouter"] + + +def test_provider_case_insensitive(): + """Some call sites pass the provider lowercase (DB enum value), others + uppercase (YAML key). Both must resolve.""" + upper = resolve_api_base( + provider="DEEPSEEK", provider_prefix="openai", config_api_base=None + ) + lower = resolve_api_base( + provider="deepseek", provider_prefix="openai", config_api_base=None + ) + assert upper == lower == PROVIDER_KEY_DEFAULT_API_BASE["DEEPSEEK"] + + +def test_all_inputs_none_returns_none(): + assert ( + resolve_api_base(provider=None, provider_prefix=None, config_api_base=None) + is None + ) diff --git a/surfsense_backend/tests/unit/services/test_provider_capabilities.py b/surfsense_backend/tests/unit/services/test_provider_capabilities.py index d20af14ae..aac88977f 100644 --- a/surfsense_backend/tests/unit/services/test_provider_capabilities.py +++ b/surfsense_backend/tests/unit/services/test_provider_capabilities.py @@ -32,7 +32,7 @@ pytestmark = pytest.mark.unit def test_or_modalities_with_image_returns_true(): assert ( derive_supports_image_input( - provider="openrouter", + provider="OPENROUTER", model_name="openai/gpt-4o", openrouter_input_modalities=["text", "image"], ) @@ -43,7 +43,7 @@ def test_or_modalities_with_image_returns_true(): def test_or_modalities_text_only_returns_false(): assert ( derive_supports_image_input( - provider="openrouter", + provider="OPENROUTER", model_name="deepseek/deepseek-v3.2-exp", openrouter_input_modalities=["text"], ) @@ -57,7 +57,7 @@ def test_or_modalities_empty_list_returns_false(): to LiteLLM.""" assert ( derive_supports_image_input( - provider="openrouter", + provider="OPENROUTER", model_name="weird/empty-modalities", openrouter_input_modalities=[], ) @@ -70,7 +70,7 @@ def test_or_modalities_none_falls_through_to_litellm(): to LiteLLM. Using ``openai/gpt-4o`` which is in LiteLLM's map.""" assert ( derive_supports_image_input( - provider="openai", + provider="OPENAI", model_name="gpt-4o", openrouter_input_modalities=None, ) @@ -86,7 +86,7 @@ def test_or_modalities_none_falls_through_to_litellm(): def test_litellm_known_vision_model_returns_true(): assert ( derive_supports_image_input( - provider="openai", + provider="OPENAI", model_name="gpt-4o", ) is True @@ -100,7 +100,7 @@ def test_litellm_base_model_wins_over_model_name(): doesn't know) would shadow the real capability.""" assert ( derive_supports_image_input( - provider="azure", + provider="AZURE_OPENAI", model_name="my-azure-deployment-id", base_model="gpt-4o", ) @@ -112,7 +112,7 @@ def test_litellm_unknown_model_default_allows(): """Default-allow on unknown — the safety net is the actual block.""" assert ( derive_supports_image_input( - provider="custom", + provider="CUSTOM", model_name="brand-new-model-x9-unmapped", custom_provider="brand_new_proxy", ) @@ -128,7 +128,7 @@ def test_litellm_known_text_only_returns_false(): # Sanity: confirm the helper's negative path. We use a small model # known not to support vision per the map. result = derive_supports_image_input( - provider="openai", + provider="DEEPSEEK", model_name="deepseek-chat", ) # We accept either False (LiteLLM said explicit no) or True @@ -147,7 +147,7 @@ def test_litellm_known_text_only_returns_false(): def test_is_known_text_only_returns_false_for_vision_model(): assert ( is_known_text_only_chat_model( - provider="openai", + provider="OPENAI", model_name="gpt-4o", ) is False @@ -160,7 +160,7 @@ def test_is_known_text_only_returns_false_for_unknown_model(): fixing.""" assert ( is_known_text_only_chat_model( - provider="custom", + provider="CUSTOM", model_name="brand-new-model-x9-unmapped", custom_provider="brand_new_proxy", ) @@ -181,7 +181,7 @@ def test_is_known_text_only_returns_false_when_lookup_raises(monkeypatch): assert ( is_known_text_only_chat_model( - provider="openai", + provider="OPENAI", model_name="gpt-4o", ) is False @@ -201,7 +201,7 @@ def test_is_known_text_only_returns_true_on_explicit_false(monkeypatch): assert ( is_known_text_only_chat_model( - provider="openai", + provider="OPENAI", model_name="any-model", ) is True @@ -218,7 +218,7 @@ def test_is_known_text_only_returns_false_on_supports_vision_true(monkeypatch): assert ( is_known_text_only_chat_model( - provider="openai", + provider="OPENAI", model_name="any-model", ) is False @@ -237,7 +237,7 @@ def test_is_known_text_only_returns_false_on_missing_key(monkeypatch): assert ( is_known_text_only_chat_model( - provider="openai", + provider="OPENAI", model_name="any-model", ) is False diff --git a/surfsense_backend/tests/unit/services/test_quality_score.py b/surfsense_backend/tests/unit/services/test_quality_score.py index cb3f7523a..6fbc8fd62 100644 --- a/surfsense_backend/tests/unit/services/test_quality_score.py +++ b/surfsense_backend/tests/unit/services/test_quality_score.py @@ -1,4 +1,4 @@ -"""Unit tests for the Auto quality scoring module.""" +"""Unit tests for the Auto (Fastest) quality scoring module.""" from __future__ import annotations @@ -228,7 +228,7 @@ def test_static_score_or_recent_release_beats_year_old_same_provider(): def test_static_score_yaml_includes_operator_bonus(): cfg = { - "litellm_provider": "azure", + "provider": "AZURE_OPENAI", "model_name": "gpt-5", "litellm_params": {"base_model": "azure/gpt-5"}, } @@ -238,7 +238,7 @@ def test_static_score_yaml_includes_operator_bonus(): def test_static_score_yaml_unknown_provider_still_carries_bonus(): cfg = { - "litellm_provider": "some_new_provider", + "provider": "SOME_NEW_PROVIDER", "model_name": "weird-model", } score = static_score_yaml(cfg) @@ -247,7 +247,7 @@ def test_static_score_yaml_unknown_provider_still_carries_bonus(): def test_static_score_yaml_clamped_0_to_100(): cfg = { - "litellm_provider": "azure", + "provider": "AZURE_OPENAI", "model_name": "gpt-5", "litellm_params": {"base_model": "azure/gpt-5"}, } diff --git a/surfsense_backend/tests/unit/services/test_quota_checked_vision_llm.py b/surfsense_backend/tests/unit/services/test_quota_checked_vision_llm.py index 0f5dd531f..9e35b6f9c 100644 --- a/surfsense_backend/tests/unit/services/test_quota_checked_vision_llm.py +++ b/surfsense_backend/tests/unit/services/test_quota_checked_vision_llm.py @@ -105,7 +105,8 @@ async def test_ainvoke_propagates_quota_insufficient_error(monkeypatch): async def _denying_billable_call(**_kwargs): raise QuotaInsufficientError( usage_type="vision_extraction", - balance_micros=5_000_000, + used_micros=5_000_000, + limit_micros=5_000_000, remaining_micros=0, ) yield # unreachable but required for asynccontextmanager type diff --git a/surfsense_backend/tests/unit/services/test_token_quota_service_cost.py b/surfsense_backend/tests/unit/services/test_token_quota_service_cost.py index 9eeb55a4d..63681828d 100644 --- a/surfsense_backend/tests/unit/services/test_token_quota_service_cost.py +++ b/surfsense_backend/tests/unit/services/test_token_quota_service_cost.py @@ -112,77 +112,6 @@ def test_per_message_summary_groups_cost_by_model(): assert summary["gpt-4o-mini"]["cost_micros"] == 200 -def test_add_reconciles_metadata_when_litellm_strips_provider_prefix(): - """Regression: LiteLLM's ``get_llm_provider`` strips the provider prefix we - add in ``to_litellm`` (``azure/gpt-5.2-chat`` → ``gpt-5.2-chat`` because - ``azure`` is in ``litellm.provider_list``), so the success callback reports - the bare model. Metadata registered under the *prefixed* string must still - attach to the call so the per-model breakdown carries provider/display_name - — otherwise the UI falls back to a bare-name collision and mis-attributes an - Azure turn to an OpenRouter model (e.g. shows "OpenAI: GPT-5.2 Chat"). - """ - from app.services.token_tracking_service import TurnTokenAccumulator - - acc = TurnTokenAccumulator() - acc.register_model_metadata( - model="azure/gpt-5.2-chat", - model_ref="global:-1", - model_id="gpt-5.2-chat", - display_name="Azure GPT 5.2", - provider="azure", - ) - # LiteLLM callback fires with the prefix-stripped model name. - acc.add( - model="gpt-5.2-chat", - prompt_tokens=100, - completion_tokens=50, - total_tokens=150, - cost_micros=4_000, - ) - - summary = acc.per_message_summary() - entry = summary["gpt-5.2-chat"] - assert entry["provider"] == "azure" - assert entry["display_name"] == "Azure GPT 5.2" - assert entry["model_id"] == "gpt-5.2-chat" - assert entry["model_ref"] == "global:-1" - - -def test_add_prefers_exact_metadata_over_bare_alias(): - """When the callback model matches a registered key exactly, the exact - metadata wins even if another model shares the same bare name — so a turn - that legitimately used two same-named deployments stays correctly - attributed.""" - from app.services.token_tracking_service import TurnTokenAccumulator - - acc = TurnTokenAccumulator() - acc.register_model_metadata( - model="azure/gpt-5.2-chat", - model_ref="global:-1", - model_id="gpt-5.2-chat", - display_name="Azure GPT 5.2", - provider="azure", - ) - acc.register_model_metadata( - model="openai/gpt-5.2-chat", - model_ref="db:7", - model_id="gpt-5.2-chat", - display_name="OpenAI GPT 5.2", - provider="openai", - ) - acc.add( - model="openai/gpt-5.2-chat", - prompt_tokens=10, - completion_tokens=5, - total_tokens=15, - cost_micros=100, - ) - - entry = acc.per_message_summary()["openai/gpt-5.2-chat"] - assert entry["provider"] == "openai" - assert entry["display_name"] == "OpenAI GPT 5.2" - - def test_serialized_calls_includes_cost_micros(): """``serialized_calls`` is what flows into the SSE ``call_details`` payload; cost_micros must be present on each entry so the FE message-info @@ -202,10 +131,6 @@ def test_serialized_calls_includes_cost_micros(): assert serialized == [ { "model": "m", - "model_ref": None, - "model_id": None, - "display_name": None, - "provider": None, "prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2, diff --git a/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py b/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py new file mode 100644 index 000000000..5e3aa6eda --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py @@ -0,0 +1,89 @@ +"""Defense-in-depth: vision-LLM resolution must not leak ``api_base`` +defaults from ``litellm.api_base`` either. + +Vision shares the same shape as image-gen — global YAML / OpenRouter +dynamic configs ship ``api_base=""`` and the pre-fix ``get_vision_llm`` +call sites would silently drop the empty string and inherit +``AZURE_OPENAI_ENDPOINT``. ``ChatLiteLLM(...)`` doesn't 404 on +construction so we test the kwargs we hand to it instead. +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +pytestmark = pytest.mark.unit + + +@pytest.mark.asyncio +async def test_get_vision_llm_global_openrouter_sets_api_base(): + """Global negative-ID branch: an OpenRouter vision config with + ``api_base=""`` must end up calling ``SanitizedChatLiteLLM`` with + ``api_base="https://openrouter.ai/api/v1"`` — never an empty string, + never silently absent.""" + from app.services import llm_service + + cfg = { + "id": -30_001, + "name": "GPT-4o Vision (OpenRouter)", + "provider": "OPENROUTER", + "model_name": "openai/gpt-4o", + "api_key": "sk-or-test", + "api_base": "", + "api_version": None, + "litellm_params": {}, + "billing_tier": "free", + } + + search_space = MagicMock() + search_space.id = 1 + search_space.user_id = "user-x" + search_space.vision_llm_config_id = cfg["id"] + + session = AsyncMock() + scalars = MagicMock() + scalars.first.return_value = search_space + result = MagicMock() + result.scalars.return_value = scalars + session.execute.return_value = result + + captured: dict = {} + + class FakeSanitized: + def __init__(self, **kwargs): + captured.update(kwargs) + + with ( + patch( + "app.services.vision_llm_router_service.get_global_vision_llm_config", + return_value=cfg, + ), + patch( + "app.agents.chat.runtime.llm_config.SanitizedChatLiteLLM", + new=FakeSanitized, + ), + ): + await llm_service.get_vision_llm(session=session, search_space_id=1) + + assert captured.get("api_base") == "https://openrouter.ai/api/v1" + assert captured["model"] == "openrouter/openai/gpt-4o" + + +def test_vision_router_deployment_sets_api_base_when_config_empty(): + """Auto-mode vision router: deployments are fed to ``litellm.Router``, + so the resolver has to apply at deployment construction time too.""" + from app.services.vision_llm_router_service import VisionLLMRouterService + + deployment = VisionLLMRouterService._config_to_deployment( + { + "model_name": "openai/gpt-4o", + "provider": "OPENROUTER", + "api_key": "sk-or-test", + "api_base": "", + } + ) + assert deployment is not None + assert deployment["litellm_params"]["api_base"] == "https://openrouter.ai/api/v1" + assert deployment["litellm_params"]["model"] == "openrouter/openai/gpt-4o" diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_error_classifier.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_error_classifier.py deleted file mode 100644 index 98ffaa282..000000000 --- a/surfsense_backend/tests/unit/tasks/chat/streaming/test_error_classifier.py +++ /dev/null @@ -1,79 +0,0 @@ -from __future__ import annotations - -import pytest - -from app.services.llm_error_adapter import LLMErrorCategory, adapt_llm_exception -from app.tasks.chat.streaming.errors.classifier import classify_stream_exception - -pytestmark = pytest.mark.unit - - -def _exception_named(name: str, message: str) -> Exception: - return type(name, (Exception,), {})(message) - - -def test_adapter_classifies_authentication_error_by_class_name() -> None: - exc = _exception_named("AuthenticationError", "provider rejected credentials") - - adapted = adapt_llm_exception(exc) - - assert adapted.category is LLMErrorCategory.AUTH_FAILED - assert adapted.retryable is False - assert adapted.user_message == "LLM authentication failed. Check your API key." - - -def test_adapter_classifies_embedded_provider_401_payload() -> None: - exc = RuntimeError( - 'litellm.AuthenticationError: OpenrouterException - {"error":{"message":"User not found.","code":401}}' - ) - - adapted = adapt_llm_exception(exc) - - assert adapted.category is LLMErrorCategory.AUTH_FAILED - assert adapted.provider_status_code == 401 - - -def test_adapter_preserves_rate_limit_classification() -> None: - exc = RuntimeError('{"error":{"message":"Slow down","code":429}}') - - adapted = adapt_llm_exception(exc) - - assert adapted.category is LLMErrorCategory.RATE_LIMITED - assert adapted.retryable is True - - -def test_stream_classifier_maps_model_auth_to_stable_code() -> None: - exc = RuntimeError( - 'litellm.AuthenticationError: OpenrouterException - {"error":{"message":"User not found.","code":401}}' - ) - - kind, code, severity, expected, message, extra = classify_stream_exception( - exc, - flow_label="chat", - ) - - assert kind == "model_auth_failed" - assert code == "MODEL_AUTH_FAILED" - assert severity == "warn" - assert expected is True - assert "API key" in message - assert extra == { - "provider_error_category": "auth_failed", - "provider_status_code": 401, - } - - -def test_stream_classifier_keeps_unknown_errors_generic() -> None: - exc = RuntimeError("database exploded") - - kind, code, severity, expected, message, extra = classify_stream_exception( - exc, - flow_label="chat", - ) - - assert kind == "server_error" - assert code == "SERVER_ERROR" - assert severity == "error" - assert expected is False - assert message == "Error during chat: database exploded" - assert extra is None diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_llm_bundle.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_llm_bundle.py deleted file mode 100644 index cecf8be5d..000000000 --- a/surfsense_backend/tests/unit/tasks/chat/streaming/test_llm_bundle.py +++ /dev/null @@ -1,154 +0,0 @@ -"""Contracts for chat LLM construction in streaming flows. - -``stream_new_chat`` / ``stream_resume_chat`` depend on LangChain receiving -token chunks from ``ChatLiteLLM``. ``langchain-litellm`` defaults -``streaming`` to ``False``, so the shared bundle loader must opt in -explicitly for both DB-backed and global model paths. -""" - -from __future__ import annotations - -from types import SimpleNamespace -from typing import Any - -import pytest - -import app.tasks.chat.streaming.flows.shared.llm_bundle as llm_bundle - -pytestmark = pytest.mark.unit - - -class _CapturedChatLiteLLM: - calls: list[dict[str, Any]] = [] - - def __init__(self, **kwargs: Any) -> None: - self.kwargs = kwargs - self.__class__.calls.append(kwargs) - - -@pytest.fixture(autouse=True) -def _patch_common_bundle_dependencies(monkeypatch: pytest.MonkeyPatch): - """Keep these tests focused on the LLM constructor contract.""" - - _CapturedChatLiteLLM.calls = [] - - async def _fake_search_space(_session: Any, _search_space_id: int) -> SimpleNamespace: - return SimpleNamespace(id=42, user_id="user-1") - - monkeypatch.setattr(llm_bundle, "_load_search_space", _fake_search_space) - monkeypatch.setattr(llm_bundle, "SanitizedChatLiteLLM", _CapturedChatLiteLLM) - monkeypatch.setattr(llm_bundle, "register_model_usage_metadata", lambda **_kw: None) - monkeypatch.setattr( - llm_bundle, - "has_capability", - lambda _model, capability: capability in {"chat", "vision"}, - ) - - return None - - -async def test_load_llm_bundle_enables_streaming_for_db_models( - monkeypatch: pytest.MonkeyPatch, -) -> None: - connection = SimpleNamespace( - provider="openai", - api_key="sk-test", - base_url=None, - extra={"litellm_params": {"temperature": 0.1}}, - ) - model = SimpleNamespace( - id=7, - model_id="gpt-4o-mini", - display_name="GPT 4o Mini", - connection=connection, - ) - - async def _fake_db_model(_session: Any, *, model_id: int, search_space: Any) -> Any: - assert model_id == 7 - assert search_space.id == 42 - return model - - monkeypatch.setattr(llm_bundle, "_load_db_model", _fake_db_model) - monkeypatch.setattr( - llm_bundle, - "to_litellm", - lambda _conn, _model_id: ( - "openai/gpt-4o-mini", - {"api_key": "sk-test", "temperature": 0.1}, - ), - ) - - llm, agent_config, error = await llm_bundle.load_llm_bundle( - object(), - config_id=7, - search_space_id=42, - ) - - assert error is None - assert llm is not None - assert agent_config is not None - assert _CapturedChatLiteLLM.calls == [ - { - "model": "openai/gpt-4o-mini", - "api_key": "sk-test", - "temperature": 0.1, - "streaming": True, - } - ] - - -async def test_load_llm_bundle_enables_streaming_for_global_models( - monkeypatch: pytest.MonkeyPatch, -) -> None: - global_model = { - "id": -11, - "connection_id": -101, - "model_id": "claude-sonnet-4-5", - "display_name": "Claude Sonnet", - "billing_tier": "premium", - } - global_connection = { - "id": -101, - "provider": "anthropic", - "api_key": "sk-ant-test", - "base_url": None, - "extra": {"litellm_params": {"temperature": 0.2}}, - } - monkeypatch.setattr( - llm_bundle.config, - "GLOBAL_MODELS", - [global_model], - raising=False, - ) - monkeypatch.setattr( - llm_bundle.config, - "GLOBAL_CONNECTIONS", - [global_connection], - raising=False, - ) - monkeypatch.setattr( - llm_bundle, - "to_litellm", - lambda _conn, _model_id: ( - "anthropic/claude-sonnet-4-5", - {"api_key": "sk-ant-test", "temperature": 0.2}, - ), - ) - - llm, agent_config, error = await llm_bundle.load_llm_bundle( - object(), - config_id=-11, - search_space_id=42, - ) - - assert error is None - assert llm is not None - assert agent_config is not None - assert _CapturedChatLiteLLM.calls == [ - { - "model": "anthropic/claude-sonnet-4-5", - "api_key": "sk-ant-test", - "temperature": 0.2, - "streaming": True, - } - ] diff --git a/surfsense_backend/tests/unit/tasks/chat/test_llm_history_normalizer.py b/surfsense_backend/tests/unit/tasks/chat/test_llm_history_normalizer.py deleted file mode 100644 index 5b2e4fdca..000000000 --- a/surfsense_backend/tests/unit/tasks/chat/test_llm_history_normalizer.py +++ /dev/null @@ -1,61 +0,0 @@ -"""Unit tests for provider-safe LLM history normalization.""" - -from __future__ import annotations - -import pytest - -from app.tasks.chat.llm_history_normalizer import ( - assistant_content_to_llm_text, - user_content_to_llm_content, -) - -pytestmark = pytest.mark.unit - - -def test_assistant_ui_parts_drop_thinking_steps_for_llm_history() -> None: - content = [ - {"type": "data-thinking-steps", "data": [{"id": "thinking-1"}]}, - {"type": "text", "text": "visible answer"}, - ] - - assert assistant_content_to_llm_text(content) == "visible answer" - - -def test_provider_thinking_blocks_are_not_replayed_to_llm() -> None: - content = [ - {"type": "thinking", "thinking": "private reasoning"}, - {"type": "text", "text": "final answer"}, - ] - - assert assistant_content_to_llm_text(content) == "final answer" - - -def test_unknown_assistant_blocks_are_dropped() -> None: - content = [ - {"type": "redacted_thinking", "data": "hidden"}, - {"type": "tool_use", "name": "search"}, - {"type": "text", "text": "kept"}, - ] - - assert assistant_content_to_llm_text(content) == "kept" - - -def test_user_images_convert_to_openai_compatible_image_url_blocks() -> None: - content = [ - {"type": "text", "text": "look"}, - {"type": "image", "image": "data:image/png;base64,abc"}, - ] - - assert user_content_to_llm_content(content, allow_images=True) == [ - {"type": "text", "text": "look"}, - {"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}}, - ] - - -def test_user_images_can_be_dropped_for_text_only_history() -> None: - content = [ - {"type": "text", "text": "look"}, - {"type": "image", "image": "data:image/png;base64,abc"}, - ] - - assert user_content_to_llm_content(content, allow_images=False) == "look" diff --git a/surfsense_backend/tests/unit/tasks/chat/test_message_parts_normalizer.py b/surfsense_backend/tests/unit/tasks/chat/test_message_parts_normalizer.py deleted file mode 100644 index 91ca01d95..000000000 --- a/surfsense_backend/tests/unit/tasks/chat/test_message_parts_normalizer.py +++ /dev/null @@ -1,67 +0,0 @@ -"""Unit tests for final assistant message part normalization.""" - -from __future__ import annotations - -import pytest -from langchain_core.messages import AIMessage, HumanMessage, ToolMessage - -from app.tasks.chat.message_parts_normalizer import ( - final_assistant_parts_from_messages, - merge_streamed_and_final_parts, - normalize_ai_message_to_parts, -) - -pytestmark = pytest.mark.unit - - -def test_string_ai_message_content_becomes_text_part() -> None: - assert normalize_ai_message_to_parts(AIMessage(content="hello")) == [ - {"type": "text", "text": "hello"} - ] - - -def test_deepseek_thinking_plus_text_blocks_backfill_only_text() -> None: - message = AIMessage( - content=[ - {"type": "thinking", "thinking": "hidden reasoning"}, - {"type": "text", "text": "Yo bro! What's up?"}, - ], - additional_kwargs={"reasoning_content": "hidden reasoning"}, - ) - - assert normalize_ai_message_to_parts(message) == [ - {"type": "text", "text": "Yo bro! What's up?"} - ] - - -def test_final_parts_use_last_ai_message_and_skip_trailing_tool_messages() -> None: - messages = [ - HumanMessage(content="ask"), - AIMessage(content="draft"), - ToolMessage(content="tool output", tool_call_id="tc-1"), - AIMessage(content=[{"type": "text", "text": "final answer"}]), - ToolMessage(content="trailing tool noise", tool_call_id="tc-2"), - ] - - assert final_assistant_parts_from_messages(messages) == [ - {"type": "text", "text": "final answer"} - ] - - -def test_merge_adds_final_text_when_stream_only_has_thinking_steps() -> None: - streamed = [ - { - "type": "data-thinking-steps", - "data": [{"id": "thinking-1", "status": "completed"}], - } - ] - final = [{"type": "text", "text": "visible answer"}] - - assert merge_streamed_and_final_parts(streamed, final) == [*streamed, *final] - - -def test_merge_does_not_duplicate_when_stream_already_has_text() -> None: - streamed = [{"type": "text", "text": "streamed answer"}] - final = [{"type": "text", "text": "final answer"}] - - assert merge_streamed_and_final_parts(streamed, final) == streamed diff --git a/surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py b/surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py index e82957acd..792d059b0 100644 --- a/surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py +++ b/surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py @@ -35,7 +35,7 @@ def test_safety_net_does_not_fire_for_azure_gpt_4o(): it text-only.""" assert ( is_known_text_only_chat_model( - provider="azure", + provider="AZURE_OPENAI", model_name="my-azure-deployment", base_model="gpt-4o", ) @@ -49,7 +49,7 @@ def test_safety_net_does_not_fire_for_unknown_model(): LiteLLM doesn't know about must flow through to the provider.""" assert ( is_known_text_only_chat_model( - provider="custom", + provider="CUSTOM", custom_provider="brand_new_proxy", model_name="brand-new-model-x9", ) @@ -69,7 +69,7 @@ def test_safety_net_does_not_fire_when_lookup_raises(monkeypatch): assert ( is_known_text_only_chat_model( - provider="openai", + provider="OPENAI", model_name="gpt-4o", ) is False @@ -88,7 +88,7 @@ def test_safety_net_fires_only_on_explicit_false(monkeypatch): monkeypatch.setattr(pc.litellm, "get_model_info", _info_explicit_false) assert ( is_known_text_only_chat_model( - provider="openai", + provider="OPENAI", model_name="text-only-stub", ) is True @@ -100,7 +100,7 @@ def test_safety_net_fires_only_on_explicit_false(monkeypatch): monkeypatch.setattr(pc.litellm, "get_model_info", _info_true) assert ( is_known_text_only_chat_model( - provider="openai", + provider="OPENAI", model_name="vision-stub", ) is False @@ -112,7 +112,7 @@ def test_safety_net_fires_only_on_explicit_false(monkeypatch): monkeypatch.setattr(pc.litellm, "get_model_info", _info_missing) assert ( is_known_text_only_chat_model( - provider="openai", + provider="OPENAI", model_name="missing-key-stub", ) is False diff --git a/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py b/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py index 7183024ed..423b64ddb 100644 --- a/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py +++ b/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py @@ -98,7 +98,8 @@ async def _denying_billable_call(**kwargs): _CALL_LOG.append(kwargs) raise QuotaInsufficientError( usage_type=kwargs.get("usage_type", "?"), - balance_micros=5_000_000, + used_micros=5_000_000, + limit_micros=5_000_000, remaining_micros=0, ) yield SimpleNamespace() # pragma: no cover diff --git a/surfsense_backend/uv.lock b/surfsense_backend/uv.lock index 8c540b41c..182b9679f 100644 --- a/surfsense_backend/uv.lock +++ b/surfsense_backend/uv.lock @@ -18,9 +18,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform == 'linux' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra == 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -37,9 +34,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version < '3.13' and sys_platform == 'linux' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra == 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -56,9 +50,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform == 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra == 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -75,9 +66,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform == 'emscripten' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra == 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -94,9 +82,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra == 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -113,9 +98,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform == 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra == 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -132,9 +114,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform == 'emscripten' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra == 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -151,9 +130,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra == 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -170,9 +146,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version < '3.13' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra == 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -189,9 +162,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version < '3.13' and sys_platform == 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra == 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -213,10 +183,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform == 'linux' and extra != 'extra-16-surf-new-backend-cpu' and extra == 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -233,9 +199,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform == 'linux' and extra != 'extra-16-surf-new-backend-cpu' and extra == 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -252,9 +215,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version < '3.13' and sys_platform == 'linux' and extra != 'extra-16-surf-new-backend-cpu' and extra == 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -271,9 +231,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform == 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra == 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -290,9 +247,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform == 'emscripten' and extra != 'extra-16-surf-new-backend-cpu' and extra == 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -309,9 +263,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra == 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -328,9 +279,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform == 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra == 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -347,9 +295,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform == 'emscripten' and extra != 'extra-16-surf-new-backend-cpu' and extra == 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -366,9 +311,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra == 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -385,9 +327,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version < '3.13' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra == 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -404,9 +343,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version < '3.13' and sys_platform == 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra == 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -428,10 +364,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform == 'linux' and extra == 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -448,9 +380,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform == 'linux' and extra == 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -467,9 +396,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version < '3.13' and sys_platform == 'linux' and extra == 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -486,9 +412,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform == 'win32' and extra == 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -505,9 +428,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform == 'emscripten' and extra == 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -524,9 +444,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -543,9 +460,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform == 'win32' and extra == 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -562,9 +476,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform == 'emscripten' and extra == 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -581,9 +492,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -600,9 +508,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version < '3.13' and sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -619,9 +524,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version < '3.13' and sys_platform == 'win32' and extra == 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -643,10 +545,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform == 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -663,9 +561,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform == 'emscripten' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -697,12 +592,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -719,9 +608,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform == 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -738,9 +624,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform == 'emscripten' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -772,12 +655,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform != 'emscripten' and sys_platform != 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -809,12 +686,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version < '3.13' and sys_platform != 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -831,9 +702,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version < '3.13' and sys_platform == 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", ] conflicts = [[ @@ -3331,6 +3199,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e8/2d/d2a548598be01649e2d46231d151a6c56d10b964d94043a335ae56ea2d92/flatbuffers-25.12.19-py2.py3-none-any.whl", hash = "sha256:7634f50c427838bb021c2d66a3d1168e9d199b0607e6329399f04846d42e20b4", size = 26661 }, ] +[[package]] +name = "flower" +version = "2.0.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "celery" }, + { name = "humanize" }, + { name = "prometheus-client" }, + { name = "pytz" }, + { name = "tornado" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/09/a1/357f1b5d8946deafdcfdd604f51baae9de10aafa2908d0b7322597155f92/flower-2.0.1.tar.gz", hash = "sha256:5ab717b979530770c16afb48b50d2a98d23c3e9fe39851dcf6bc4d01845a02a0", size = 3220408 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a6/ff/ee2f67c0ff146ec98b5df1df637b2bc2d17beeb05df9f427a67bd7a7d79c/flower-2.0.1-py2.py3-none-any.whl", hash = "sha256:9db2c621eeefbc844c8dd88be64aef61e84e2deb29b271e02ab2b5b9f01068e2", size = 383553 }, +] + [[package]] name = "flupy" version = "1.2.3" @@ -4029,6 +3913,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a8/af/48ac8483240de756d2438c380746e7130d1c6f75802ef22f3c6d49982787/huggingface_hub-0.36.2-py3-none-any.whl", hash = "sha256:48f0c8eac16145dfce371e9d2d7772854a4f591bcb56c9cf548accf531d54270", size = 566395 }, ] +[[package]] +name = "humanize" +version = "4.15.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ba/66/a3921783d54be8a6870ac4ccffcd15c4dc0dd7fcce51c6d63b8c63935276/humanize-4.15.0.tar.gz", hash = "sha256:1dd098483eb1c7ee8e32eb2e99ad1910baefa4b75c3aff3a82f4d78688993b10", size = 83599 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c5/7b/bca5613a0c3b542420cf92bd5e5fb8ebd5435ce1011a091f66bb7693285e/humanize-4.15.0-py3-none-any.whl", hash = "sha256:b1186eb9f5a9749cd9cb8565aee77919dd7c8d076161cf44d70e59e3301e1769", size = 132203 }, +] + [[package]] name = "hyperframe" version = "6.1.0" @@ -7546,6 +7439,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/42/59/123aee44a039b212cfb8d90be1adf06496a99b313ee1683aadf90b3d9799/progress-1.6.1-py3-none-any.whl", hash = "sha256:5239f22f305c12fdc8ce6e0e47f70f21622a935e16eafc4535617112e7c7ea0b", size = 9761 }, ] +[[package]] +name = "prometheus-client" +version = "0.24.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f0/58/a794d23feb6b00fc0c72787d7e87d872a6730dd9ed7c7b3e954637d8f280/prometheus_client-0.24.1.tar.gz", hash = "sha256:7e0ced7fbbd40f7b84962d5d2ab6f17ef88a72504dcf7c0b40737b43b2a461f9", size = 85616 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/74/c3/24a2f845e3917201628ecaba4f18bab4d18a337834c1df2a159ee9d22a42/prometheus_client-0.24.1-py3-none-any.whl", hash = "sha256:150db128af71a5c2482b36e588fc8a6b95e498750da4b17065947c16070f4055", size = 64057 }, +] + [[package]] name = "prompt-toolkit" version = "3.0.52" @@ -9712,7 +9614,7 @@ wheels = [ [[package]] name = "surf-new-backend" -version = "0.0.29" +version = "0.0.28" source = { editable = "." } dependencies = [ { name = "alembic" }, @@ -9737,6 +9639,7 @@ dependencies = [ { name = "fastapi-users", extra = ["oauth", "sqlalchemy"] }, { name = "faster-whisper" }, { name = "firecrawl-py" }, + { name = "flower" }, { name = "fractional-indexing" }, { name = "github3-py" }, { name = "gitingest" }, @@ -9855,6 +9758,7 @@ requires-dist = [ { name = "fastapi-users", extras = ["oauth", "sqlalchemy"], specifier = ">=15.0.3" }, { name = "faster-whisper", specifier = ">=1.1.0" }, { name = "firecrawl-py", specifier = ">=4.9.0" }, + { name = "flower", specifier = ">=2.0.1" }, { name = "fractional-indexing", specifier = ">=0.1.3" }, { name = "github3-py", specifier = "==4.0.1" }, { name = "gitingest", specifier = ">=0.3.1" }, diff --git a/surfsense_browser_extension/package.json b/surfsense_browser_extension/package.json index 4d888acdb..e7f0f082c 100644 --- a/surfsense_browser_extension/package.json +++ b/surfsense_browser_extension/package.json @@ -1,7 +1,7 @@ { "name": "surfsense_browser_extension", "displayName": "Surfsense Browser Extension", - "version": "0.0.29", + "version": "0.0.28", "description": "Extension to collect Browsing History for SurfSense.", "author": "https://github.com/MODSetter", "engines": { diff --git a/surfsense_desktop/package.json b/surfsense_desktop/package.json index 5b663de00..f4cc9586d 100644 --- a/surfsense_desktop/package.json +++ b/surfsense_desktop/package.json @@ -1,7 +1,7 @@ { "name": "surfsense-desktop", "productName": "SurfSense", - "version": "0.0.29", + "version": "0.0.28", "description": "SurfSense Desktop App", "main": "dist/main.js", "scripts": { diff --git a/surfsense_desktop/scripts/build-electron.mjs b/surfsense_desktop/scripts/build-electron.mjs index cc2083fe4..75a3cdf61 100644 --- a/surfsense_desktop/scripts/build-electron.mjs +++ b/surfsense_desktop/scripts/build-electron.mjs @@ -108,11 +108,8 @@ async function buildElectron() { sourcemap: true, minify: false, define: { - 'process.env.HOSTED_BACKEND_URL': JSON.stringify( - process.env.HOSTED_BACKEND_URL || desktopEnv.HOSTED_BACKEND_URL || '' - ), 'process.env.HOSTED_FRONTEND_URL': JSON.stringify( - process.env.HOSTED_FRONTEND_URL || desktopEnv.HOSTED_FRONTEND_URL || 'https://surfsense.com' + process.env.HOSTED_FRONTEND_URL || desktopEnv.HOSTED_FRONTEND_URL || 'https://surfsense.net' ), 'process.env.POSTHOG_KEY': JSON.stringify( process.env.POSTHOG_KEY || desktopEnv.POSTHOG_KEY || '' diff --git a/surfsense_desktop/src/modules/server.ts b/surfsense_desktop/src/modules/server.ts index d7274ad9c..fc2fa05c3 100644 --- a/surfsense_desktop/src/modules/server.ts +++ b/surfsense_desktop/src/modules/server.ts @@ -43,13 +43,11 @@ export async function startNextServer(): Promise { const standalonePath = getStandalonePath(); const serverScript = path.join(standalonePath, 'server.js'); - const backendInternalUrl = process.env.SURFSENSE_BACKEND_INTERNAL_URL || process.env.HOSTED_BACKEND_URL; const child = utilityProcess.fork(serverScript, [], { cwd: standalonePath, env: { ...process.env, - ...(backendInternalUrl ? { SURFSENSE_BACKEND_INTERNAL_URL: backendInternalUrl } : {}), PORT: String(serverPort), // Loopback bind: avoids 0.0.0.0 leaking into request.url and redirect origins. HOSTNAME: SERVER_HOST, diff --git a/surfsense_evals/README.md b/surfsense_evals/README.md index e6fc52ca1..c755c4de6 100644 --- a/surfsense_evals/README.md +++ b/surfsense_evals/README.md @@ -77,7 +77,7 @@ The walkthrough above is `--scenario head-to-head` (default): both arms answer w | `symmetric-cheap` | `--provider-model` (cheap, text-only) | `--provider-model` (same) | Does pre-extracted image context let a non-vision LLM reason over image-heavy docs? | | `cost-arbitrage` | `--native-arm-model` (vision) | `--provider-model` (cheap) | How close does SurfSense get to a vision-native baseline at a fraction of per-query cost?| -In all three modes the **ingest-time** vision LLM is set on the SearchSpace's `vision_model_id` (auto-picked from the strongest registered global OpenRouter vision-capable model — `claude-sonnet-4.5` > `claude-opus-4.7` > `gpt-5` > `gemini-2.5-pro`, override with `--vision-llm `). What changes is which slug the *answering* models hit per arm. +In all three modes the **ingest-time** vision LLM is set on the SearchSpace's `vision_llm_config_id` (auto-picked from the strongest registered global OpenRouter vision config — `claude-sonnet-4.5` > `claude-opus-4.7` > `gpt-5` > `gemini-2.5-pro`, override with `--vision-llm `). What changes is which slug the *answering* models hit per arm. ### Ingest with vision, evaluate with a non-vision LLM (`symmetric-cheap`) @@ -118,7 +118,7 @@ python -m surfsense_evals report --suite medical Notes: - `cost-arbitrage` requires both `--provider-model` (the cheap SurfSense slug) AND `--native-arm-model `. -- `--vision-llm ` is optional; if omitted the harness queries `GET /api/v1/model-connections/global` and auto-picks the strongest registered vision-capable model. Pass `--no-vision-llm-setup` if you want to keep whatever vision model is already attached to the SearchSpace. +- `--vision-llm ` is optional; if omitted the harness queries `GET /api/v1/global-vision-llm-configs` and auto-picks the strongest registered one. Pass `--no-vision-llm-setup` if you want to keep whatever vision config is already attached to the SearchSpace. - The runner's "looks text-only" warning is suppressed (or relabelled as informational) for `symmetric-cheap` so intentional asymmetry doesn't read as a misconfiguration. - All three scenario fields (`scenario`, `provider_model`, `native_arm_model`, `vision_provider_model`) are persisted to `state.json` and recorded in `run_artifact.extra` + the report header — no need to retrace what was set. diff --git a/surfsense_evals/data/multimodal_doc/runs/2026-05-14T00-53-19Z/parser_compare/run_artifact.json b/surfsense_evals/data/multimodal_doc/runs/2026-05-14T00-53-19Z/parser_compare/run_artifact.json index b6c59e2bc..a4687f64a 100644 --- a/surfsense_evals/data/multimodal_doc/runs/2026-05-14T00-53-19Z/parser_compare/run_artifact.json +++ b/surfsense_evals/data/multimodal_doc/runs/2026-05-14T00-53-19Z/parser_compare/run_artifact.json @@ -9,7 +9,7 @@ "llamacloud_premium_lc", "surfsense_agentic" ], - "chat_model_id": -5138454, + "agent_llm_id": -5138454, "concurrency": 2, "llm_model": "anthropic/claude-sonnet-4.5", "n_pdfs": 30, diff --git a/surfsense_evals/src/surfsense_evals/core/cli.py b/surfsense_evals/src/surfsense_evals/core/cli.py index 17979fba0..3d4d0fd24 100644 --- a/surfsense_evals/src/surfsense_evals/core/cli.py +++ b/surfsense_evals/src/surfsense_evals/core/cli.py @@ -2,7 +2,7 @@ Subcommands: -* ``setup --suite --provider-model [--chat-model-id ]`` +* ``setup --suite --provider-model [--agent-llm-id ]`` * ``teardown --suite `` * ``models list [--provider openrouter] [--grep ]`` * ``suites list`` @@ -18,7 +18,7 @@ publish its own flags. Design choices worth flagging: -* ``setup`` rejects ``chat_model_id == 0`` (Auto / LiteLLM router) so +* ``setup`` rejects ``agent_llm_id == 0`` (Auto / LiteLLM router) so per-question accuracy is reproducible. * ``setup`` validates that the picked LLM config has ``provider == "OPENROUTER"`` and ``model_name == --provider-model`` @@ -59,6 +59,7 @@ if sys.platform == "win32": from . import registry from .auth import CredentialError, acquire_token, client_with_auth from .clients import SearchSpaceClient +from .clients.search_space import LlmPreferences from .config import ( DEFAULT_SCENARIO, SCENARIOS, @@ -110,30 +111,23 @@ class LlmConfigEntry: def from_payload(cls, payload: dict[str, Any]) -> LlmConfigEntry: return cls( id=int(payload["id"]), - name=str(payload.get("display_name") or payload.get("name") or ""), + name=str(payload.get("name", "")), provider=str(payload.get("provider", "")).upper(), - model_name=str(payload.get("model_id") or payload.get("model_name") or ""), + model_name=str(payload.get("model_name", "")), raw=payload, ) async def _list_global_llm_configs(http: httpx.AsyncClient, base: str) -> list[LlmConfigEntry]: response = await http.get( - f"{base}/api/v1/model-connections/global", + f"{base}/api/v1/global-new-llm-configs", headers={"Accept": "application/json"}, ) response.raise_for_status() payload = response.json() if not isinstance(payload, list): - raise RuntimeError(f"Unexpected /model-connections/global payload: {payload!r}") - entries: list[LlmConfigEntry] = [] - for connection in payload: - provider = connection.get("provider", "") - for model in connection.get("models") or []: - if not model.get("enabled", True) or not model.get("supports_chat"): - continue - entries.append(LlmConfigEntry.from_payload({**model, "provider": provider})) - return entries + raise RuntimeError(f"Unexpected /global-new-llm-configs payload: {payload!r}") + return [LlmConfigEntry.from_payload(item) for item in payload] def _resolve_openrouter_id( @@ -149,8 +143,8 @@ def _resolve_openrouter_id( * If ``explicit_id`` is given: return it directly. The caller is then expected to GET-validate that the row's ``provider == "OPENROUTER"`` and ``model_name`` matches the slug. - That branch supports positive BYOK model rows whose slugs may overlap - with global OpenRouter virtuals. + That branch supports positive BYOK ``NewLLMConfig`` rows whose + slugs may overlap with global OpenRouter virtuals. * Otherwise: filter to ``provider == "OPENROUTER"`` and ``model_name == provider_model``. Expect exactly one match — raise with a friendly message otherwise. @@ -179,7 +173,7 @@ def _resolve_openrouter_id( listing = "\n".join(f" id={c.id} name={c.name!r}" for c in matches) raise RuntimeError( f"Multiple OpenRouter configs for slug '{provider_model}':\n{listing}\n" - "Pass --chat-model-id to disambiguate." + "Pass --agent-llm-id to disambiguate." ) return matches[0].id @@ -192,7 +186,7 @@ def _resolve_openrouter_id( async def _cmd_setup(args: argparse.Namespace) -> int: suite = args.suite provider_model: str = args.provider_model - explicit_id: int | None = args.chat_model_id + explicit_id: int | None = args.agent_llm_id scenario: str = args.scenario vision_llm_slug: str | None = args.vision_llm native_arm_model: str | None = args.native_arm_model @@ -200,7 +194,7 @@ async def _cmd_setup(args: argparse.Namespace) -> int: if explicit_id == 0: console.print( - "[red]chat_model_id == 0 (Auto / LiteLLM router) is not allowed — " + "[red]agent_llm_id == 0 (Auto / LiteLLM router) is not allowed — " "results would not be reproducible.[/red]" ) return 2 @@ -248,7 +242,7 @@ async def _cmd_setup(args: argparse.Namespace) -> int: candidates = await _list_global_llm_configs(http, config.surfsense_api_base) try: - chat_model_id = _resolve_openrouter_id( + agent_llm_id = _resolve_openrouter_id( candidates, provider_model, explicit_id=explicit_id ) except RuntimeError as exc: @@ -294,7 +288,7 @@ async def _cmd_setup(args: argparse.Namespace) -> int: vision_provider_model: str | None = None if not skip_vision_setup and (vision_required or vision_llm_slug is not None): try: - vision_candidates = await ss_client.list_global_vision_models() + vision_candidates = await ss_client.list_global_vision_llm_configs() resolved = resolve_vision_llm( vision_candidates, explicit_slug=vision_llm_slug ) @@ -308,34 +302,37 @@ async def _cmd_setup(args: argparse.Namespace) -> int: f"(id={vision_config_id}, selected_via={resolved.selected_via})." ) - role_kwargs: dict[str, Any] = {"chat_model_id": chat_model_id} + pref_kwargs: dict[str, Any] = {"agent_llm_id": agent_llm_id} if vision_config_id is not None: - role_kwargs["vision_model_id"] = vision_config_id + pref_kwargs["vision_llm_config_id"] = vision_config_id - await ss_client.set_model_roles(search_space_id, **role_kwargs) - roles = await ss_client.get_model_roles(search_space_id) - if roles.chat_model_id != chat_model_id: + await ss_client.set_llm_preferences(search_space_id, **pref_kwargs) + prefs = await ss_client.get_llm_preferences(search_space_id) + if not _validate_pin(prefs, provider_model): + agent = prefs.agent_llm or {} console.print( f"[red]LLM pin validation FAILED.[/red] After PUT, " - f"chat_model_id={roles.chat_model_id!r}; expected {chat_model_id!r}." + f"agent_llm.provider={agent.get('provider')!r}, " + f"model_name={agent.get('model_name')!r}; expected " + f"provider=OPENROUTER, model_name={provider_model!r}." ) return 2 - if vision_config_id is not None and roles.vision_model_id != vision_config_id: + if vision_config_id is not None and prefs.vision_llm_config_id != vision_config_id: console.print( f"[red]Vision LLM pin validation FAILED.[/red] After PUT, " - f"vision_model_id={roles.vision_model_id!r}; " + f"vision_llm_config_id={prefs.vision_llm_config_id!r}; " f"expected {vision_config_id!r}." ) return 2 suite_state = SuiteState( search_space_id=search_space_id, - chat_model_id=chat_model_id, + agent_llm_id=agent_llm_id, provider_model=provider_model, created_at=utc_iso_timestamp(), ingestion_maps=existing.ingestion_maps if existing else {}, scenario=scenario, - vision_model_id=vision_config_id, + vision_llm_config_id=vision_config_id, vision_provider_model=vision_provider_model, native_arm_model=native_arm_model, ) @@ -345,7 +342,7 @@ async def _cmd_setup(args: argparse.Namespace) -> int: f"suite={suite!r}", f"scenario={scenario!r}", f"search_space_id={suite_state.search_space_id}", - f"chat_model_id={suite_state.chat_model_id}", + f"agent_llm_id={suite_state.agent_llm_id}", f"provider_model={suite_state.provider_model!r}", ] if suite_state.vision_provider_model: @@ -356,6 +353,14 @@ async def _cmd_setup(args: argparse.Namespace) -> int: return 0 +def _validate_pin(prefs: LlmPreferences, provider_model: str) -> bool: + agent = prefs.agent_llm or {} + return ( + str(agent.get("provider", "")).upper() == "OPENROUTER" + and str(agent.get("model_name", "")) == provider_model + ) + + async def _cmd_teardown(args: argparse.Namespace) -> int: suite = args.suite config = load_config() @@ -649,10 +654,10 @@ def _build_parser() -> argparse.ArgumentParser: ), ) p_setup.add_argument( - "--chat-model-id", + "--agent-llm-id", type=int, default=None, - help="Optional explicit model id override.", + help="Optional override for BYOK NewLLMConfig rows.", ) p_setup.add_argument( "--scenario", diff --git a/surfsense_evals/src/surfsense_evals/core/clients/search_space.py b/surfsense_evals/src/surfsense_evals/core/clients/search_space.py index efd4a571d..e2d37694d 100644 --- a/surfsense_evals/src/surfsense_evals/core/clients/search_space.py +++ b/surfsense_evals/src/surfsense_evals/core/clients/search_space.py @@ -1,16 +1,17 @@ -"""Client for ``/api/v1/searchspaces`` and model-role endpoints. +"""Client for ``/api/v1/searchspaces`` and ``/api/v1/search-spaces/{id}/llm-preferences``. Verified against: * ``surfsense_backend/app/routes/search_spaces_routes.py:116`` (POST create) * ``surfsense_backend/app/routes/search_spaces_routes.py:234`` (GET by id) * ``surfsense_backend/app/routes/search_spaces_routes.py:422`` (DELETE soft-delete) -* ``surfsense_backend/app/routes/model_connections_routes.py`` (GET/PUT model roles) +* ``surfsense_backend/app/routes/search_spaces_routes.py:698-849`` (GET/PUT llm-preferences) * ``surfsense_backend/app/schemas/search_space.py:14`` (SearchSpaceCreate body) +* ``surfsense_backend/app/routes/vision_llm_routes.py:60`` (GET global vision configs) Note the inconsistent pluralisation in the backend: ``/searchspaces`` -(no hyphen) for CRUD, but ``/search-spaces`` (hyphenated) for model-role -sub-resources. Both are mirrored verbatim here. +(no hyphen) for CRUD, but ``/search-spaces`` (hyphenated) for the +``llm-preferences`` sub-resource. Both are mirrored verbatim here. """ from __future__ import annotations @@ -45,8 +46,13 @@ class SearchSpaceRow: @dataclass -class VisionModelEntry: - """Subset of one GLOBAL model-connection model with image input support.""" +class VisionLlmConfigEntry: + """Subset of one ``GET /global-vision-llm-configs`` row. + + The backend returns negative ids for global / OpenRouter-derived + vision configs and positive ids for per-user BYOK rows. Either is + accepted by ``set_llm_preferences(vision_llm_config_id=...)``. + """ id: int name: str @@ -56,38 +62,45 @@ class VisionModelEntry: raw: dict[str, Any] @classmethod - def from_payload(cls, payload: dict[str, Any]) -> VisionModelEntry: + def from_payload(cls, payload: dict[str, Any]) -> VisionLlmConfigEntry: return cls( id=int(payload.get("id", 0)), - name=str(payload.get("display_name") or payload.get("model_id") or ""), + name=str(payload.get("name", "")), provider=str(payload.get("provider", "")).upper(), - model_name=str(payload.get("model_id", "")), - is_auto_mode=False, + model_name=str(payload.get("model_name", "")), + is_auto_mode=bool(payload.get("is_auto_mode", False)), raw=payload, ) @dataclass -class ModelRoles: - """Model role ids for a search space.""" +class LlmPreferences: + """Resolved LLM preferences with the embedded full config row. - chat_model_id: int | None - image_gen_model_id: int | None - vision_model_id: int | None + Mirrors ``LLMPreferencesRead`` from the backend so the lifecycle + command can introspect ``provider`` / ``model_name`` to validate the + OpenRouter pin. + """ + + agent_llm_id: int | None + image_generation_config_id: int | None + vision_llm_config_id: int | None + agent_llm: dict[str, Any] | None raw: dict[str, Any] @classmethod - def from_payload(cls, payload: dict[str, Any]) -> ModelRoles: + def from_payload(cls, payload: dict[str, Any]) -> LlmPreferences: return cls( - chat_model_id=payload.get("chat_model_id"), - image_gen_model_id=payload.get("image_gen_model_id"), - vision_model_id=payload.get("vision_model_id"), + agent_llm_id=payload.get("agent_llm_id"), + image_generation_config_id=payload.get("image_generation_config_id"), + vision_llm_config_id=payload.get("vision_llm_config_id"), + agent_llm=payload.get("agent_llm"), raw=payload, ) class SearchSpaceClient: - """Thin wrapper around the SearchSpace + model role endpoints.""" + """Thin wrapper around the SearchSpace + LLM preferences endpoints.""" def __init__(self, http: httpx.AsyncClient, base_url: str) -> None: self._http = http @@ -126,67 +139,64 @@ class SearchSpaceClient: return response.raise_for_status() - async def get_model_roles(self, search_space_id: int) -> ModelRoles: + async def get_llm_preferences(self, search_space_id: int) -> LlmPreferences: response = await self._http.get( - f"{self._base}/api/v1/search-spaces/{search_space_id}/model-roles", + f"{self._base}/api/v1/search-spaces/{search_space_id}/llm-preferences", headers={"Accept": "application/json"}, ) response.raise_for_status() - return ModelRoles.from_payload(response.json()) + return LlmPreferences.from_payload(response.json()) - async def set_model_roles( + async def set_llm_preferences( self, search_space_id: int, *, - chat_model_id: int | None = None, - image_gen_model_id: int | None = None, - vision_model_id: int | None = None, - ) -> ModelRoles: - """PUT a partial update to ``/search-spaces/{id}/model-roles``. + agent_llm_id: int | None = None, + image_generation_config_id: int | None = None, + vision_llm_config_id: int | None = None, + ) -> LlmPreferences: + """PUT a partial update to ``/search-spaces/{id}/llm-preferences``. Backend uses ``model_dump(exclude_unset=True)`` so omitted fields are left unchanged. """ body: dict[str, Any] = {} - if chat_model_id is not None: - body["chat_model_id"] = chat_model_id - if image_gen_model_id is not None: - body["image_gen_model_id"] = image_gen_model_id - if vision_model_id is not None: - body["vision_model_id"] = vision_model_id + if agent_llm_id is not None: + body["agent_llm_id"] = agent_llm_id + if image_generation_config_id is not None: + body["image_generation_config_id"] = image_generation_config_id + if vision_llm_config_id is not None: + body["vision_llm_config_id"] = vision_llm_config_id response = await self._http.put( - f"{self._base}/api/v1/search-spaces/{search_space_id}/model-roles", + f"{self._base}/api/v1/search-spaces/{search_space_id}/llm-preferences", json=body, headers={"Accept": "application/json"}, ) response.raise_for_status() - return ModelRoles.from_payload(response.json()) + return LlmPreferences.from_payload(response.json()) - async def list_global_vision_models(self) -> list[VisionModelEntry]: - """List registered GLOBAL models that can accept image input. + async def list_global_vision_llm_configs(self) -> list[VisionLlmConfigEntry]: + """List the registered global vision LLM configs. - Used by ``setup`` to resolve ``--vision-llm `` or auto-pick a - reproducible ingest-time vision model. + Used by ``setup`` to (a) resolve an explicit ``--vision-llm `` + to a config id and (b) auto-pick the strongest registered vision + config when the operator doesn't pass one. The ``Auto (Fastest)`` + entry (``id=0``) is filtered out — accuracy must be reproducible. """ response = await self._http.get( - f"{self._base}/api/v1/model-connections/global", + f"{self._base}/api/v1/global-vision-llm-configs", headers={"Accept": "application/json"}, ) response.raise_for_status() payload = response.json() if not isinstance(payload, list): raise RuntimeError( - f"Unexpected /model-connections/global payload: {payload!r}" + f"Unexpected /global-vision-llm-configs payload: {payload!r}" ) - entries: list[VisionModelEntry] = [] - for connection in payload: - provider = str(connection.get("provider", "")) - for model in connection.get("models") or []: - if not model.get("enabled", True) or not model.get("supports_image_input"): - continue - entries.append( - VisionModelEntry.from_payload({**model, "provider": provider}) - ) - return entries + return [ + VisionLlmConfigEntry.from_payload(item) + for item in payload + if not bool(item.get("is_auto_mode", False)) + ] diff --git a/surfsense_evals/src/surfsense_evals/core/config.py b/surfsense_evals/src/surfsense_evals/core/config.py index 9a5a71e89..164955914 100644 --- a/surfsense_evals/src/surfsense_evals/core/config.py +++ b/surfsense_evals/src/surfsense_evals/core/config.py @@ -147,35 +147,35 @@ class SuiteState: """Per-suite persisted state. ``provider_model`` is the slug pinned to the SearchSpace's - ``chat_model_id`` — what answers SurfSense queries (and what the native + ``agent_llm`` — what answers SurfSense queries (and what the native arm uses too, unless ``native_arm_model`` is set for cost-arbitrage). - ``vision_provider_model`` is the slug of the OpenRouter vision model - attached to the SearchSpace's ``vision_model_id`` — what + ``vision_provider_model`` is the slug of the OpenRouter vision LLM + config attached to the SearchSpace's ``vision_llm_config_id`` — what SurfSense uses to extract image content at ingest time when ``use_vision_llm=True``. ``None`` means no vision config was attached at setup (legacy or text-only suite). """ search_space_id: int - chat_model_id: int + agent_llm_id: int provider_model: str created_at: str ingestion_maps: dict[str, str] = field(default_factory=dict) scenario: str = DEFAULT_SCENARIO - vision_model_id: int | None = None + vision_llm_config_id: int | None = None vision_provider_model: str | None = None native_arm_model: str | None = None def to_dict(self) -> dict[str, Any]: return { "search_space_id": self.search_space_id, - "chat_model_id": self.chat_model_id, + "agent_llm_id": self.agent_llm_id, "provider_model": self.provider_model, "created_at": self.created_at, "ingestion_maps": dict(self.ingestion_maps), "scenario": self.scenario, - "vision_model_id": self.vision_model_id, + "vision_llm_config_id": self.vision_llm_config_id, "vision_provider_model": self.vision_provider_model, "native_arm_model": self.native_arm_model, } @@ -187,16 +187,15 @@ class SuiteState: scenario = str(payload.get("scenario") or DEFAULT_SCENARIO) if scenario not in SCENARIOS: scenario = DEFAULT_SCENARIO - raw_chat_id = payload.get("chat_model_id") - raw_vision_id = payload.get("vision_model_id") + raw_vision_id = payload.get("vision_llm_config_id") return cls( search_space_id=int(payload["search_space_id"]), - chat_model_id=int(raw_chat_id), + agent_llm_id=int(payload["agent_llm_id"]), provider_model=str(payload["provider_model"]), created_at=str(payload.get("created_at") or ""), ingestion_maps=dict(payload.get("ingestion_maps") or {}), scenario=scenario, - vision_model_id=int(raw_vision_id) if raw_vision_id is not None else None, + vision_llm_config_id=int(raw_vision_id) if raw_vision_id is not None else None, vision_provider_model=( str(payload["vision_provider_model"]) if payload.get("vision_provider_model") diff --git a/surfsense_evals/src/surfsense_evals/core/registry.py b/surfsense_evals/src/surfsense_evals/core/registry.py index 65f64c39a..cc8b725e0 100644 --- a/surfsense_evals/src/surfsense_evals/core/registry.py +++ b/surfsense_evals/src/surfsense_evals/core/registry.py @@ -53,8 +53,8 @@ class RunContext: return self.suite_state.search_space_id @property - def chat_model_id(self) -> int: - return self.suite_state.chat_model_id + def agent_llm_id(self) -> int: + return self.suite_state.agent_llm_id @property def provider_model(self) -> str: diff --git a/surfsense_evals/src/surfsense_evals/core/vision_llm.py b/surfsense_evals/src/surfsense_evals/core/vision_llm.py index 5d5e2c6d1..ae96f1285 100644 --- a/surfsense_evals/src/surfsense_evals/core/vision_llm.py +++ b/surfsense_evals/src/surfsense_evals/core/vision_llm.py @@ -3,8 +3,8 @@ Two responsibilities: 1. Resolve an explicit ``--vision-llm `` to a global OpenRouter - vision-capable model id that ``set_model_roles(vision_model_id=...)`` can - accept. + vision LLM config id that ``set_llm_preferences(vision_llm_config_id=...)`` + can accept. 2. Auto-pick the strongest registered vision config when the operator doesn't pass ``--vision-llm`` but the scenario / benchmark needs one. diff --git a/surfsense_evals/src/surfsense_evals/suites/medical/medxpertqa/runner.py b/surfsense_evals/src/surfsense_evals/suites/medical/medxpertqa/runner.py index ac0651996..e1a830138 100644 --- a/surfsense_evals/src/surfsense_evals/suites/medical/medxpertqa/runner.py +++ b/surfsense_evals/src/surfsense_evals/suites/medical/medxpertqa/runner.py @@ -371,7 +371,7 @@ class MedXpertQAMMBenchmark: "provider_model": ctx.provider_model, "native_arm_model": native_arm_model, "vision_provider_model": ctx.vision_provider_model, - "chat_model_id": ctx.chat_model_id, + "agent_llm_id": ctx.agent_llm_id, "ingest_settings": ingest_settings, }, ) diff --git a/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/mmlongbench/runner.py b/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/mmlongbench/runner.py index b7685766e..95a1e15eb 100644 --- a/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/mmlongbench/runner.py +++ b/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/mmlongbench/runner.py @@ -391,7 +391,7 @@ class MMLongBenchDocBenchmark: "provider_model": ctx.provider_model, "native_arm_model": native_arm_model, "vision_provider_model": ctx.vision_provider_model, - "chat_model_id": ctx.chat_model_id, + "agent_llm_id": ctx.agent_llm_id, "ingest_settings": ingest_settings, }, ) diff --git a/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/parser_compare/runner.py b/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/parser_compare/runner.py index 2c4a0ffe4..e71dffa65 100644 --- a/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/parser_compare/runner.py +++ b/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/parser_compare/runner.py @@ -554,7 +554,7 @@ class ParserCompareBenchmark: "scenario": ctx.scenario, "provider_model": ctx.provider_model, "vision_provider_model": ctx.vision_provider_model, - "chat_model_id": ctx.chat_model_id, + "agent_llm_id": ctx.agent_llm_id, "preprocess_tariff": { "basic_per_1k_pages": 1.0, "premium_per_1k_pages": 10.0, diff --git a/surfsense_evals/src/surfsense_evals/suites/research/crag/runner.py b/surfsense_evals/src/surfsense_evals/suites/research/crag/runner.py index 654c261a2..8b759e0d8 100644 --- a/surfsense_evals/src/surfsense_evals/suites/research/crag/runner.py +++ b/surfsense_evals/src/surfsense_evals/suites/research/crag/runner.py @@ -467,7 +467,7 @@ class CragBenchmark: "provider_model": ctx.provider_model, "native_arm_model": ctx.native_arm_model, "vision_provider_model": ctx.vision_provider_model, - "chat_model_id": ctx.chat_model_id, + "agent_llm_id": ctx.agent_llm_id, "ingest_settings": ingest_settings, "per_page_char_cap": per_page_char_cap, "max_output_tokens": max_output_tokens, diff --git a/surfsense_evals/src/surfsense_evals/suites/research/frames/runner.py b/surfsense_evals/src/surfsense_evals/suites/research/frames/runner.py index 450c7ddd6..9c0e16b00 100644 --- a/surfsense_evals/src/surfsense_evals/suites/research/frames/runner.py +++ b/surfsense_evals/src/surfsense_evals/suites/research/frames/runner.py @@ -372,7 +372,7 @@ class FramesBenchmark: "provider_model": ctx.provider_model, "native_arm_model": ctx.native_arm_model, "vision_provider_model": ctx.vision_provider_model, - "chat_model_id": ctx.chat_model_id, + "agent_llm_id": ctx.agent_llm_id, "ingest_settings": ingest_settings, "bare_arm_label": "bare_llm", }, diff --git a/surfsense_evals/tests/core/test_clients.py b/surfsense_evals/tests/core/test_clients.py index aa98f0ad4..611408703 100644 --- a/surfsense_evals/tests/core/test_clients.py +++ b/surfsense_evals/tests/core/test_clients.py @@ -63,22 +63,29 @@ async def test_delete_search_space_idempotent_on_404(respx_mock, http): @pytest.mark.asyncio @respx.mock(base_url=_BASE) -async def test_set_model_roles_partial_update(respx_mock, http): - route = respx_mock.put("/api/v1/search-spaces/42/model-roles").mock( +async def test_set_llm_preferences_partial_update(respx_mock, http): + route = respx_mock.put("/api/v1/search-spaces/42/llm-preferences").mock( return_value=httpx.Response( 200, json={ - "chat_model_id": -10042, - "image_gen_model_id": None, - "vision_model_id": None, + "agent_llm_id": -10042, + "agent_llm_id": None, + "image_generation_config_id": None, + "vision_llm_config_id": None, + "agent_llm": { + "id": -10042, + "provider": "OPENROUTER", + "model_name": "anthropic/claude-sonnet-4.5", + }, }, ) ) client = SearchSpaceClient(http, _BASE) - roles = await client.set_model_roles(42, chat_model_id=-10042) - assert roles.chat_model_id == -10042 + prefs = await client.set_llm_preferences(42, agent_llm_id=-10042) + assert prefs.agent_llm_id == -10042 + assert prefs.agent_llm["provider"] == "OPENROUTER" sent_body = json.loads(route.calls[-1].request.content) - assert sent_body == {"chat_model_id": -10042} + assert sent_body == {"agent_llm_id": -10042} # --------------------------------------------------------------------------- diff --git a/surfsense_evals/tests/core/test_config.py b/surfsense_evals/tests/core/test_config.py index 6f9671c86..f7b8f7249 100644 --- a/surfsense_evals/tests/core/test_config.py +++ b/surfsense_evals/tests/core/test_config.py @@ -41,14 +41,14 @@ def test_state_roundtrip_per_suite(tmp_env): # noqa: ARG001 assert get_suite_state(config, "medical") is None state = SuiteState( search_space_id=1, - chat_model_id=-10042, + agent_llm_id=-10042, provider_model="anthropic/claude-sonnet-4.5", created_at="2026-05-11T20-30-00Z", ) set_suite_state(config, "medical", state) legal = SuiteState( search_space_id=2, - chat_model_id=-1, + agent_llm_id=-1, provider_model="openai/gpt-5", created_at="2026-05-11T21-00-00Z", ) @@ -84,19 +84,25 @@ def test_paths_are_per_suite(tmp_env): # noqa: ARG001 # --------------------------------------------------------------------------- -def test_minimal_state_defaults_to_head_to_head(): - """Missing scenario / vision / native fields default safely.""" +def test_legacy_state_back_compat_defaults_to_head_to_head(): + """state.json files written before scenarios shipped must still load. - payload = { + Missing ``scenario`` / ``vision_*`` / ``native_arm_model`` keys all + default to ``head-to-head`` / ``None`` so old setups keep working + after upgrade — the runner's behaviour exactly mirrors the legacy + one (both arms answer with ``provider_model``). + """ + + legacy = { "search_space_id": 7, - "chat_model_id": -123, + "agent_llm_id": -123, "provider_model": "anthropic/claude-sonnet-4.5", "created_at": "2026-05-11T20-30-00Z", "ingestion_maps": {}, } - state = SuiteState.from_dict(payload) + state = SuiteState.from_dict(legacy) assert state.scenario == DEFAULT_SCENARIO == "head-to-head" - assert state.vision_model_id is None + assert state.vision_llm_config_id is None assert state.vision_provider_model is None assert state.native_arm_model is None # The native arm should still answer with the same slug as SurfSense. @@ -112,7 +118,7 @@ def test_unknown_scenario_falls_back_to_default(): payload = { "search_space_id": 1, - "chat_model_id": -1, + "agent_llm_id": -1, "provider_model": "openai/gpt-5", "scenario": "unknown-scenario-name", } @@ -124,11 +130,11 @@ def test_cost_arbitrage_state_persists_native_arm_model(tmp_env): # noqa: ARG00 config = load_config() state = SuiteState( search_space_id=42, - chat_model_id=-1, + agent_llm_id=-1, provider_model="openai/gpt-5.4-mini", created_at="2026-05-11T20-30-00Z", scenario="cost-arbitrage", - vision_model_id=-101, + vision_llm_config_id=-101, vision_provider_model="anthropic/claude-sonnet-4.5", native_arm_model="anthropic/claude-sonnet-4.5", ) @@ -136,7 +142,7 @@ def test_cost_arbitrage_state_persists_native_arm_model(tmp_env): # noqa: ARG00 fetched = get_suite_state(config, "medical") assert fetched.scenario == "cost-arbitrage" - assert fetched.vision_model_id == -101 + assert fetched.vision_llm_config_id == -101 assert fetched.vision_provider_model == "anthropic/claude-sonnet-4.5" assert fetched.native_arm_model == "anthropic/claude-sonnet-4.5" # Cost arbitrage's whole point: native arm slug != surfsense slug. diff --git a/surfsense_evals/tests/test_integration_smoke.py b/surfsense_evals/tests/test_integration_smoke.py index 1c89ae5ab..493c04c25 100644 --- a/surfsense_evals/tests/test_integration_smoke.py +++ b/surfsense_evals/tests/test_integration_smoke.py @@ -27,7 +27,7 @@ async def test_smoke_against_localhost(): pytest.skip("No credentials in environment; skipping integration smoke") bundle = await acquire_token(config) async with client_with_auth(config, bundle) as client: - response = await client.get(f"{config.surfsense_api_base}/api/v1/model-connections/global") + response = await client.get(f"{config.surfsense_api_base}/api/v1/global-new-llm-configs") try: response.raise_for_status() except httpx.HTTPStatusError as exc: diff --git a/surfsense_web/.env.example b/surfsense_web/.env.example index 11646c948..5fb9d07d1 100644 --- a/surfsense_web/.env.example +++ b/surfsense_web/.env.example @@ -1,74 +1,30 @@ -# ───────────────────────────────────────────────────────────────────────────── -# Backend connectivity -# ───────────────────────────────────────────────────────────────────────────── +NEXT_PUBLIC_FASTAPI_BACKEND_URL=http://localhost:8000 -# Optional packaged-client override. Leave unset in Docker so browser requests -# use same-origin relative URLs behind Caddy. Set it for packaged clients -# (e.g. Electron) or local dev that talks to a separate backend origin. -# NEXT_PUBLIC_FASTAPI_BACKEND_URL=http://localhost:8000 +# Server-only. Internal backend URL used by Next.js server code. +FASTAPI_BACKEND_INTERNAL_URL=https://your-internal-backend.example.com -# Server-only. Internal backend URL used by Next.js server code (RSC / route -# handlers). Cannot be a relative URL. -SURFSENSE_BACKEND_INTERNAL_URL=http://backend:8000 +NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=LOCAL or GOOGLE +NEXT_PUBLIC_ETL_SERVICE=UNSTRUCTURED or LLAMACLOUD or DOCLING +NEXT_PUBLIC_ZERO_CACHE_URL=http://localhost:4848 -# ───────────────────────────────────────────────────────────────────────────── -# Runtime configuration (read at runtime by the server, no rebuild needed) -# ───────────────────────────────────────────────────────────────────────────── - -# Authentication method: LOCAL (email/password) or GOOGLE (OAuth). -AUTH_TYPE=LOCAL -# Document parsing backend: DOCLING, LLAMACLOUD, etc. -ETL_SERVICE=DOCLING -# Deployment mode: self-hosted or cloud. -DEPLOYMENT_MODE=self-hosted - -# ───────────────────────────────────────────────────────────────────────────── -# Build-time fallbacks for packaged clients (e.g. Electron) without a runtime -# config provider. Optional; Docker reads the plain runtime vars above first. -# ───────────────────────────────────────────────────────────────────────────── -# NEXT_PUBLIC_AUTH_TYPE=GOOGLE -# NEXT_PUBLIC_ETL_SERVICE=DOCLING -# NEXT_PUBLIC_DEPLOYMENT_MODE=self-hosted -# Overrides the app version shown in the UI (defaults to package.json version). -# NEXT_PUBLIC_APP_VERSION= - -# ───────────────────────────────────────────────────────────────────────────── -# Database (Contact Form, optional) -# ───────────────────────────────────────────────────────────────────────────── +# Contact Form Vars (optional) DATABASE_URL=postgresql://postgres:[YOUR-PASSWORD]@db.sdsf.supabase.co:5432/postgres -# ───────────────────────────────────────────────────────────────────────────── -# PostHog analytics (optional, leave key empty to disable) -# ───────────────────────────────────────────────────────────────────────────── +# Deployment mode (optional) +NEXT_PUBLIC_DEPLOYMENT_MODE="self-hosted" or "cloud" + +# PostHog analytics (optional, leave empty to disable) NEXT_PUBLIC_POSTHOG_KEY= -NEXT_PUBLIC_POSTHOG_HOST=https://us.i.posthog.com -# ───────────────────────────────────────────────────────────────────────────── -# Zero cache (real-time sync). Leave unset in Docker to use the same-origin -# "/zero" endpoint behind Caddy. Set it for local dev or packaged clients. -# ───────────────────────────────────────────────────────────────────────────── -# NEXT_PUBLIC_ZERO_CACHE_URL=http://localhost:4848 - -# ───────────────────────────────────────────────────────────────────────────── # Cloudflare Turnstile CAPTCHA for anonymous chat abuse prevention # Get your site key from https://dash.cloudflare.com/ -> Turnstile -# ───────────────────────────────────────────────────────────────────────────── NEXT_PUBLIC_TURNSTILE_SITE_KEY= -# ───────────────────────────────────────────────────────────────────────────── # Google AdSense (optional, only enables ads on the /free hub page). # Publisher ID from your AdSense dashboard, e.g. ca-pub-XXXXXXXXXXXXXXXX. # Leave empty to disable ad rendering entirely. -# ───────────────────────────────────────────────────────────────────────────── NEXT_PUBLIC_GOOGLE_ADSENSE_CLIENT_ID= # Ad unit slot IDs from AdSense dashboard -> Ads -> By ad unit. # Leave empty to hide individual slots while keeping the script loaded. NEXT_PUBLIC_GOOGLE_ADSENSE_SLOT_FREE_HUB_IN_CONTENT= -NEXT_PUBLIC_GOOGLE_ADSENSE_SLOT_FREE_HUB_BEFORE_FAQ= - -# ───────────────────────────────────────────────────────────────────────────── -# Global announcement banner (e.g. planned downtime / maintenance notice). -# Set ENABLED to "true" to show the banner, and put the notice text in MESSAGE. -# ───────────────────────────────────────────────────────────────────────────── -NEXT_PUBLIC_GLOBAL_ANNOUNCEMENT_ENABLED=false -NEXT_PUBLIC_GLOBAL_ANNOUNCEMENT_MESSAGE= +NEXT_PUBLIC_GOOGLE_ADSENSE_SLOT_FREE_HUB_BEFORE_FAQ= \ No newline at end of file diff --git a/surfsense_web/Dockerfile b/surfsense_web/Dockerfile index 48cc28594..d851cf045 100644 --- a/surfsense_web/Dockerfile +++ b/surfsense_web/Dockerfile @@ -35,6 +35,21 @@ RUN apk add --no-cache git # Enable pnpm RUN corepack enable pnpm +# Build with placeholder values for NEXT_PUBLIC_* variables. +# These are replaced at container startup by docker-entrypoint.js +# with real values from the container's environment variables. +ARG NEXT_PUBLIC_FASTAPI_BACKEND_URL=__NEXT_PUBLIC_FASTAPI_BACKEND_URL__ +ARG NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=__NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE__ +ARG NEXT_PUBLIC_ETL_SERVICE=__NEXT_PUBLIC_ETL_SERVICE__ +ARG NEXT_PUBLIC_ZERO_CACHE_URL=__NEXT_PUBLIC_ZERO_CACHE_URL__ +ARG NEXT_PUBLIC_DEPLOYMENT_MODE=__NEXT_PUBLIC_DEPLOYMENT_MODE__ + +ENV NEXT_PUBLIC_FASTAPI_BACKEND_URL=$NEXT_PUBLIC_FASTAPI_BACKEND_URL +ENV NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=$NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE +ENV NEXT_PUBLIC_ETL_SERVICE=$NEXT_PUBLIC_ETL_SERVICE +ENV NEXT_PUBLIC_ZERO_CACHE_URL=$NEXT_PUBLIC_ZERO_CACHE_URL +ENV NEXT_PUBLIC_DEPLOYMENT_MODE=$NEXT_PUBLIC_DEPLOYMENT_MODE + COPY --from=deps /app/node_modules ./node_modules COPY . . @@ -63,6 +78,10 @@ COPY --from=builder /app/public ./public COPY --from=builder --chown=nextjs:nodejs /app/.next/standalone/app/ ./ COPY --from=builder --chown=nextjs:nodejs /app/.next/static ./.next/static +# Entrypoint scripts for runtime env var substitution +COPY --chown=nextjs:nodejs docker-entrypoint.js ./docker-entrypoint.js +COPY --chown=nextjs:nodejs --chmod=755 docker-entrypoint.sh ./docker-entrypoint.sh + USER nextjs EXPOSE 3000 @@ -72,4 +91,4 @@ ENV PORT=3000 # server.js is created by next build from the standalone output # https://nextjs.org/docs/pages/api-reference/config/next-config-js/output ENV HOSTNAME="0.0.0.0" -CMD ["node", "server.js"] \ No newline at end of file +ENTRYPOINT ["/bin/sh", "./docker-entrypoint.sh"] \ No newline at end of file diff --git a/surfsense_web/app/(home)/free/[model_slug]/page.tsx b/surfsense_web/app/(home)/free/[model_slug]/page.tsx index 71fc925e4..e72c3d6e3 100644 --- a/surfsense_web/app/(home)/free/[model_slug]/page.tsx +++ b/surfsense_web/app/(home)/free/[model_slug]/page.tsx @@ -7,7 +7,7 @@ import { FAQJsonLd, JsonLd } from "@/components/seo/json-ld"; import { Button } from "@/components/ui/button"; import { Separator } from "@/components/ui/separator"; import type { AnonModel } from "@/contracts/types/anonymous-chat.types"; -import { SERVER_BACKEND_URL } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; interface PageProps { params: Promise<{ model_slug: string }>; @@ -16,7 +16,7 @@ interface PageProps { async function getModel(slug: string): Promise { try { const res = await fetch( - `${SERVER_BACKEND_URL}/api/v1/public/anon-chat/models/${encodeURIComponent(slug)}`, + `${BACKEND_URL}/api/v1/public/anon-chat/models/${encodeURIComponent(slug)}`, { next: { revalidate: 300 } } ); if (!res.ok) return null; @@ -28,7 +28,7 @@ async function getModel(slug: string): Promise { async function getAllModels(): Promise { try { - const res = await fetch(`${SERVER_BACKEND_URL}/api/v1/public/anon-chat/models`, { + const res = await fetch(`${BACKEND_URL}/api/v1/public/anon-chat/models`, { next: { revalidate: 300 }, }); if (!res.ok) return []; @@ -136,7 +136,7 @@ export async function generateMetadata({ params }: PageProps): Promise export async function generateStaticParams() { const models = await getAllModels(); - return models.flatMap((m) => (m.seo_slug ? [{ model_slug: m.seo_slug }] : [])); + return models.filter((m) => m.seo_slug).map((m) => ({ model_slug: m.seo_slug! })); } export default async function FreeModelPage({ params }: PageProps) { diff --git a/surfsense_web/app/(home)/free/page.tsx b/surfsense_web/app/(home)/free/page.tsx index b754502f6..0092ca2d5 100644 --- a/surfsense_web/app/(home)/free/page.tsx +++ b/surfsense_web/app/(home)/free/page.tsx @@ -16,7 +16,7 @@ import { TableRow, } from "@/components/ui/table"; import type { AnonModel } from "@/contracts/types/anonymous-chat.types"; -import { SERVER_BACKEND_URL } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; export const metadata: Metadata = { title: "Free AI Chat, No Login Required | SurfSense", @@ -94,7 +94,7 @@ export const metadata: Metadata = { async function getModels(): Promise { try { - const res = await fetch(`${SERVER_BACKEND_URL}/api/v1/public/anon-chat/models`, { + const res = await fetch(`${BACKEND_URL}/api/v1/public/anon-chat/models`, { next: { revalidate: 300 }, }); if (!res.ok) return []; @@ -246,6 +246,11 @@ export default async function FreeHubPage() { className="group flex flex-col gap-0.5" > {model.name} + {model.description && ( + + {model.description} + + )} diff --git a/surfsense_web/app/(home)/layout.tsx b/surfsense_web/app/(home)/layout.tsx index c749b10f3..57dd9919e 100644 --- a/surfsense_web/app/(home)/layout.tsx +++ b/surfsense_web/app/(home)/layout.tsx @@ -2,7 +2,6 @@ import { usePathname } from "next/navigation"; import { FooterNew } from "@/components/homepage/footer-new"; -import { GlobalAnnouncement } from "@/components/homepage/global-announcement"; import { Navbar } from "@/components/homepage/navbar"; export default function HomePageLayout({ children }: { children: React.ReactNode }) { @@ -16,7 +15,6 @@ export default function HomePageLayout({ children }: { children: React.ReactNode return (
- {children} {!isAuthPage && } diff --git a/surfsense_web/app/(home)/login/GoogleLoginButton.tsx b/surfsense_web/app/(home)/login/GoogleLoginButton.tsx index a9e1b553e..1c91f8115 100644 --- a/surfsense_web/app/(home)/login/GoogleLoginButton.tsx +++ b/surfsense_web/app/(home)/login/GoogleLoginButton.tsx @@ -3,7 +3,7 @@ import { useTranslations } from "next-intl"; import { useState } from "react"; import { Logo } from "@/components/Logo"; import { Button } from "@/components/ui/button"; -import { buildBackendUrl } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; import { trackLoginAttempt } from "@/lib/posthog/events"; import { AmbientBackground } from "./AmbientBackground"; @@ -51,7 +51,7 @@ export function GoogleLoginButton() { // cross-origin fetch requests may not be sent on subsequent redirects. // The authorize-redirect endpoint does a server-side redirect to Google // and sets the CSRF cookie properly for same-site context. - window.location.href = buildBackendUrl("/auth/google/authorize-redirect"); + window.location.href = `${BACKEND_URL}/auth/google/authorize-redirect`; }; return (
diff --git a/surfsense_web/app/(home)/login/LocalLoginForm.tsx b/surfsense_web/app/(home)/login/LocalLoginForm.tsx index 108151512..9692d35e1 100644 --- a/surfsense_web/app/(home)/login/LocalLoginForm.tsx +++ b/surfsense_web/app/(home)/login/LocalLoginForm.tsx @@ -7,10 +7,10 @@ import { useRouter } from "next/navigation"; import { useTranslations } from "next-intl"; import { useState } from "react"; import { loginMutationAtom } from "@/atoms/auth/auth-mutation.atoms"; -import { useRuntimeConfig } from "@/components/providers/runtime-config"; import { Button } from "@/components/ui/button"; import { Spinner } from "@/components/ui/spinner"; import { getAuthErrorDetails, isNetworkError } from "@/lib/auth-errors"; +import { AUTH_TYPE } from "@/lib/env-config"; import { ValidationError } from "@/lib/error"; import { trackLoginAttempt, trackLoginFailure, trackLoginSuccess } from "@/lib/posthog/events"; @@ -26,7 +26,7 @@ export function LocalLoginForm() { title: null, message: null, }); - const { authType } = useRuntimeConfig(); + const authType = AUTH_TYPE; const router = useRouter(); const [{ mutateAsync: login, isPending: isLoggingIn }] = useAtom(loginMutationAtom); diff --git a/surfsense_web/app/(home)/login/layout.tsx b/surfsense_web/app/(home)/login/layout.tsx deleted file mode 100644 index e14aec239..000000000 --- a/surfsense_web/app/(home)/login/layout.tsx +++ /dev/null @@ -1,5 +0,0 @@ -import { RuntimeConfig } from "@/components/providers/runtime-config.server"; - -export default function LoginLayout({ children }: { children: React.ReactNode }) { - return {children}; -} diff --git a/surfsense_web/app/(home)/login/page.tsx b/surfsense_web/app/(home)/login/page.tsx index 8f146f815..42a9182e9 100644 --- a/surfsense_web/app/(home)/login/page.tsx +++ b/surfsense_web/app/(home)/login/page.tsx @@ -6,10 +6,11 @@ import { useTranslations } from "next-intl"; import { Suspense, useEffect, useState } from "react"; import { toast } from "sonner"; import { Logo } from "@/components/Logo"; -import { useRuntimeConfig } from "@/components/providers/runtime-config"; import { Button } from "@/components/ui/button"; +import { useGlobalLoadingEffect } from "@/hooks/use-global-loading"; import { getAuthErrorDetails, shouldRetry } from "@/lib/auth-errors"; import { setRedirectPath } from "@/lib/auth-utils"; +import { AUTH_TYPE } from "@/lib/env-config"; import { AmbientBackground } from "./AmbientBackground"; import { GoogleLoginButton } from "./GoogleLoginButton"; import { LocalLoginForm } from "./LocalLoginForm"; @@ -18,7 +19,8 @@ function LoginContent() { const t = useTranslations("auth"); const tCommon = useTranslations("common"); const router = useRouter(); - const { authType } = useRuntimeConfig(); + const [authType, setAuthType] = useState(null); + const [isLoading, setIsLoading] = useState(true); const [urlError, setUrlError] = useState<{ title: string; message: string } | null>(null); const searchParams = useSearchParams(); @@ -94,7 +96,19 @@ function LoginContent() { duration: 4000, }); } - }, [router, searchParams, t, tCommon]); + + // Get the auth type from centralized config + setAuthType(AUTH_TYPE); + setIsLoading(false); + }, [searchParams, t, tCommon]); + + // Use global loading screen for auth type determination - spinner animation won't reset + useGlobalLoadingEffect(isLoading); + + // Show nothing while loading - the GlobalLoadingProvider handles the loading UI + if (isLoading) { + return null; + } if (authType === "GOOGLE") { return ; diff --git a/surfsense_web/app/(home)/register/layout.tsx b/surfsense_web/app/(home)/register/layout.tsx deleted file mode 100644 index 361df85c1..000000000 --- a/surfsense_web/app/(home)/register/layout.tsx +++ /dev/null @@ -1,5 +0,0 @@ -import { RuntimeConfig } from "@/components/providers/runtime-config.server"; - -export default function RegisterLayout({ children }: { children: React.ReactNode }) { - return {children}; -} diff --git a/surfsense_web/app/(home)/register/page.tsx b/surfsense_web/app/(home)/register/page.tsx index 9421a0156..1fd1a4ecb 100644 --- a/surfsense_web/app/(home)/register/page.tsx +++ b/surfsense_web/app/(home)/register/page.tsx @@ -9,11 +9,11 @@ import { useEffect, useState } from "react"; import { type ExternalToast, toast } from "sonner"; import { registerMutationAtom } from "@/atoms/auth/auth-mutation.atoms"; import { Logo } from "@/components/Logo"; -import { useRuntimeConfig } from "@/components/providers/runtime-config"; import { Button } from "@/components/ui/button"; import { Spinner } from "@/components/ui/spinner"; import { getAuthErrorDetails, isNetworkError, shouldRetry } from "@/lib/auth-errors"; import { getBearerToken } from "@/lib/auth-utils"; +import { AUTH_TYPE } from "@/lib/env-config"; import { AppError, ValidationError } from "@/lib/error"; import { trackRegistrationAttempt, @@ -25,7 +25,6 @@ import { AmbientBackground } from "../login/AmbientBackground"; export default function RegisterPage() { const t = useTranslations("auth"); const tCommon = useTranslations("common"); - const { authType } = useRuntimeConfig(); const [email, setEmail] = useState(""); const [password, setPassword] = useState(""); const [confirmPassword, setConfirmPassword] = useState(""); @@ -45,10 +44,10 @@ export default function RegisterPage() { router.replace("/dashboard"); return; } - if (authType !== "LOCAL") { + if (AUTH_TYPE !== "LOCAL") { router.push("/login"); } - }, [authType, router]); + }, [router]); const handleSubmit = (e: React.FormEvent) => { e.preventDefault(); diff --git a/surfsense_web/app/api/v1/[...path]/route.ts b/surfsense_web/app/api/v1/[...path]/route.ts index 66ea78af5..418bf1a33 100644 --- a/surfsense_web/app/api/v1/[...path]/route.ts +++ b/surfsense_web/app/api/v1/[...path]/route.ts @@ -14,11 +14,7 @@ const HOP_BY_HOP_HEADERS = new Set([ ]); function getBackendBaseUrl() { - const base = - process.env.SURFSENSE_BACKEND_INTERNAL_URL || - // TODO: Remove FASTAPI_BACKEND_INTERNAL_URL after the post-Caddy env migration window. - process.env.FASTAPI_BACKEND_INTERNAL_URL || - "http://backend:8000"; + const base = process.env.FASTAPI_BACKEND_INTERNAL_URL || "http://localhost:8000"; return base.endsWith("/") ? base.slice(0, -1) : base; } diff --git a/surfsense_web/app/api/zero/query/route.ts b/surfsense_web/app/api/zero/query/route.ts index f08b012e7..35ef51fb5 100644 --- a/surfsense_web/app/api/zero/query/route.ts +++ b/surfsense_web/app/api/zero/query/route.ts @@ -1,7 +1,7 @@ import { mustGetQuery } from "@rocicorp/zero"; import { handleQueryRequest } from "@rocicorp/zero/server"; import { NextResponse } from "next/server"; -import { SERVER_BACKEND_URL } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; import type { Context } from "@/types/zero"; import { queries } from "@/zero/queries"; import { schema } from "@/zero/schema"; @@ -11,7 +11,11 @@ import { schema } from "@/zero/schema"; // (e.g. http://backend:8000). The browser-facing NEXT_PUBLIC_FASTAPI_BACKEND_URL // (e.g. http://localhost:8929) does NOT resolve from inside the frontend // container and would make every authenticated Zero query fail with a 503. -const backendURL = SERVER_BACKEND_URL.replace(/\/$/, ""); +const backendURL = ( + process.env.FASTAPI_BACKEND_INTERNAL_URL || + process.env.BACKEND_URL || + "http://localhost:8000" +).replace(/\/$/, ""); async function authenticateRequest( request: Request diff --git a/surfsense_web/app/auth/[...path]/route.ts b/surfsense_web/app/auth/[...path]/route.ts deleted file mode 100644 index 923f6eef3..000000000 --- a/surfsense_web/app/auth/[...path]/route.ts +++ /dev/null @@ -1,74 +0,0 @@ -import type { NextRequest } from "next/server"; - -export const dynamic = "force-dynamic"; - -const HOP_BY_HOP_HEADERS = new Set([ - "connection", - "keep-alive", - "proxy-authenticate", - "proxy-authorization", - "te", - "trailer", - "transfer-encoding", - "upgrade", -]); - -function getBackendBaseUrl() { - const base = - process.env.SURFSENSE_BACKEND_INTERNAL_URL || - // TODO: Remove FASTAPI_BACKEND_INTERNAL_URL after the post-Caddy env migration window. - process.env.FASTAPI_BACKEND_INTERNAL_URL || - "http://backend:8000"; - return base.endsWith("/") ? base.slice(0, -1) : base; -} - -function toUpstreamHeaders(headers: Headers) { - const nextHeaders = new Headers(headers); - nextHeaders.delete("host"); - nextHeaders.delete("content-length"); - return nextHeaders; -} - -function toClientHeaders(headers: Headers) { - const nextHeaders = new Headers(headers); - for (const header of HOP_BY_HOP_HEADERS) { - nextHeaders.delete(header); - } - return nextHeaders; -} - -async function proxy(request: NextRequest, context: { params: Promise<{ path?: string[] }> }) { - const params = await context.params; - const path = params.path?.join("/") || ""; - const upstreamUrl = new URL(`${getBackendBaseUrl()}/auth/${path}`); - upstreamUrl.search = request.nextUrl.search; - - const hasBody = request.method !== "GET" && request.method !== "HEAD"; - - const response = await fetch(upstreamUrl, { - method: request.method, - headers: toUpstreamHeaders(request.headers), - body: hasBody ? request.body : undefined, - // `duplex: "half"` is required by the Fetch spec when streaming a - // ReadableStream as the request body. Avoids buffering uploads in heap. - // @ts-expect-error - `duplex` is not yet in lib.dom RequestInit types. - duplex: hasBody ? "half" : undefined, - redirect: "manual", - }); - - return new Response(response.body, { - status: response.status, - statusText: response.statusText, - headers: toClientHeaders(response.headers), - }); -} - -export { - proxy as GET, - proxy as POST, - proxy as PUT, - proxy as PATCH, - proxy as DELETE, - proxy as OPTIONS, - proxy as HEAD, -}; diff --git a/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-empty-state.tsx b/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-empty-state.tsx index 1ee71c636..b2e7b2532 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-empty-state.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-empty-state.tsx @@ -1,5 +1,5 @@ "use client"; -import { AlarmClock } from "lucide-react"; +import { Workflow } from "lucide-react"; import Link from "next/link"; import { Button } from "@/components/ui/button"; @@ -18,7 +18,7 @@ export function AutomationsEmptyState({ searchSpaceId, canCreate }: AutomationsE return (
- +

No automations yet

diff --git a/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-table.tsx b/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-table.tsx index 74c604173..8314a5179 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-table.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-table.tsx @@ -1,5 +1,5 @@ "use client"; -import { AlarmClock, CalendarDays, Info } from "lucide-react"; +import { CalendarDays, Info, Workflow } from "lucide-react"; import { Table, TableBody, TableHead, TableHeader, TableRow } from "@/components/ui/table"; import type { AutomationSummary } from "@/contracts/types/automation.types"; import { AutomationRow } from "./automation-row"; @@ -31,7 +31,7 @@ export function AutomationsTable({ - + Name diff --git a/surfsense_web/app/dashboard/[search_space_id]/automations/components/builder/automation-builder-form.tsx b/surfsense_web/app/dashboard/[search_space_id]/automations/components/builder/automation-builder-form.tsx index a68e53a1c..59967080f 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/automations/components/builder/automation-builder-form.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/automations/components/builder/automation-builder-form.tsx @@ -130,7 +130,7 @@ export function AutomationBuilderForm({ // data into state, so there's no flicker/loop and the user's pick is sticky. const resolvedModels = useMemo( () => ({ - chatModelId: form.models.chatModelId || eligibleModels.llm.defaultId || 0, + agentLlmId: form.models.agentLlmId || eligibleModels.llm.defaultId || 0, imageConfigId: form.models.imageConfigId || eligibleModels.image.defaultId || 0, visionConfigId: form.models.visionConfigId || eligibleModels.vision.defaultId || 0, }), diff --git a/surfsense_web/app/dashboard/[search_space_id]/automations/components/builder/automation-model-fields.tsx b/surfsense_web/app/dashboard/[search_space_id]/automations/components/builder/automation-model-fields.tsx index 6dd42366b..2c4a0bf60 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/automations/components/builder/automation-model-fields.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/automations/components/builder/automation-model-fields.tsx @@ -25,7 +25,7 @@ import { getProviderIcon } from "@/lib/provider-icons"; import { Field } from "./form-field"; export interface AutomationModelSelection { - chatModelId: number; + agentLlmId: number; imageConfigId: number; visionConfigId: number; } @@ -39,7 +39,7 @@ interface AutomationModelFieldsProps { } /** - * Three eligible-only model pickers (Chat / Image / Vision) for the + * Three eligible-only model pickers (Agent LLM / Image / Vision) for the * automation builder + chat approval card. Options come from * {@link useAutomationEligibleModels} (premium globals + BYOK only); selection * is validated + snapshotted onto `definition.models` at create time. @@ -51,18 +51,18 @@ export function AutomationModelFields({ errors, }: AutomationModelFieldsProps) { const { llm, image, vision, isLoading } = useAutomationEligibleModels(); - const rolesHref = `/dashboard/${searchSpaceId}/search-space-settings/models`; + const rolesHref = `/dashboard/${searchSpaceId}/search-space-settings/roles`; return (

onChange({ chatModelId: id })} + error={errors?.agentLlmId} + onChange={(id) => onChange({ agentLlmId: id })} /> { + return isLlmOnboardingComplete(preferences.agent_llm_id, globalConfigs.length > 0); + }, [preferences.agent_llm_id, globalConfigs.length]); const { data: access = null, isLoading: accessLoading } = useAtomValue(myAccessAtom); const [hasCheckedOnboarding, setHasCheckedOnboarding] = useState(false); + const [isAutoConfiguring, setIsAutoConfiguring] = useState(false); + const hasAttemptedAutoConfig = useRef(false); const isOnboardingPage = pathname?.includes("/onboard"); const isOwner = access?.is_owner ?? false; - const isSearchSpaceReady = activeSearchSpaceId === searchSpaceId; - - useEffect(() => { - if (isSearchSpaceReady) return; - setHasCheckedOnboarding(false); - }, [isSearchSpaceReady]); useEffect(() => { if (isOnboardingPage) { @@ -66,27 +66,13 @@ export function DashboardClientLayout({ } if ( - isSearchSpaceReady && !loading && !accessLoading && !globalConfigsLoading && - !globalConfigStatusLoading && - !modelConnectionsLoading && - !hasCheckedOnboarding + !hasCheckedOnboarding && + !isAutoConfiguring ) { - // Onboarding is only relevant when no operator-provided - // global_llm_config.yaml exists. When it does, search spaces inherit - // the global config and should never be forced into onboarding. - if (globalConfigStatus?.exists) { - setHasCheckedOnboarding(true); - return; - } - - const onboardingComplete = isLlmOnboardingComplete( - modelRoles.chat_model_id, - globalConnections, - modelConnections - ); + const onboardingComplete = isOnboardingComplete(); if (onboardingComplete) { setHasCheckedOnboarding(true); @@ -98,25 +84,56 @@ export function DashboardClientLayout({ return; } + if (globalConfigs.length > 0 && !hasAttemptedAutoConfig.current) { + hasAttemptedAutoConfig.current = true; + setIsAutoConfiguring(true); + + const autoConfigureWithGlobal = async () => { + try { + const firstGlobalConfig = globalConfigs[0]; + await updatePreferences({ + search_space_id: Number(searchSpaceId), + data: { + agent_llm_id: firstGlobalConfig.id, + }, + }); + + await refetchPreferences(); + + toast.success("AI configured automatically!", { + description: `Using ${firstGlobalConfig.name}. Customize in Settings.`, + }); + + setHasCheckedOnboarding(true); + } catch (error) { + console.error("Auto-configuration failed:", error); + router.push(`/dashboard/${searchSpaceId}/onboard`); + } finally { + setIsAutoConfiguring(false); + } + }; + + autoConfigureWithGlobal(); + return; + } + router.push(`/dashboard/${searchSpaceId}/onboard`); setHasCheckedOnboarding(true); } }, [ - isSearchSpaceReady, loading, accessLoading, globalConfigsLoading, - globalConfigStatusLoading, - globalConfigStatus, - modelConnectionsLoading, - modelRoles.chat_model_id, - globalConnections, - modelConnections, + isOnboardingComplete, isOnboardingPage, isOwner, + isAutoConfiguring, + globalConfigs, router, searchSpaceId, hasCheckedOnboarding, + updatePreferences, + refetchPreferences, ]); const electronAPI = useElectronAPI(); @@ -168,14 +185,10 @@ export function DashboardClientLayout({ // Determine if we should show loading const shouldShowLoading = - !hasCheckedOnboarding && - (!isSearchSpaceReady || - loading || - accessLoading || - globalConfigsLoading || - globalConfigStatusLoading || - modelConnectionsLoading) && - !isOnboardingPage; + (!hasCheckedOnboarding && + (loading || accessLoading || globalConfigsLoading) && + !isOnboardingPage) || + isAutoConfiguring; // Use global loading screen - spinner animation won't reset useGlobalLoadingEffect(shouldShowLoading); diff --git a/surfsense_web/app/dashboard/[search_space_id]/connectors/callback/route.ts b/surfsense_web/app/dashboard/[search_space_id]/connectors/callback/route.ts index 96f7dc349..304f33a33 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/connectors/callback/route.ts +++ b/surfsense_web/app/dashboard/[search_space_id]/connectors/callback/route.ts @@ -16,12 +16,9 @@ export async function GET( }; const result = JSON.stringify(payload); - const response = new NextResponse(null, { - status: 302, - headers: { - Location: `/dashboard/${search_space_id}/new-chat`, - }, - }); + const redirectUrl = new URL(`/dashboard/${search_space_id}/new-chat`, request.url); + + const response = NextResponse.redirect(redirectUrl, { status: 302 }); response.cookies.set(OAUTH_RESULT_COOKIE, result, { path: "/", maxAge: 60, diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 3594e15eb..f048376cc 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -106,7 +106,7 @@ import { extractUserTurnForNewChatApi, type NewChatUserImagePayload, } from "@/lib/chat/user-turn-api-parts"; -import { buildBackendUrl } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; import { NotFoundError } from "@/lib/error"; import { trackChatBlocked, @@ -613,18 +613,6 @@ export default function NewChatPage() { return; } - if (normalized.channel === "inline") { - if (normalized.assistantMessage) { - await persistAssistantErrorMessage({ - threadId, - assistantMsgId, - text: normalized.assistantMessage, - }); - } - toast.error(normalized.userMessage); - return; - } - toast.error(normalized.userMessage); }, [currentUser?.id, persistAssistantErrorMessage, searchSpaceId, setPremiumAlertForThread] @@ -919,9 +907,10 @@ export default function NewChatPage() { if (threadId) { const token = getBearerToken(); if (token) { + const backendUrl = BACKEND_URL; try { const response = await fetch( - buildBackendUrl(`/api/v1/threads/${threadId}/cancel-active-turn`), + `${backendUrl}/api/v1/threads/${threadId}/cancel-active-turn`, { method: "POST", headers: { @@ -1109,6 +1098,7 @@ export default function NewChatPage() { let streamBatcher: FrameBatchedUpdater | null = null; try { + const backendUrl = BACKEND_URL; const selection = await getAgentFilesystemSelection(searchSpaceId, { localFilesystemEnabled, }); @@ -1145,7 +1135,7 @@ export default function NewChatPage() { } const response = await fetchWithTurnCancellingRetry(() => - fetch(buildBackendUrl("/api/v1/new_chat"), { + fetch(`${backendUrl}/api/v1/new_chat`, { method: "POST", headers: { "Content-Type": "application/json", @@ -1640,11 +1630,12 @@ export default function NewChatPage() { } try { + const backendUrl = BACKEND_URL; const selection = await getAgentFilesystemSelection(searchSpaceId, { localFilesystemEnabled, }); const response = await fetchWithTurnCancellingRetry(() => - fetch(buildBackendUrl(`/api/v1/threads/${resumeThreadId}/resume`), { + fetch(`${backendUrl}/api/v1/threads/${resumeThreadId}/resume`, { method: "POST", headers: { "Content-Type": "application/json", diff --git a/surfsense_web/app/dashboard/[search_space_id]/onboard/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/onboard/page.tsx index 8efe81cce..de5c961e8 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/onboard/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/onboard/page.tsx @@ -2,89 +2,193 @@ import { useAtomValue } from "jotai"; import { useParams, useRouter } from "next/navigation"; -import { useEffect, useMemo } from "react"; +import { useEffect, useRef, useState } from "react"; +import { toast } from "sonner"; import { - globalLlmConfigStatusAtom, - globalModelConnectionsAtom, - modelConnectionsAtom, - modelRolesAtom, -} from "@/atoms/model-connections/model-connections-query.atoms"; + createNewLLMConfigMutationAtom, + updateLLMPreferencesMutationAtom, +} from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; +import { + globalNewLLMConfigsAtom, + llmPreferencesAtom, +} from "@/atoms/new-llm-config/new-llm-config-query.atoms"; import { Logo } from "@/components/Logo"; -import { ModelProviderConnectionsPanel } from "@/components/settings/model-connections/model-provider-connections-panel"; +import { LLMConfigForm, type LLMConfigFormData } from "@/components/shared/llm-config-form"; import { Button } from "@/components/ui/button"; +import { Spinner } from "@/components/ui/spinner"; import { useGlobalLoadingEffect } from "@/hooks/use-global-loading"; import { getBearerToken, redirectToLogin } from "@/lib/auth-utils"; -import { hasEnabledChatModel, isLlmOnboardingComplete } from "@/lib/onboarding"; +import { isLlmOnboardingComplete } from "@/lib/onboarding"; export default function OnboardPage() { const router = useRouter(); const params = useParams(); const searchSpaceId = Number(params.search_space_id); - const { data: globalConnections = [], isLoading: globalLoading } = useAtomValue( - globalModelConnectionsAtom - ); - const { data: connections = [] } = useAtomValue(modelConnectionsAtom); - const { data: roles = {}, isLoading: rolesLoading } = useAtomValue(modelRolesAtom); - const { data: globalConfigStatus, isLoading: globalConfigStatusLoading } = - useAtomValue(globalLlmConfigStatusAtom); + // Queries + const { + data: globalConfigs = [], + isFetching: globalConfigsLoading, + isSuccess: globalConfigsLoaded, + } = useAtomValue(globalNewLLMConfigsAtom); + const { data: preferences = {}, isFetching: preferencesLoading } = + useAtomValue(llmPreferencesAtom); + // Mutations + const { mutateAsync: createConfig, isPending: isCreating } = useAtomValue( + createNewLLMConfigMutationAtom + ); + const { mutateAsync: updatePreferences, isPending: isUpdatingPreferences } = useAtomValue( + updateLLMPreferencesMutationAtom + ); + + // State + const [isAutoConfiguring, setIsAutoConfiguring] = useState(false); + const hasAttemptedAutoConfig = useRef(false); + + // Check authentication useEffect(() => { - if (!getBearerToken()) redirectToLogin(); + const token = getBearerToken(); + if (!token) { + redirectToLogin(); + } }, []); - const hasUsableChatModel = useMemo( - () => hasEnabledChatModel([...globalConnections, ...connections]), - [globalConnections, connections] + const isOnboardingComplete = isLlmOnboardingComplete( + preferences.agent_llm_id, + globalConfigs.length > 0 ); - const onboardingComplete = isLlmOnboardingComplete( - roles.chat_model_id, - globalConnections, - connections - ); - - const isLoading = globalLoading || rolesLoading || globalConfigStatusLoading; - - // Onboarding only applies when no global_llm_config.yaml exists. If a global - // config is present (or onboarding is already complete), leave this page. - const shouldLeaveOnboarding = - !isLoading && (Boolean(globalConfigStatus?.exists) || onboardingComplete); - useEffect(() => { - if (shouldLeaveOnboarding) { - router.replace(`/dashboard/${searchSpaceId}/new-chat`); + if (!preferencesLoading && globalConfigsLoaded && isOnboardingComplete) { + router.push(`/dashboard/${searchSpaceId}/new-chat`); } - }, [shouldLeaveOnboarding, router, searchSpaceId]); + }, [preferencesLoading, globalConfigsLoaded, isOnboardingComplete, router, searchSpaceId]); - useGlobalLoadingEffect(isLoading || shouldLeaveOnboarding); + useEffect(() => { + const autoConfigureWithGlobal = async () => { + if (hasAttemptedAutoConfig.current) return; + if (globalConfigsLoading || preferencesLoading) return; + if (!globalConfigsLoaded) return; + if (isOnboardingComplete) return; - if (isLoading || shouldLeaveOnboarding) return null; + if (globalConfigs.length > 0) { + hasAttemptedAutoConfig.current = true; + setIsAutoConfiguring(true); + + try { + const firstGlobalConfig = globalConfigs[0]; + + await updatePreferences({ + search_space_id: searchSpaceId, + data: { + agent_llm_id: firstGlobalConfig.id, + }, + }); + + toast.success("AI configured automatically!", { + description: `Using ${firstGlobalConfig.name}. You can customize this later in Settings.`, + }); + + router.push(`/dashboard/${searchSpaceId}/new-chat`); + } catch (error) { + console.error("Auto-configuration failed:", error); + toast.error("Auto-configuration failed. Please add a configuration manually."); + setIsAutoConfiguring(false); + } + } + }; + + autoConfigureWithGlobal(); + }, [ + globalConfigs, + globalConfigsLoading, + globalConfigsLoaded, + preferencesLoading, + isOnboardingComplete, + updatePreferences, + searchSpaceId, + router, + ]); + + const handleSubmit = async (formData: LLMConfigFormData) => { + try { + const newConfig = await createConfig(formData); + + await updatePreferences({ + search_space_id: searchSpaceId, + data: { + agent_llm_id: newConfig.id, + }, + }); + + toast.success("Configuration created!", { + description: "Redirecting to chat...", + }); + + router.push(`/dashboard/${searchSpaceId}/new-chat`); + } catch (error) { + console.error("Failed to create config:", error); + if (error instanceof Error) { + toast.error(error.message || "Failed to create configuration"); + } + } + }; + + const isSubmitting = isCreating || isUpdatingPreferences; + + const isLoading = globalConfigsLoading || preferencesLoading || isAutoConfiguring; + useGlobalLoadingEffect(isLoading); + + if (isLoading) { + return null; + } + + if (globalConfigs.length > 0 && !isAutoConfiguring) { + return null; + } return ( -
-
- -
-

Choose a model

-

- Connect any supported provider, then enable the models you want SurfSense to use. -

+
+
+ {/* Header */} +
+ +
+

Configure Your AI

+

+ Add your LLM provider to get started with SurfSense +

+
+
+ + {/* Form card */} +
+ +
+ + {/* Footer */} +
+ +

You can add more configurations later

- router.push(`/dashboard/${searchSpaceId}/new-chat`)} - > - Start - - } - showAddProviderHeader={false} - />
); diff --git a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/image-models/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/image-models/page.tsx new file mode 100644 index 000000000..b300f8078 --- /dev/null +++ b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/image-models/page.tsx @@ -0,0 +1,6 @@ +import { ImageModelManager } from "@/components/settings/image-model-manager"; + +export default async function Page({ params }: { params: Promise<{ search_space_id: string }> }) { + const { search_space_id } = await params; + return ; +} diff --git a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/layout-shell.tsx b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/layout-shell.tsx index bb928f8f7..22f68edab 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/layout-shell.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/layout-shell.tsx @@ -1,6 +1,15 @@ "use client"; -import { BookText, Cpu, Earth, Settings, UserKey } from "lucide-react"; +import { + BookText, + Bot, + CircleUser, + Earth, + ImageIcon, + ListChecks, + ScanEye, + UserKey, +} from "lucide-react"; import Link from "next/link"; import { useSelectedLayoutSegment } from "next/navigation"; import { useTranslations } from "next-intl"; @@ -11,7 +20,10 @@ import { cn } from "@/lib/utils"; export type SearchSpaceSettingsTab = | "general" + | "roles" | "models" + | "image-models" + | "vision-models" | "team-roles" | "prompts" | "public-links"; @@ -43,12 +55,27 @@ export function SearchSpaceSettingsLayoutShell({ { value: "general" as const, label: t("nav_general"), - icon: , + icon: , + }, + { + value: "roles" as const, + label: t("nav_role_assignments"), + icon: , }, { value: "models" as const, - label: t("nav_models"), - icon: , + label: t("nav_agent_models"), + icon: , + }, + { + value: "image-models" as const, + label: t("nav_image_models"), + icon: , + }, + { + value: "vision-models" as const, + label: t("nav_vision_models"), + icon: , }, { value: "team-roles" as const, diff --git a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/models/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/models/page.tsx index c97ef7630..d68194782 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/models/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/models/page.tsx @@ -1,6 +1,6 @@ -import { ModelConnectionsSettings } from "@/components/settings/model-connections-settings"; +import { AgentModelManager } from "@/components/settings/agent-model-manager"; export default async function Page({ params }: { params: Promise<{ search_space_id: string }> }) { const { search_space_id } = await params; - return ; + return ; } diff --git a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/roles/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/roles/page.tsx new file mode 100644 index 000000000..5bad50cd3 --- /dev/null +++ b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/roles/page.tsx @@ -0,0 +1,6 @@ +import { LLMRoleManager } from "@/components/settings/llm-role-manager"; + +export default async function Page({ params }: { params: Promise<{ search_space_id: string }> }) { + const { search_space_id } = await params; + return ; +} diff --git a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/vision-models/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/vision-models/page.tsx new file mode 100644 index 000000000..06aea003a --- /dev/null +++ b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/vision-models/page.tsx @@ -0,0 +1,6 @@ +import { VisionModelManager } from "@/components/settings/vision-model-manager"; + +export default async function Page({ params }: { params: Promise<{ search_space_id: string }> }) { + const { search_space_id } = await params; + return ; +} diff --git a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/MessagingChannelsContent.tsx b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/MessagingChannelsContent.tsx index 4a3c5e9e7..b0cb6699c 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/MessagingChannelsContent.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/MessagingChannelsContent.tsx @@ -1,11 +1,10 @@ "use client"; -import { AlertTriangle, RefreshCw, ShieldAlert } from "lucide-react"; +import { RefreshCw, ShieldAlert } from "lucide-react"; import { useParams } from "next/navigation"; import { QRCodeSVG } from "qrcode.react"; import { useCallback, useEffect, useState } from "react"; import { toast } from "sonner"; -import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; import { Button } from "@/components/ui/button"; import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card"; import { @@ -20,7 +19,7 @@ import { Skeleton } from "@/components/ui/skeleton"; import type { SearchSpace } from "@/contracts/types/search-space.types"; import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service"; import { authenticatedFetch } from "@/lib/auth-utils"; -import { buildBackendUrl } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; import { cn } from "@/lib/utils"; type GatewayConnection = { @@ -40,7 +39,6 @@ type GatewayConnection = { }; type GatewayConfig = { - enabled: boolean; telegram_enabled: boolean; whatsapp_intake_mode: "disabled" | "cloud" | "baileys"; slack_enabled: boolean; @@ -49,14 +47,6 @@ type GatewayConfig = { type GatewayConfigState = GatewayConfig | null; -const DISABLED_GATEWAY_CONFIG: GatewayConfig = { - enabled: false, - telegram_enabled: false, - whatsapp_intake_mode: "disabled", - slack_enabled: false, - discord_enabled: false, -}; - type Pairing = { binding_id: number; code: string; @@ -90,26 +80,16 @@ export function MessagingChannelsContent() { const whatsappMode = gatewayConfig?.whatsapp_intake_mode ?? "disabled"; const slackGatewayEnabled = gatewayConfig?.slack_enabled ?? false; const discordGatewayEnabled = gatewayConfig?.discord_enabled ?? false; - const gatewayDisabled = gatewayConfig?.enabled === false; const fetchConnections = useCallback(async (platform?: GatewayPlatform) => { - const res = await authenticatedFetch( - buildBackendUrl("/api/v1/gateway/connections", platform ? { platform } : undefined) - ); - if (!res.ok) return []; - const data = await res.json(); - return Array.isArray(data) ? (data as GatewayConnection[]) : []; + const query = platform ? `?platform=${encodeURIComponent(platform)}` : ""; + const res = await authenticatedFetch(`${BACKEND_URL}/api/v1/gateway/connections${query}`); + return (await res.json()) as GatewayConnection[]; }, []); - const fetchGatewayConfig = useCallback(async (): Promise => { - const res = await authenticatedFetch(buildBackendUrl("/api/v1/gateway/config")); - if (!res.ok) return DISABLED_GATEWAY_CONFIG; - const data = (await res.json()) as Partial; - return { - ...DISABLED_GATEWAY_CONFIG, - ...data, - enabled: data.enabled ?? true, - }; + const fetchGatewayConfig = useCallback(async () => { + const res = await authenticatedFetch(`${BACKEND_URL}/api/v1/gateway/config`); + return (await res.json()) as GatewayConfig; }, []); const refresh = useCallback(async () => { @@ -145,9 +125,7 @@ export function MessagingChannelsContent() { const refreshBaileysHealth = useCallback(async () => { if (whatsappMode !== "baileys") return; - const res = await authenticatedFetch( - buildBackendUrl("/api/v1/gateway/whatsapp/baileys/health") - ); + const res = await authenticatedFetch(`${BACKEND_URL}/api/v1/gateway/whatsapp/baileys/health`); if (!res.ok) return; const data = (await res.json()) as BaileysHealth; setBaileysHealth(data); @@ -158,7 +136,7 @@ export function MessagingChannelsContent() { }, [refreshBaileysHealth]); async function startPairing(platform: PairingPlatform) { - const res = await authenticatedFetch(buildBackendUrl("/api/v1/gateway/bindings/start"), { + const res = await authenticatedFetch(`${BACKEND_URL}/api/v1/gateway/bindings/start`, { method: "POST", headers: { "Content-Type": "application/json" }, body: JSON.stringify({ platform, search_space_id: searchSpaceId }), @@ -170,7 +148,7 @@ export function MessagingChannelsContent() { async function installSlackGateway() { const res = await authenticatedFetch( - buildBackendUrl("/api/v1/gateway/slack/install", { search_space_id: searchSpaceId }) + `${BACKEND_URL}/api/v1/gateway/slack/install?search_space_id=${searchSpaceId}` ); if (!res.ok) return; const data = (await res.json()) as { auth_url?: string }; @@ -181,7 +159,7 @@ export function MessagingChannelsContent() { async function installDiscordGateway() { const res = await authenticatedFetch( - buildBackendUrl("/api/v1/gateway/discord/install", { search_space_id: searchSpaceId }) + `${BACKEND_URL}/api/v1/gateway/discord/install?search_space_id=${searchSpaceId}` ); if (!res.ok) return; const data = (await res.json()) as { auth_url?: string }; @@ -203,8 +181,8 @@ export function MessagingChannelsContent() { async function revoke(connection: GatewayConnection) { const url = connection.route_type === "account" && connection.account_id - ? buildBackendUrl(`/api/v1/gateway/accounts/${connection.account_id}`) - : buildBackendUrl(`/api/v1/gateway/bindings/${connection.id}`); + ? `${BACKEND_URL}/api/v1/gateway/accounts/${connection.account_id}` + : `${BACKEND_URL}/api/v1/gateway/bindings/${connection.id}`; await authenticatedFetch(url, { method: "DELETE", }); @@ -227,8 +205,8 @@ export function MessagingChannelsContent() { ); const url = connection.route_type === "account" && connection.account_id - ? buildBackendUrl(`/api/v1/gateway/accounts/${connection.account_id}/search-space`) - : buildBackendUrl(`/api/v1/gateway/bindings/${connection.id}/search-space`); + ? `${BACKEND_URL}/api/v1/gateway/accounts/${connection.account_id}/search-space` + : `${BACKEND_URL}/api/v1/gateway/bindings/${connection.id}/search-space`; const res = await authenticatedFetch(url, { method: "PATCH", headers: { "Content-Type": "application/json" }, @@ -244,7 +222,7 @@ export function MessagingChannelsContent() { } async function resume(connection: GatewayConnection) { - await authenticatedFetch(buildBackendUrl(`/api/v1/gateway/bindings/${connection.id}/resume`), { + await authenticatedFetch(`${BACKEND_URL}/api/v1/gateway/bindings/${connection.id}/resume`, { method: "POST", }); await refreshPlatform(connection.platform as GatewayPlatform); @@ -403,21 +381,7 @@ export function MessagingChannelsContent() {
{isGatewayConfigLoading ? renderGatewaySkeletons() : null} - {!isGatewayConfigLoading && gatewayDisabled ? ( - - - Messaging Channels coming soon - -

- Soon you'll be able to connect WhatsApp, Telegram, Slack, and Discord to your - SurfSense agent so you can ask questions, route messages to search spaces, and get - answers from your knowledge base without leaving your chat app. -

-
-
- ) : null} - - {!isGatewayConfigLoading && !gatewayDisabled && !hasEnabledGateway ? ( + {!isGatewayConfigLoading && !hasEnabledGateway ? ( No messaging gateways enabled @@ -425,7 +389,7 @@ export function MessagingChannelsContent() { ) : null} - {!gatewayDisabled && telegramGatewayEnabled ? ( + {telegramGatewayEnabled ? (
@@ -461,7 +425,7 @@ export function MessagingChannelsContent() { ) : null} - {!gatewayDisabled && slackGatewayEnabled ? ( + {slackGatewayEnabled ? (
@@ -493,7 +457,7 @@ export function MessagingChannelsContent() { ) : null} - {!gatewayDisabled && discordGatewayEnabled ? ( + {discordGatewayEnabled ? (
@@ -525,7 +489,7 @@ export function MessagingChannelsContent() { ) : null} - {!gatewayDisabled && whatsappMode !== "disabled" ? ( + {whatsappMode !== "disabled" ? (
diff --git a/surfsense_web/app/dashboard/dashboard-shell.tsx b/surfsense_web/app/dashboard/dashboard-shell.tsx deleted file mode 100644 index f84cd56eb..000000000 --- a/surfsense_web/app/dashboard/dashboard-shell.tsx +++ /dev/null @@ -1,42 +0,0 @@ -"use client"; - -import { useEffect, useState } from "react"; -import { USER_QUERY_KEY } from "@/atoms/user/user-query.atoms"; -import { useGlobalLoadingEffect } from "@/hooks/use-global-loading"; -import { ensureTokensFromElectron, getBearerToken, redirectToLogin } from "@/lib/auth-utils"; -import { queryClient } from "@/lib/query-client/client"; - -export function DashboardShell({ children }: { children: React.ReactNode }) { - const [isCheckingAuth, setIsCheckingAuth] = useState(true); - - // Use the global loading screen - spinner animation won't reset - useGlobalLoadingEffect(isCheckingAuth); - - useEffect(() => { - async function checkAuth() { - let token = getBearerToken(); - if (!token) { - const synced = await ensureTokensFromElectron(); - if (synced) token = getBearerToken(); - } - if (!token) { - redirectToLogin(); - return; - } - queryClient.invalidateQueries({ queryKey: [...USER_QUERY_KEY] }); - setIsCheckingAuth(false); - } - checkAuth(); - }, []); - - // Return null while loading - the global provider handles the loading UI - if (isCheckingAuth) { - return null; - } - - return ( -
-
{children}
-
- ); -} diff --git a/surfsense_web/app/dashboard/layout.tsx b/surfsense_web/app/dashboard/layout.tsx index 6212c92e7..1f5481b15 100644 --- a/surfsense_web/app/dashboard/layout.tsx +++ b/surfsense_web/app/dashboard/layout.tsx @@ -1,14 +1,46 @@ -import { RuntimeConfig } from "@/components/providers/runtime-config.server"; -import { DashboardShell } from "./dashboard-shell"; +"use client"; + +import { useEffect, useState } from "react"; +import { USER_QUERY_KEY } from "@/atoms/user/user-query.atoms"; +import { useGlobalLoadingEffect } from "@/hooks/use-global-loading"; +import { ensureTokensFromElectron, getBearerToken, redirectToLogin } from "@/lib/auth-utils"; +import { queryClient } from "@/lib/query-client/client"; interface DashboardLayoutProps { children: React.ReactNode; } export default function DashboardLayout({ children }: DashboardLayoutProps) { + const [isCheckingAuth, setIsCheckingAuth] = useState(true); + + // Use the global loading screen - spinner animation won't reset + useGlobalLoadingEffect(isCheckingAuth); + + useEffect(() => { + async function checkAuth() { + let token = getBearerToken(); + if (!token) { + const synced = await ensureTokensFromElectron(); + if (synced) token = getBearerToken(); + } + if (!token) { + redirectToLogin(); + return; + } + queryClient.invalidateQueries({ queryKey: [...USER_QUERY_KEY] }); + setIsCheckingAuth(false); + } + checkAuth(); + }, []); + + // Return null while loading - the global provider handles the loading UI + if (isCheckingAuth) { + return null; + } + return ( - - {children} - +
+
{children}
+
); } diff --git a/surfsense_web/app/desktop/login/layout.tsx b/surfsense_web/app/desktop/login/layout.tsx deleted file mode 100644 index 83556d314..000000000 --- a/surfsense_web/app/desktop/login/layout.tsx +++ /dev/null @@ -1,5 +0,0 @@ -import { RuntimeConfig } from "@/components/providers/runtime-config.server"; - -export default function DesktopLoginLayout({ children }: { children: React.ReactNode }) { - return {children}; -} diff --git a/surfsense_web/app/desktop/login/page.tsx b/surfsense_web/app/desktop/login/page.tsx index 0d91588e1..41c956f3e 100644 --- a/surfsense_web/app/desktop/login/page.tsx +++ b/surfsense_web/app/desktop/login/page.tsx @@ -8,7 +8,6 @@ import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { toast } from "sonner"; import { loginMutationAtom } from "@/atoms/auth/auth-mutation.atoms"; import { DEFAULT_SHORTCUTS, keyEventToAccelerator } from "@/components/desktop/shortcut-recorder"; -import { useIsGoogleAuth } from "@/components/providers/runtime-config"; import { Button } from "@/components/ui/button"; import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; @@ -18,8 +17,9 @@ import { Spinner } from "@/components/ui/spinner"; import { useElectronAPI } from "@/hooks/use-platform"; import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service"; import { setBearerToken } from "@/lib/auth-utils"; -import { buildBackendUrl } from "@/lib/env-config"; +import { AUTH_TYPE, BACKEND_URL } from "@/lib/env-config"; +const isGoogleAuth = AUTH_TYPE === "GOOGLE"; type ShortcutKey = "generalAssist" | "quickAsk" | "screenshotAssist"; type ShortcutMap = typeof DEFAULT_SHORTCUTS; @@ -189,7 +189,6 @@ function HotkeyRow({ export default function DesktopLoginPage() { const router = useRouter(); const api = useElectronAPI(); - const isGoogleAuth = useIsGoogleAuth(); const [{ mutateAsync: login, isPending: isLoggingIn }] = useAtom(loginMutationAtom); const [email, setEmail] = useState(""); @@ -240,7 +239,7 @@ export default function DesktopLoginPage() { const handleGoogleLogin = () => { if (isGoogleRedirecting) return; setIsGoogleRedirecting(true); - window.location.href = buildBackendUrl("/auth/google/authorize-redirect"); + window.location.href = `${BACKEND_URL}/auth/google/authorize-redirect`; }; const autoSetSearchSpace = async () => { diff --git a/surfsense_web/app/docs/layout.tsx b/surfsense_web/app/docs/layout.tsx index cc5f2118a..9311a45b4 100644 --- a/surfsense_web/app/docs/layout.tsx +++ b/surfsense_web/app/docs/layout.tsx @@ -8,15 +8,12 @@ const gridTemplate = `"sidebar header toc" "sidebar toc-popover toc" "sidebar main toc" 1fr / var(--fd-sidebar-col) minmax(0, 1fr) min-content`; -const docsSurfaceClass = - "bg-main-panel [--color-fd-background:var(--main-panel)] [--color-fd-card:var(--main-panel)] [--color-fd-popover:var(--main-panel)] [--color-fd-muted:var(--main-panel)] [--color-fd-secondary:var(--main-panel)]"; - export default function Layout({ children }: { children: ReactNode }) { return ( { try { - const res = await fetch(`${SERVER_BACKEND_URL}/api/v1/public/anon-chat/models`, { + const res = await fetch(`${BACKEND_URL}/api/v1/public/anon-chat/models`, { next: { revalidate: 3600 }, }); if (!res.ok) return []; diff --git a/surfsense_web/app/verify-token/route.ts b/surfsense_web/app/verify-token/route.ts index 9df460779..b7ed762de 100644 --- a/surfsense_web/app/verify-token/route.ts +++ b/surfsense_web/app/verify-token/route.ts @@ -1,16 +1,12 @@ import { type NextRequest, NextResponse } from "next/server"; -function getBackendBaseUrl() { - const base = - process.env.SURFSENSE_BACKEND_INTERNAL_URL || - // TODO: Remove FASTAPI_BACKEND_INTERNAL_URL after the post-Caddy env migration window. - process.env.FASTAPI_BACKEND_INTERNAL_URL || - "http://backend:8000"; - return base.replace(/\/+$/, ""); -} +const backendBaseUrl = (process.env.INTERNAL_FASTAPI_BACKEND_URL || "http://backend:8000").replace( + /\/+$/, + "" +); export async function GET(request: NextRequest) { - const response = await fetch(`${getBackendBaseUrl()}/verify-token`, { + const response = await fetch(`${backendBaseUrl}/verify-token`, { method: "GET", headers: { Authorization: request.headers.get("authorization") || "", diff --git a/surfsense_web/atoms/automations/automations-mutation.atoms.ts b/surfsense_web/atoms/automations/automations-mutation.atoms.ts index 288d97c63..a81cd1578 100644 --- a/surfsense_web/atoms/automations/automations-mutation.atoms.ts +++ b/surfsense_web/atoms/automations/automations-mutation.atoms.ts @@ -57,9 +57,9 @@ export const createAutomationMutationAtom = atomWithMutation(() => ({ task_count: variables.definition.plan.length, trigger_type: variables.triggers?.[0]?.type ?? "none", has_schedule: (variables.triggers?.length ?? 0) > 0, - chat_model_id: variables.definition.models?.chat_model_id, - image_gen_model_id: variables.definition.models?.image_gen_model_id, - vision_model_id: variables.definition.models?.vision_model_id, + agent_llm_id: variables.definition.models?.agent_llm_id, + image_generation_config_id: variables.definition.models?.image_generation_config_id, + vision_llm_config_id: variables.definition.models?.vision_llm_config_id, tags_count: variables.definition.metadata?.tags?.length, }); }, diff --git a/surfsense_web/atoms/image-gen-config/image-gen-config-mutation.atoms.ts b/surfsense_web/atoms/image-gen-config/image-gen-config-mutation.atoms.ts new file mode 100644 index 000000000..922c398c9 --- /dev/null +++ b/surfsense_web/atoms/image-gen-config/image-gen-config-mutation.atoms.ts @@ -0,0 +1,96 @@ +import { atomWithMutation } from "jotai-tanstack-query"; +import { toast } from "sonner"; +import type { + CreateImageGenConfigRequest, + CreateImageGenConfigResponse, + DeleteImageGenConfigResponse, + GetImageGenConfigsResponse, + UpdateImageGenConfigRequest, + UpdateImageGenConfigResponse, +} from "@/contracts/types/new-llm-config.types"; +import { imageGenConfigApiService } from "@/lib/apis/image-gen-config-api.service"; +import { cacheKeys } from "@/lib/query-client/cache-keys"; +import { queryClient } from "@/lib/query-client/client"; +import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; + +/** + * Mutation atom for creating a new ImageGenerationConfig + */ +export const createImageGenConfigMutationAtom = atomWithMutation((get) => { + const searchSpaceId = get(activeSearchSpaceIdAtom); + + return { + mutationKey: ["image-gen-configs", "create"], + meta: { suppressGlobalErrorToast: true }, + enabled: !!searchSpaceId, + mutationFn: async (request: CreateImageGenConfigRequest) => { + return imageGenConfigApiService.createConfig(request); + }, + onSuccess: (_: CreateImageGenConfigResponse, request: CreateImageGenConfigRequest) => { + toast.success(`${request.name} created`); + queryClient.invalidateQueries({ + queryKey: cacheKeys.imageGenConfigs.all(Number(searchSpaceId)), + }); + }, + onError: (error: Error) => { + toast.error(error.message || "Failed to create image model"); + }, + }; +}); + +/** + * Mutation atom for updating an existing ImageGenerationConfig + */ +export const updateImageGenConfigMutationAtom = atomWithMutation((get) => { + const searchSpaceId = get(activeSearchSpaceIdAtom); + + return { + mutationKey: ["image-gen-configs", "update"], + meta: { suppressGlobalErrorToast: true }, + enabled: !!searchSpaceId, + mutationFn: async (request: UpdateImageGenConfigRequest) => { + return imageGenConfigApiService.updateConfig(request); + }, + onSuccess: (_: UpdateImageGenConfigResponse, request: UpdateImageGenConfigRequest) => { + toast.success(`${request.data.name ?? "Configuration"} updated`); + queryClient.invalidateQueries({ + queryKey: cacheKeys.imageGenConfigs.all(Number(searchSpaceId)), + }); + queryClient.invalidateQueries({ + queryKey: cacheKeys.imageGenConfigs.byId(request.id), + }); + }, + onError: (error: Error) => { + toast.error(error.message || "Failed to update image model"); + }, + }; +}); + +/** + * Mutation atom for deleting an ImageGenerationConfig + */ +export const deleteImageGenConfigMutationAtom = atomWithMutation((get) => { + const searchSpaceId = get(activeSearchSpaceIdAtom); + + return { + mutationKey: ["image-gen-configs", "delete"], + meta: { suppressGlobalErrorToast: true }, + enabled: !!searchSpaceId, + mutationFn: async (request: { id: number; name: string }) => { + return imageGenConfigApiService.deleteConfig(request.id); + }, + onSuccess: (_: DeleteImageGenConfigResponse, request: { id: number; name: string }) => { + toast.success(`${request.name} deleted`); + queryClient.setQueryData( + cacheKeys.imageGenConfigs.all(Number(searchSpaceId)), + (oldData: GetImageGenConfigsResponse | undefined) => { + if (!oldData) return oldData; + return oldData.filter((config) => config.id !== request.id); + } + ); + }, + onError: (error: Error) => { + toast.error(error.message || "Failed to delete image model"); + }, + }; +}); diff --git a/surfsense_web/atoms/image-gen-config/image-gen-config-query.atoms.ts b/surfsense_web/atoms/image-gen-config/image-gen-config-query.atoms.ts new file mode 100644 index 000000000..a45e69a03 --- /dev/null +++ b/surfsense_web/atoms/image-gen-config/image-gen-config-query.atoms.ts @@ -0,0 +1,33 @@ +import { atomWithQuery } from "jotai-tanstack-query"; +import { imageGenConfigApiService } from "@/lib/apis/image-gen-config-api.service"; +import { cacheKeys } from "@/lib/query-client/cache-keys"; +import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; + +/** + * Query atom for fetching user-created image gen configs for the active search space + */ +export const imageGenConfigsAtom = atomWithQuery((get) => { + const searchSpaceId = get(activeSearchSpaceIdAtom); + + return { + queryKey: cacheKeys.imageGenConfigs.all(Number(searchSpaceId)), + enabled: !!searchSpaceId, + staleTime: 5 * 60 * 1000, // 5 minutes + queryFn: async () => { + return imageGenConfigApiService.getConfigs(Number(searchSpaceId)); + }, + }; +}); + +/** + * Query atom for fetching global image gen configs (from YAML, negative IDs) + */ +export const globalImageGenConfigsAtom = atomWithQuery(() => { + return { + queryKey: cacheKeys.imageGenConfigs.global(), + staleTime: 10 * 60 * 1000, // 10 minutes - global configs rarely change + queryFn: async () => { + return imageGenConfigApiService.getGlobalConfigs(); + }, + }; +}); diff --git a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts deleted file mode 100644 index f00bf76f9..000000000 --- a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts +++ /dev/null @@ -1,214 +0,0 @@ -import { atomWithMutation } from "jotai-tanstack-query"; -import { toast } from "sonner"; -import type { - ConnectionCreateRequest, - ConnectionRead, - ConnectionUpdateRequest, - ModelCreateRequest, - ModelPreviewRead, - ModelRead, - ModelRoles, - ModelsBulkUpdateRequest, - ModelTestPreviewRequest, - ModelUpdateRequest, - VerifyConnectionResponse, -} from "@/contracts/types/model-connections.types"; -import { modelConnectionsApiService } from "@/lib/apis/model-connections-api.service"; -import { cacheKeys } from "@/lib/query-client/cache-keys"; -import { queryClient } from "@/lib/query-client/client"; -import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; - -function invalidateModelConnections(searchSpaceId: number) { - queryClient.invalidateQueries({ - queryKey: cacheKeys.modelConnections.all(searchSpaceId), - }); - queryClient.invalidateQueries({ - queryKey: cacheKeys.modelConnections.roles(searchSpaceId), - }); -} - -function upsertModelConnection(searchSpaceId: number, connection: ConnectionRead) { - queryClient.setQueryData( - cacheKeys.modelConnections.all(searchSpaceId), - (current = []) => { - if (current.some((item) => item.id === connection.id)) { - return current.map((item) => (item.id === connection.id ? connection : item)); - } - return [...current, connection]; - } - ); -} - -export const createModelConnectionMutationAtom = atomWithMutation((get) => { - const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); - return { - mutationKey: ["model-connections", "create"], - mutationFn: (request: ConnectionCreateRequest) => - modelConnectionsApiService.createConnection(request), - onSuccess: (connection: ConnectionRead, request: ConnectionCreateRequest) => { - const resolvedSearchSpaceId = Number( - request.search_space_id ?? connection.search_space_id ?? searchSpaceId - ); - toast.success("Connection created"); - if (resolvedSearchSpaceId > 0) { - upsertModelConnection(resolvedSearchSpaceId, connection); - invalidateModelConnections(resolvedSearchSpaceId); - } - }, - onError: (error: Error) => toast.error(error.message || "Failed to create connection"), - }; -}); - -export const updateModelConnectionMutationAtom = atomWithMutation((get) => { - const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); - return { - mutationKey: ["model-connections", "update"], - mutationFn: ({ id, data }: { id: number; data: ConnectionUpdateRequest }) => - modelConnectionsApiService.updateConnection(id, data), - onSuccess: () => { - toast.success("Connection updated"); - invalidateModelConnections(searchSpaceId); - }, - onError: (error: Error) => toast.error(error.message || "Failed to update connection"), - }; -}); - -export const deleteModelConnectionMutationAtom = atomWithMutation((get) => { - const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); - return { - mutationKey: ["model-connections", "delete"], - mutationFn: (id: number) => modelConnectionsApiService.deleteConnection(id), - onSuccess: () => { - toast.success("Connection deleted"); - invalidateModelConnections(searchSpaceId); - }, - onError: (error: Error) => toast.error(error.message || "Failed to delete connection"), - }; -}); - -export const verifyModelConnectionMutationAtom = atomWithMutation((get) => { - const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); - return { - mutationKey: ["model-connections", "verify"], - mutationFn: (id: number) => modelConnectionsApiService.verifyConnection(id), - onSuccess: (result: VerifyConnectionResponse) => { - if (result.ok) { - toast.success("Connection verified"); - } else { - // Non-fatal: many providers lack a /models endpoint yet still serve - // chat. Guide the user to add model IDs manually instead of alarming. - toast.warning( - result.message - ? `${result.message} Chat may still work — add model IDs manually.` - : "Couldn't list models. Chat may still work — add model IDs manually." - ); - } - invalidateModelConnections(searchSpaceId); - }, - onError: (error: Error) => toast.error(error.message || "Failed to verify connection"), - }; -}); - -export const discoverConnectionModelsMutationAtom = atomWithMutation((get) => { - const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); - return { - mutationKey: ["model-connections", "discover"], - mutationFn: (id: number) => modelConnectionsApiService.discoverModels(id), - onSuccess: (models: ModelRead[]) => { - toast.success( - models.length ? `${models.length} models discovered` : "No models found for this connection" - ); - invalidateModelConnections(searchSpaceId); - }, - onError: (error: Error) => toast.error(error.message || "Failed to discover models"), - }; -}); - -export const previewConnectionModelsMutationAtom = atomWithMutation(() => { - return { - mutationKey: ["model-connections", "discover-preview"], - mutationFn: (request: ConnectionCreateRequest) => - modelConnectionsApiService.previewModels(request), - onSuccess: (_models: ModelPreviewRead[]) => {}, - onError: (error: Error) => toast.error(error.message || "Failed to discover models"), - }; -}); - -export const testPreviewModelMutationAtom = atomWithMutation(() => { - return { - mutationKey: ["model-connections", "test-preview"], - mutationFn: (request: ModelTestPreviewRequest) => - modelConnectionsApiService.testPreviewModel(request), - onSuccess: (result: VerifyConnectionResponse) => { - if (!result.ok) { - toast.error(result.message || "Model test failed"); - } - }, - onError: (error: Error) => toast.error(error.message || "Failed to test model"), - }; -}); - -export const addManualModelMutationAtom = atomWithMutation((get) => { - const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); - return { - mutationKey: ["models", "add-manual"], - mutationFn: ({ connectionId, data }: { connectionId: number; data: ModelCreateRequest }) => - modelConnectionsApiService.addManualModel(connectionId, data), - onSuccess: () => { - toast.success("Model added"); - invalidateModelConnections(searchSpaceId); - }, - onError: (error: Error) => toast.error(error.message || "Failed to add model"), - }; -}); - -export const updateModelMutationAtom = atomWithMutation((get) => { - const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); - return { - mutationKey: ["models", "update"], - mutationFn: ({ id, data }: { id: number; data: ModelUpdateRequest }) => - modelConnectionsApiService.updateModel(id, data), - onSuccess: () => invalidateModelConnections(searchSpaceId), - onError: (error: Error) => toast.error(error.message || "Failed to update model"), - }; -}); - -export const bulkUpdateModelsMutationAtom = atomWithMutation((get) => { - const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); - return { - mutationKey: ["models", "bulk-update"], - mutationFn: ({ connectionId, data }: { connectionId: number; data: ModelsBulkUpdateRequest }) => - modelConnectionsApiService.bulkUpdateModels(connectionId, data), - onSuccess: () => invalidateModelConnections(searchSpaceId), - onError: (error: Error) => toast.error(error.message || "Failed to update models"), - }; -}); - -export const testModelMutationAtom = atomWithMutation((get) => { - const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); - return { - mutationKey: ["models", "test"], - mutationFn: (id: number) => modelConnectionsApiService.testModel(id), - onSuccess: (result: VerifyConnectionResponse) => { - if (result.ok) toast.success("Model test succeeded"); - else toast.error(result.message || "Model test failed"); - invalidateModelConnections(searchSpaceId); - }, - onError: (error: Error) => toast.error(error.message || "Failed to test model"), - }; -}); - -export const updateModelRolesMutationAtom = atomWithMutation((get) => { - const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); - return { - mutationKey: ["model-roles", "update"], - mutationFn: (roles: ModelRoles) => - modelConnectionsApiService.updateModelRoles(searchSpaceId, roles), - onSuccess: () => { - queryClient.invalidateQueries({ - queryKey: cacheKeys.modelConnections.roles(searchSpaceId), - }); - }, - onError: (error: Error) => toast.error(error.message || "Failed to update model roles"), - }; -}); diff --git a/surfsense_web/atoms/model-connections/model-connections-query.atoms.ts b/surfsense_web/atoms/model-connections/model-connections-query.atoms.ts deleted file mode 100644 index 04dad9b21..000000000 --- a/surfsense_web/atoms/model-connections/model-connections-query.atoms.ts +++ /dev/null @@ -1,46 +0,0 @@ -import { atomWithQuery } from "jotai-tanstack-query"; -import { modelConnectionsApiService } from "@/lib/apis/model-connections-api.service"; -import { getBearerToken } from "@/lib/auth-utils"; -import { cacheKeys } from "@/lib/query-client/cache-keys"; -import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; - -export const globalModelConnectionsAtom = atomWithQuery(() => ({ - queryKey: cacheKeys.modelConnections.global(), - enabled: !!getBearerToken(), - staleTime: 10 * 60 * 1000, - queryFn: () => modelConnectionsApiService.getGlobalConnections(), -})); - -export const globalLlmConfigStatusAtom = atomWithQuery(() => ({ - queryKey: cacheKeys.modelConnections.globalConfigStatus(), - enabled: !!getBearerToken(), - staleTime: 60 * 60 * 1000, - queryFn: () => modelConnectionsApiService.getGlobalLlmConfigStatus(), -})); - -export const modelProvidersAtom = atomWithQuery(() => ({ - queryKey: cacheKeys.modelConnections.providers(), - enabled: !!getBearerToken(), - staleTime: 60 * 60 * 1000, - queryFn: () => modelConnectionsApiService.getModelProviders(), -})); - -export const modelConnectionsAtom = atomWithQuery((get) => { - const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); - return { - queryKey: cacheKeys.modelConnections.all(searchSpaceId), - enabled: !!searchSpaceId, - staleTime: 5 * 60 * 1000, - queryFn: () => modelConnectionsApiService.getConnections(searchSpaceId), - }; -}); - -export const modelRolesAtom = atomWithQuery((get) => { - const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); - return { - queryKey: cacheKeys.modelConnections.roles(searchSpaceId), - enabled: !!searchSpaceId, - staleTime: 5 * 60 * 1000, - queryFn: () => modelConnectionsApiService.getModelRoles(searchSpaceId), - }; -}); diff --git a/surfsense_web/atoms/new-llm-config/new-llm-config-mutation.atoms.ts b/surfsense_web/atoms/new-llm-config/new-llm-config-mutation.atoms.ts new file mode 100644 index 000000000..476d89d4c --- /dev/null +++ b/surfsense_web/atoms/new-llm-config/new-llm-config-mutation.atoms.ts @@ -0,0 +1,132 @@ +import { atomWithMutation } from "jotai-tanstack-query"; +import { toast } from "sonner"; +import type { + CreateNewLLMConfigRequest, + CreateNewLLMConfigResponse, + DeleteNewLLMConfigRequest, + DeleteNewLLMConfigResponse, + GetNewLLMConfigsResponse, + UpdateLLMPreferencesRequest, + UpdateNewLLMConfigRequest, + UpdateNewLLMConfigResponse, +} from "@/contracts/types/new-llm-config.types"; +import { newLLMConfigApiService } from "@/lib/apis/new-llm-config-api.service"; +import { cacheKeys } from "@/lib/query-client/cache-keys"; +import { queryClient } from "@/lib/query-client/client"; +import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; + +/** + * Mutation atom for creating a new NewLLMConfig + */ +export const createNewLLMConfigMutationAtom = atomWithMutation((get) => { + const searchSpaceId = get(activeSearchSpaceIdAtom); + + return { + mutationKey: ["new-llm-configs", "create"], + meta: { suppressGlobalErrorToast: true }, + enabled: !!searchSpaceId, + mutationFn: async (request: CreateNewLLMConfigRequest) => { + return newLLMConfigApiService.createConfig(request); + }, + onSuccess: (_: CreateNewLLMConfigResponse, request: CreateNewLLMConfigRequest) => { + toast.success(`${request.name} created`); + queryClient.invalidateQueries({ + queryKey: cacheKeys.newLLMConfigs.all(Number(searchSpaceId)), + }); + }, + onError: (error: Error) => { + toast.error(error.message || "Failed to create model"); + }, + }; +}); + +/** + * Mutation atom for updating an existing NewLLMConfig + */ +export const updateNewLLMConfigMutationAtom = atomWithMutation((get) => { + const searchSpaceId = get(activeSearchSpaceIdAtom); + + return { + mutationKey: ["new-llm-configs", "update"], + meta: { suppressGlobalErrorToast: true }, + enabled: !!searchSpaceId, + mutationFn: async (request: UpdateNewLLMConfigRequest) => { + return newLLMConfigApiService.updateConfig(request); + }, + onSuccess: (_: UpdateNewLLMConfigResponse, request: UpdateNewLLMConfigRequest) => { + toast.success(`${request.data.name ?? "Configuration"} updated`); + queryClient.invalidateQueries({ + queryKey: cacheKeys.newLLMConfigs.all(Number(searchSpaceId)), + }); + queryClient.invalidateQueries({ + queryKey: cacheKeys.newLLMConfigs.byId(request.id), + }); + }, + onError: (error: Error) => { + toast.error(error.message || "Failed to update"); + }, + }; +}); + +/** + * Mutation atom for deleting a NewLLMConfig + */ +export const deleteNewLLMConfigMutationAtom = atomWithMutation((get) => { + const searchSpaceId = get(activeSearchSpaceIdAtom); + + return { + mutationKey: ["new-llm-configs", "delete"], + meta: { suppressGlobalErrorToast: true }, + enabled: !!searchSpaceId, + mutationFn: async (request: DeleteNewLLMConfigRequest & { name: string }) => { + return newLLMConfigApiService.deleteConfig({ id: request.id }); + }, + onSuccess: ( + _: DeleteNewLLMConfigResponse, + request: DeleteNewLLMConfigRequest & { name: string } + ) => { + toast.success(`${request.name} deleted`); + queryClient.setQueryData( + cacheKeys.newLLMConfigs.all(Number(searchSpaceId)), + (oldData: GetNewLLMConfigsResponse | undefined) => { + if (!oldData) return oldData; + return oldData.filter((config) => config.id !== request.id); + } + ); + }, + onError: (error: Error) => { + toast.error(error.message || "Failed to delete"); + }, + }; +}); + +/** + * Mutation atom for updating LLM preferences (role assignments) + */ +export const updateLLMPreferencesMutationAtom = atomWithMutation((get) => { + const searchSpaceId = get(activeSearchSpaceIdAtom); + + return { + mutationKey: ["llm-preferences", "update"], + meta: { suppressGlobalErrorToast: true }, + enabled: !!searchSpaceId, + mutationFn: async (request: UpdateLLMPreferencesRequest) => { + return newLLMConfigApiService.updateLLMPreferences(request); + }, + onSuccess: (_data, request: UpdateLLMPreferencesRequest) => { + queryClient.setQueryData( + cacheKeys.newLLMConfigs.preferences(Number(searchSpaceId)), + (old: Record | undefined) => ({ ...old, ...request.data }) + ); + // Automation eligibility is derived from these model preferences + // (agent/image/vision). Invalidate it so the automations gate alert + // reflects the new selection without a manual refresh. + queryClient.invalidateQueries({ + queryKey: cacheKeys.automations.modelEligibility(Number(searchSpaceId)), + }); + }, + onError: (error: Error) => { + toast.error(error.message || "Failed to update LLM preferences"); + }, + }; +}); diff --git a/surfsense_web/atoms/new-llm-config/new-llm-config-query.atoms.ts b/surfsense_web/atoms/new-llm-config/new-llm-config-query.atoms.ts new file mode 100644 index 000000000..410d061e5 --- /dev/null +++ b/surfsense_web/atoms/new-llm-config/new-llm-config-query.atoms.ts @@ -0,0 +1,98 @@ +import { atomWithQuery } from "jotai-tanstack-query"; +import type { LLMModel } from "@/contracts/enums/llm-models"; +import { LLM_MODELS } from "@/contracts/enums/llm-models"; +import { newLLMConfigApiService } from "@/lib/apis/new-llm-config-api.service"; +import { getBearerToken } from "@/lib/auth-utils"; +import { cacheKeys } from "@/lib/query-client/cache-keys"; +import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; + +/** + * Query atom for fetching all NewLLMConfigs for the active search space + */ +export const newLLMConfigsAtom = atomWithQuery((get) => { + const searchSpaceId = get(activeSearchSpaceIdAtom); + + return { + queryKey: cacheKeys.newLLMConfigs.all(Number(searchSpaceId)), + enabled: !!searchSpaceId, + staleTime: 5 * 60 * 1000, // 5 minutes + queryFn: async () => { + return newLLMConfigApiService.getConfigs({ + search_space_id: Number(searchSpaceId), + }); + }, + }; +}); + +/** + * Query atom for fetching global NewLLMConfigs (from YAML, negative IDs) + */ +export const globalNewLLMConfigsAtom = atomWithQuery(() => { + return { + queryKey: cacheKeys.newLLMConfigs.global(), + staleTime: 10 * 60 * 1000, // 10 minutes - global configs rarely change + enabled: !!getBearerToken(), + queryFn: async () => { + return newLLMConfigApiService.getGlobalConfigs(); + }, + }; +}); + +/** + * Query atom for fetching LLM preferences (role assignments) for the active search space + */ +export const llmPreferencesAtom = atomWithQuery((get) => { + const searchSpaceId = get(activeSearchSpaceIdAtom); + + return { + queryKey: cacheKeys.newLLMConfigs.preferences(Number(searchSpaceId)), + enabled: !!searchSpaceId, + staleTime: 5 * 60 * 1000, // 5 minutes + queryFn: async () => { + return newLLMConfigApiService.getLLMPreferences(Number(searchSpaceId)); + }, + }; +}); + +/** + * Query atom for fetching default system instructions template + */ +export const defaultSystemInstructionsAtom = atomWithQuery(() => { + return { + queryKey: cacheKeys.newLLMConfigs.defaultInstructions(), + staleTime: 60 * 60 * 1000, // 1 hour - this rarely changes + queryFn: async () => { + return newLLMConfigApiService.getDefaultSystemInstructions(); + }, + }; +}); + +/** + * Query atom for the dynamic model catalogue. + * Fetched from the backend (which proxies OpenRouter's public API). + * Falls back to the static hardcoded list on error. + */ +export const modelListAtom = atomWithQuery(() => { + return { + queryKey: cacheKeys.newLLMConfigs.modelList(), + staleTime: 60 * 60 * 1000, // 1 hour - models don't change often + placeholderData: LLM_MODELS, + queryFn: async (): Promise => { + const data = await newLLMConfigApiService.getModels(); + const dynamicModels = data.map((m) => ({ + value: m.value, + label: m.label, + provider: m.provider, + contextWindow: m.context_window ?? undefined, + })); + + // Providers covered by the dynamic API (from OpenRouter mapping). + // For uncovered providers (Ollama, Groq, Bedrock, etc.) keep the + // hand-curated static suggestions so users still see model options. + const coveredProviders = new Set(dynamicModels.map((m) => m.provider)); + const staticFallbacks = LLM_MODELS.filter((m) => !coveredProviders.has(m.provider)); + + return [...dynamicModels, ...staticFallbacks]; + }, + }; +}); diff --git a/surfsense_web/atoms/vision-llm-config/vision-llm-config-mutation.atoms.ts b/surfsense_web/atoms/vision-llm-config/vision-llm-config-mutation.atoms.ts new file mode 100644 index 000000000..f46b977d5 --- /dev/null +++ b/surfsense_web/atoms/vision-llm-config/vision-llm-config-mutation.atoms.ts @@ -0,0 +1,87 @@ +import { atomWithMutation } from "jotai-tanstack-query"; +import { toast } from "sonner"; +import type { + CreateVisionLLMConfigRequest, + CreateVisionLLMConfigResponse, + DeleteVisionLLMConfigResponse, + GetVisionLLMConfigsResponse, + UpdateVisionLLMConfigRequest, + UpdateVisionLLMConfigResponse, +} from "@/contracts/types/new-llm-config.types"; +import { visionLLMConfigApiService } from "@/lib/apis/vision-llm-config-api.service"; +import { cacheKeys } from "@/lib/query-client/cache-keys"; +import { queryClient } from "@/lib/query-client/client"; +import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; + +export const createVisionLLMConfigMutationAtom = atomWithMutation((get) => { + const searchSpaceId = get(activeSearchSpaceIdAtom); + + return { + mutationKey: ["vision-llm-configs", "create"], + meta: { suppressGlobalErrorToast: true }, + enabled: !!searchSpaceId, + mutationFn: async (request: CreateVisionLLMConfigRequest) => { + return visionLLMConfigApiService.createConfig(request); + }, + onSuccess: (_: CreateVisionLLMConfigResponse, request: CreateVisionLLMConfigRequest) => { + toast.success(`${request.name} created`); + queryClient.invalidateQueries({ + queryKey: cacheKeys.visionLLMConfigs.all(Number(searchSpaceId)), + }); + }, + onError: (error: Error) => { + toast.error(error.message || "Failed to create vision model"); + }, + }; +}); + +export const updateVisionLLMConfigMutationAtom = atomWithMutation((get) => { + const searchSpaceId = get(activeSearchSpaceIdAtom); + + return { + mutationKey: ["vision-llm-configs", "update"], + meta: { suppressGlobalErrorToast: true }, + enabled: !!searchSpaceId, + mutationFn: async (request: UpdateVisionLLMConfigRequest) => { + return visionLLMConfigApiService.updateConfig(request); + }, + onSuccess: (_: UpdateVisionLLMConfigResponse, request: UpdateVisionLLMConfigRequest) => { + toast.success(`${request.data.name ?? "Configuration"} updated`); + queryClient.invalidateQueries({ + queryKey: cacheKeys.visionLLMConfigs.all(Number(searchSpaceId)), + }); + queryClient.invalidateQueries({ + queryKey: cacheKeys.visionLLMConfigs.byId(request.id), + }); + }, + onError: (error: Error) => { + toast.error(error.message || "Failed to update vision model"); + }, + }; +}); + +export const deleteVisionLLMConfigMutationAtom = atomWithMutation((get) => { + const searchSpaceId = get(activeSearchSpaceIdAtom); + + return { + mutationKey: ["vision-llm-configs", "delete"], + meta: { suppressGlobalErrorToast: true }, + enabled: !!searchSpaceId, + mutationFn: async (request: { id: number; name: string }) => { + return visionLLMConfigApiService.deleteConfig(request.id); + }, + onSuccess: (_: DeleteVisionLLMConfigResponse, request: { id: number; name: string }) => { + toast.success(`${request.name} deleted`); + queryClient.setQueryData( + cacheKeys.visionLLMConfigs.all(Number(searchSpaceId)), + (oldData: GetVisionLLMConfigsResponse | undefined) => { + if (!oldData) return oldData; + return oldData.filter((config) => config.id !== request.id); + } + ); + }, + onError: (error: Error) => { + toast.error(error.message || "Failed to delete vision model"); + }, + }; +}); diff --git a/surfsense_web/atoms/vision-llm-config/vision-llm-config-query.atoms.ts b/surfsense_web/atoms/vision-llm-config/vision-llm-config-query.atoms.ts new file mode 100644 index 000000000..906ce638f --- /dev/null +++ b/surfsense_web/atoms/vision-llm-config/vision-llm-config-query.atoms.ts @@ -0,0 +1,51 @@ +import { atomWithQuery } from "jotai-tanstack-query"; +import type { LLMModel } from "@/contracts/enums/llm-models"; +import { VISION_MODELS } from "@/contracts/enums/vision-providers"; +import { visionLLMConfigApiService } from "@/lib/apis/vision-llm-config-api.service"; +import { cacheKeys } from "@/lib/query-client/cache-keys"; +import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; + +export const visionLLMConfigsAtom = atomWithQuery((get) => { + const searchSpaceId = get(activeSearchSpaceIdAtom); + + return { + queryKey: cacheKeys.visionLLMConfigs.all(Number(searchSpaceId)), + enabled: !!searchSpaceId, + staleTime: 5 * 60 * 1000, + queryFn: async () => { + return visionLLMConfigApiService.getConfigs(Number(searchSpaceId)); + }, + }; +}); + +export const globalVisionLLMConfigsAtom = atomWithQuery(() => { + return { + queryKey: cacheKeys.visionLLMConfigs.global(), + staleTime: 10 * 60 * 1000, + queryFn: async () => { + return visionLLMConfigApiService.getGlobalConfigs(); + }, + }; +}); + +export const visionModelListAtom = atomWithQuery(() => { + return { + queryKey: cacheKeys.visionLLMConfigs.modelList(), + staleTime: 60 * 60 * 1000, + placeholderData: VISION_MODELS, + queryFn: async (): Promise => { + const data = await visionLLMConfigApiService.getModels(); + const dynamicModels = data.map((m) => ({ + value: m.value, + label: m.label, + provider: m.provider, + contextWindow: m.context_window ?? undefined, + })); + + const coveredProviders = new Set(dynamicModels.map((m) => m.provider)); + const staticFallbacks = VISION_MODELS.filter((m) => !coveredProviders.has(m.provider)); + + return [...dynamicModels, ...staticFallbacks]; + }, + }; +}); diff --git a/surfsense_web/components/agent-action-log/action-log-dialog.tsx b/surfsense_web/components/agent-action-log/action-log-dialog.tsx index 5f3b83db1..1d0eefc17 100644 --- a/surfsense_web/components/agent-action-log/action-log-dialog.tsx +++ b/surfsense_web/components/agent-action-log/action-log-dialog.tsx @@ -2,7 +2,7 @@ import { useQueryClient } from "@tanstack/react-query"; import { useAtom, useAtomValue } from "jotai"; -import { RefreshCw, Workflow } from "lucide-react"; +import { RefreshCcw, Workflow } from "lucide-react"; import { useCallback } from "react"; import { actionLogDialogAtom } from "@/atoms/agent/action-log-dialog.atom"; import { agentFlagsAtom } from "@/atoms/agent/agent-flags-query.atom"; @@ -112,7 +112,7 @@ export function ActionLogDialog() { className="absolute right-14 top-4 size-8 rounded-full p-0 text-muted-foreground hover:bg-accent hover:text-accent-foreground" aria-label="Refresh action log" > - +
diff --git a/surfsense_web/components/assistant-ui/assistant-message.tsx b/surfsense_web/components/assistant-ui/assistant-message.tsx index 59006b26e..d084ac0fd 100644 --- a/surfsense_web/components/assistant-ui/assistant-message.tsx +++ b/surfsense_web/components/assistant-ui/assistant-message.tsx @@ -26,9 +26,9 @@ import type { FC } from "react"; import { useEffect, useMemo, useRef, useState } from "react"; import { commentsEnabledAtom, targetCommentIdAtom } from "@/atoms/chat/current-thread.atom"; import { - globalModelConnectionsAtom, - modelConnectionsAtom, -} from "@/atoms/model-connections/model-connections-query.atoms"; + globalNewLLMConfigsAtom, + newLLMConfigsAtom, +} from "@/atoms/new-llm-config/new-llm-config-query.atoms"; import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms"; import { CitationMetadataProvider, @@ -37,10 +37,7 @@ import { import { MarkdownText } from "@/components/assistant-ui/markdown-text"; import { ReasoningMessagePart } from "@/components/assistant-ui/reasoning-message-part"; import { RevertTurnButton } from "@/components/assistant-ui/revert-turn-button"; -import { - type TokenUsageModelBreakdown, - useTokenUsage, -} from "@/components/assistant-ui/token-usage-context"; +import { useTokenUsage } from "@/components/assistant-ui/token-usage-context"; import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; import { CommentPanelContainer } from "@/components/chat-comments/comment-panel-container/comment-panel-container"; import { CommentSheet } from "@/components/chat-comments/comment-sheet/comment-sheet"; @@ -271,81 +268,29 @@ function formatTurnCost(micros: number): string { return "$0"; } -function normalizeUsageModelKey(modelKey: string): string { - return modelKey.trim().replace(/^~/, ""); -} - -function bareModelKey(modelKey: string): string { - const normalized = normalizeUsageModelKey(modelKey); - const parts = normalized.split("/"); - return parts[parts.length - 1] || normalized; -} - -function inferProviderFromModelKey(modelKey: string) { - const normalized = normalizeUsageModelKey(modelKey); - const [provider] = normalized.split("/"); - return provider && provider !== normalized ? provider : null; -} - -function titleCaseModelPart(part: string) { - if (!part) return ""; - const upper = part.toUpperCase(); - if (/^\d+(\.\d+)?[BKM]$/.test(upper)) return upper; - if (["gpt", "oai", "api", "llm", "vlm"].includes(part.toLowerCase())) return upper; - return part.charAt(0).toUpperCase() + part.slice(1); -} - -function humanizeModelId(modelKey: string): string { - const bare = bareModelKey(modelKey) - .replace(/:latest$/i, "") - .replace(/[-_]+/g, " ") - .trim(); - if (!bare) return modelKey; - return bare.split(/\s+/).map(titleCaseModelPart).join(" "); -} - const MessageInfoDropdown: FC<{ chatTurnId: string | null | undefined }> = ({ chatTurnId }) => { const messageId = useAuiState(({ message }) => message?.id); const createdAt = useAuiState(({ message }) => message?.createdAt); const usage = useTokenUsage(messageId); - const { data: globalConnections = [] } = useAtomValue(globalModelConnectionsAtom); - const { data: localConnections = [] } = useAtomValue(modelConnectionsAtom); + const { data: localConfigs } = useAtomValue(newLLMConfigsAtom); + const { data: globalConfigs } = useAtomValue(globalNewLLMConfigsAtom); - const modelConnectionByKey = useMemo(() => { - const map = new Map(); - for (const connection of [...globalConnections, ...localConnections]) { - for (const model of connection.models) { - const normalizedModelId = normalizeUsageModelKey(model.model_id); - const entry = { - name: model.display_name || model.model_id, - provider: connection.provider, - modelId: model.model_id, - }; - map.set(model.model_id, entry); - map.set(normalizedModelId, entry); - map.set(bareModelKey(model.model_id), entry); - } + const configByModel = useMemo(() => { + const map = new Map(); + for (const c of [...(globalConfigs ?? []), ...(localConfigs ?? [])]) { + map.set(c.model_name, { name: c.name, provider: c.provider }); } return map; - }, [globalConnections, localConnections]); + }, [localConfigs, globalConfigs]); - const resolveModel = (modelKey: string, counts: TokenUsageModelBreakdown) => { - const normalizedKey = normalizeUsageModelKey(counts.model_id || counts.model || modelKey); - const connectionModel = - modelConnectionByKey.get(modelKey) ?? - modelConnectionByKey.get(normalizeUsageModelKey(modelKey)) ?? - modelConnectionByKey.get(normalizedKey) ?? - modelConnectionByKey.get(bareModelKey(normalizedKey)); - const provider = - counts.provider || connectionModel?.provider || inferProviderFromModelKey(normalizedKey); - const modelId = counts.model_id || connectionModel?.modelId || modelKey; - const name = counts.display_name || connectionModel?.name || humanizeModelId(modelId); - return { - name, - modelId, - icon: provider ? getProviderIcon(provider, { className: "size-3.5 shrink-0" }) : null, - }; + const resolveModel = (modelKey: string) => { + const parts = modelKey.split("/"); + const bare = parts[parts.length - 1] ?? modelKey; + const config = configByModel.get(modelKey) ?? configByModel.get(bare); + return config + ? { name: config.name, icon: getProviderIcon(config.provider, { className: "size-3.5" }) } + : { name: modelKey, icon: null }; }; const modelBreakdown = usage ? (usage.usage ?? usage.model_breakdown) : undefined; @@ -374,12 +319,12 @@ const MessageInfoDropdown: FC<{ chatTurnId: string | null | undefined }> = ({ ch {models.length > 0 ? ( models.map(([model, counts]) => { - const { name, icon } = resolveModel(model, counts); + const { name, icon } = resolveModel(model); const costMicros = counts.cost_micros; return ( e.preventDefault()} > diff --git a/surfsense_web/components/assistant-ui/chat-viewport.tsx b/surfsense_web/components/assistant-ui/chat-viewport.tsx index 83308b642..dedada7a5 100644 --- a/surfsense_web/components/assistant-ui/chat-viewport.tsx +++ b/surfsense_web/components/assistant-ui/chat-viewport.tsx @@ -27,8 +27,8 @@ export interface ChatViewportProps { export const ChatViewport: FC = ({ children, footer }) => ( void; @@ -42,10 +42,17 @@ export const CirclebackConfig: FC = ({ connector, onNameC const doFetch = async () => { if (!connector.search_space_id) return; + const baseUrl = BACKEND_URL; + if (!baseUrl) { + console.error("NEXT_PUBLIC_FASTAPI_BACKEND_URL is not configured"); + setIsLoading(false); + return; + } + setIsLoading(true); try { const response = await authenticatedFetch( - buildBackendUrl(`/api/v1/webhooks/circleback/${connector.search_space_id}/info`), + `${baseUrl}/api/v1/webhooks/circleback/${connector.search_space_id}/info`, { signal: controller.signal } ); if (controller.signal.aborted) return; diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx index 1fc555471..011eeec96 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx @@ -13,7 +13,7 @@ import { getConnectorIcon } from "@/contracts/enums/connectorIcons"; import type { SearchSourceConnector } from "@/contracts/types/connector.types"; import { authenticatedFetch } from "@/lib/auth-utils"; import { getReauthEndpoint } from "@/lib/connector-telemetry"; -import { buildBackendUrl } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; import { cn } from "@/lib/utils"; import { DateRangeSelector } from "../../components/date-range-selector"; import { PeriodicSyncConfig } from "../../components/periodic-sync-config"; @@ -95,13 +95,12 @@ export const ConnectorEditView: FC = ({ if (!spaceId || !reauthEndpoint) return; setReauthing(true); try { - const response = await authenticatedFetch( - buildBackendUrl(reauthEndpoint, { - connector_id: connector.id, - space_id: spaceId, - return_url: window.location.pathname, - }) - ); + const backendUrl = BACKEND_URL; + const url = new URL(`${backendUrl}${reauthEndpoint}`); + url.searchParams.set("connector_id", String(connector.id)); + url.searchParams.set("space_id", String(spaceId)); + url.searchParams.set("return_url", window.location.pathname); + const response = await authenticatedFetch(url.toString()); if (!response.ok) { const data = await response.json().catch(() => ({})); toast.error(data.detail ?? "Failed to initiate re-authentication."); diff --git a/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts b/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts index 2f10152b8..45c174d74 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts +++ b/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts @@ -16,7 +16,7 @@ import type { SearchSourceConnector } from "@/contracts/types/connector.types"; import { searchSourceConnector } from "@/contracts/types/connector.types"; import { OAUTH_RESULT_COOKIE, parseOAuthCallbackResult } from "@/contracts/types/oauth.types"; import { authenticatedFetch } from "@/lib/auth-utils"; -import { buildBackendUrl } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; import { trackConnectorConnected, trackConnectorDeleted, @@ -351,7 +351,9 @@ export const useConnectorDialog = () => { trackConnectorSetupStarted(Number(searchSpaceId), connector.connectorType, "oauth_click"); try { - const url = buildBackendUrl(connector.authEndpoint, { space_id: searchSpaceId }); + // Check if authEndpoint already has query parameters + const separator = connector.authEndpoint.includes("?") ? "&" : "?"; + const url = `${BACKEND_URL}${connector.authEndpoint}${separator}space_id=${searchSpaceId}`; const response = await authenticatedFetch(url, { method: "GET" }); diff --git a/surfsense_web/components/assistant-ui/connector-popup/tabs/all-connectors-tab.tsx b/surfsense_web/components/assistant-ui/connector-popup/tabs/all-connectors-tab.tsx index f7b27441b..4977219f7 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/tabs/all-connectors-tab.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/tabs/all-connectors-tab.tsx @@ -2,10 +2,10 @@ import { Search } from "lucide-react"; import type { FC } from "react"; -import { useIsSelfHosted } from "@/components/providers/runtime-config"; import { EnumConnectorName } from "@/contracts/enums/connector"; import type { SearchSourceConnector } from "@/contracts/types/connector.types"; import { usePlatform } from "@/hooks/use-platform"; +import { isSelfHosted } from "@/lib/env-config"; import { ConnectorCard } from "../components/connector-card"; import { COMPOSIO_CONNECTORS, @@ -22,11 +22,6 @@ type OAuthConnector = (typeof OAUTH_CONNECTORS)[number]; type ComposioConnector = (typeof COMPOSIO_CONNECTORS)[number]; type OtherConnector = (typeof OTHER_CONNECTORS)[number]; type CrawlerConnector = (typeof CRAWLERS)[number]; -type DeploymentFilterableConnector = { - readonly id: string; - readonly selfHostedOnly?: boolean; - readonly desktopOnly?: boolean; -}; /** * Extract the display name from a full connector name. @@ -71,14 +66,14 @@ export const AllConnectorsTab: FC = ({ onManage, onViewAccountsList, }) => { - const selfHosted = useIsSelfHosted(); + const selfHosted = isSelfHosted(); const { isDesktop } = usePlatform(); const matchesSearch = (title: string, description: string) => title.toLowerCase().includes(searchQuery.toLowerCase()) || description.toLowerCase().includes(searchQuery.toLowerCase()); - const passesDeploymentFilter = (c: DeploymentFilterableConnector) => + const passesDeploymentFilter = (c: { selfHostedOnly?: boolean; desktopOnly?: boolean }) => (!c.selfHostedOnly || selfHosted) && (!c.desktopOnly || isDesktop); // Filter connectors based on search and deployment mode diff --git a/surfsense_web/components/assistant-ui/connector-popup/views/connector-accounts-list-view.tsx b/surfsense_web/components/assistant-ui/connector-popup/views/connector-accounts-list-view.tsx index f53537cdc..05b684397 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/views/connector-accounts-list-view.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/views/connector-accounts-list-view.tsx @@ -12,7 +12,7 @@ import { getConnectorIcon } from "@/contracts/enums/connectorIcons"; import type { SearchSourceConnector } from "@/contracts/types/connector.types"; import { authenticatedFetch } from "@/lib/auth-utils"; import { getReauthEndpoint } from "@/lib/connector-telemetry"; -import { buildBackendUrl } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; import { formatRelativeDate } from "@/lib/format-date"; import { cn } from "@/lib/utils"; import { LIVE_CONNECTOR_TYPES } from "../constants/connector-constants"; @@ -61,13 +61,12 @@ export const ConnectorAccountsListView: FC = ({ if (!searchSpaceId || !endpoint) return; setReauthingId(connector.id); try { - const response = await authenticatedFetch( - buildBackendUrl(endpoint, { - connector_id: connector.id, - space_id: searchSpaceId, - return_url: window.location.pathname, - }) - ); + const backendUrl = BACKEND_URL; + const url = new URL(`${backendUrl}${endpoint}`); + url.searchParams.set("connector_id", String(connector.id)); + url.searchParams.set("space_id", String(searchSpaceId)); + url.searchParams.set("return_url", window.location.pathname); + const response = await authenticatedFetch(url.toString()); if (!response.ok) { const data = await response.json().catch(() => ({})); toast.error(data.detail ?? "Failed to initiate re-authentication."); diff --git a/surfsense_web/components/assistant-ui/markdown-text.tsx b/surfsense_web/components/assistant-ui/markdown-text.tsx index cd32ca920..ee36e8499 100644 --- a/surfsense_web/components/assistant-ui/markdown-text.tsx +++ b/surfsense_web/components/assistant-ui/markdown-text.tsx @@ -110,7 +110,7 @@ const MarkdownTextImpl = () => { return ( { editorRef.current?.focus(); }, [isDesktop, showDocumentPopover, showPromptPicker, threadId]); - const handleChatModelSelected = useCallback(() => { - if (!isDesktop) return; - editorRef.current?.focus(); - }, [isDesktop]); - // Close document picker when a sidebar slide-out panel (inbox, etc.) opens. // React only on changes to the tick — comparing against the previously-seen // value preserves the one-shot semantics of the prior window-event approach @@ -937,11 +931,7 @@ const Composer: FC = () => { className="min-h-[48px] sm:min-h-[24px] **:data-slate-placeholder:font-normal" />
- +
{ interface ComposerActionProps { isBlockedByOtherUser?: boolean; - searchSpaceId: number; - onChatModelSelected?: () => void; } -const ComposerAction: FC = ({ - isBlockedByOtherUser = false, - searchSpaceId, - onChatModelSelected, -}) => { +const ComposerAction: FC = ({ isBlockedByOtherUser = false }) => { const mentionedDocuments = useAtomValue(mentionedDocumentsAtom); const setConnectorDialogOpen = useSetAtom(connectorDialogOpenAtom); const [toolsPopoverOpen, setToolsPopoverOpen] = useState(false); @@ -996,9 +980,9 @@ const ComposerAction: FC = ({ if (url) setPendingScreenImages((prev) => [...prev, url]); }, [electronAPI, setPendingScreenImages]); - const { data: globalModelConnections } = useAtomValue(globalModelConnectionsAtom); - const { data: modelConnections } = useAtomValue(modelConnectionsAtom); - const { data: modelRoles } = useAtomValue(modelRolesAtom); + const { data: userConfigs } = useAtomValue(newLLMConfigsAtom); + const { data: globalConfigs } = useAtomValue(globalNewLLMConfigsAtom); + const { data: preferences } = useAtomValue(llmPreferencesAtom); const { data: agentTools } = useAtomValue(agentToolsAtom); const disabledTools = useAtomValue(disabledToolsAtom); @@ -1085,18 +1069,15 @@ const ComposerAction: FC = ({ }, [hydrateDisabled]); const hasModelConfigured = useMemo(() => { - const chatModelId = modelRoles?.chat_model_id ?? 0; - if (chatModelId === 0) { - return [...(globalModelConnections ?? []), ...(modelConnections ?? [])].some((connection) => - connection.models.some((model) => model.enabled && Boolean(model.supports_chat)) - ); + if (!preferences) return false; + const agentLlmId = preferences.agent_llm_id; + if (agentLlmId === null || agentLlmId === undefined) return false; + + if (agentLlmId <= 0) { + return globalConfigs?.some((c) => c.id === agentLlmId) ?? false; } - return [...(globalModelConnections ?? []), ...(modelConnections ?? [])].some((connection) => - connection.models.some( - (model) => model.id === chatModelId && model.enabled && Boolean(model.supports_chat) - ) - ); - }, [modelRoles?.chat_model_id, globalModelConnections, modelConnections]); + return userConfigs?.some((c) => c.id === agentLlmId) ?? false; + }, [preferences, globalConfigs, userConfigs]); const isSendDisabled = isComposerEmpty || !hasModelConfigured || isBlockedByOtherUser; @@ -1578,11 +1559,6 @@ const ComposerAction: FC = ({
)}
- !thread.isRunning}> = ({ isBlockedByOtherUser ? "Wait for AI to finish responding" : !hasModelConfigured - ? "Please select a model to start chatting" + ? "Please select a model from the header to start chatting" : isComposerEmpty ? "Enter a message or add a screenshot to send" : "Send message" diff --git a/surfsense_web/components/assistant-ui/token-usage-context.tsx b/surfsense_web/components/assistant-ui/token-usage-context.tsx index 8db8c2b50..dd80bcac3 100644 --- a/surfsense_web/components/assistant-ui/token-usage-context.tsx +++ b/surfsense_web/components/assistant-ui/token-usage-context.tsx @@ -9,18 +9,6 @@ import { useSyncExternalStore, } from "react"; -export interface TokenUsageModelBreakdown { - prompt_tokens: number; - completion_tokens: number; - total_tokens: number; - cost_micros?: number; - model?: string | null; - model_ref?: string | null; - model_id?: string | null; - display_name?: string | null; - provider?: string | null; -} - export interface TokenUsageData { prompt_tokens: number; completion_tokens: number; @@ -32,8 +20,24 @@ export interface TokenUsageData { * before the migration won't have it. */ cost_micros?: number; - usage?: Record; - model_breakdown?: Record; + usage?: Record< + string, + { + prompt_tokens: number; + completion_tokens: number; + total_tokens: number; + cost_micros?: number; + } + >; + model_breakdown?: Record< + string, + { + prompt_tokens: number; + completion_tokens: number; + total_tokens: number; + cost_micros?: number; + } + >; } type Listener = () => void; diff --git a/surfsense_web/components/auth/sign-in-button.tsx b/surfsense_web/components/auth/sign-in-button.tsx index 581e37603..7f5a77f36 100644 --- a/surfsense_web/components/auth/sign-in-button.tsx +++ b/surfsense_web/components/auth/sign-in-button.tsx @@ -3,7 +3,7 @@ import Link from "next/link"; import { useState } from "react"; import { Button } from "@/components/ui/button"; -import { BUILD_TIME_AUTH_TYPE, buildBackendUrl } from "@/lib/env-config"; +import { AUTH_TYPE, BACKEND_URL } from "@/lib/env-config"; import { trackLoginAttempt } from "@/lib/posthog/events"; import { cn } from "@/lib/utils"; @@ -46,14 +46,14 @@ interface SignInButtonProps { } export const SignInButton = ({ variant = "desktop" }: SignInButtonProps) => { - const isGoogleAuth = BUILD_TIME_AUTH_TYPE === "GOOGLE"; + const isGoogleAuth = AUTH_TYPE === "GOOGLE"; const [isRedirecting, setIsRedirecting] = useState(false); const handleGoogleLogin = () => { if (isRedirecting) return; setIsRedirecting(true); trackLoginAttempt("google"); - window.location.href = buildBackendUrl("/auth/google/authorize-redirect"); + window.location.href = `${BACKEND_URL}/auth/google/authorize-redirect`; }; const getClassName = () => { diff --git a/surfsense_web/components/documents/download-original-button.tsx b/surfsense_web/components/documents/download-original-button.tsx index e04ead89a..b79b289b4 100644 --- a/surfsense_web/components/documents/download-original-button.tsx +++ b/surfsense_web/components/documents/download-original-button.tsx @@ -7,7 +7,7 @@ import { Button } from "@/components/ui/button"; import { Spinner } from "@/components/ui/spinner"; import { documentsApiService } from "@/lib/apis/documents-api.service"; import { authenticatedFetch } from "@/lib/auth-utils"; -import { buildBackendUrl } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; interface DownloadOriginalButtonProps { documentId: number; @@ -41,7 +41,7 @@ export function DownloadOriginalButton({ documentId }: DownloadOriginalButtonPro setDownloading(true); try { const response = await authenticatedFetch( - buildBackendUrl(`/api/v1/documents/${documentId}/download-original`), + `${BACKEND_URL}/api/v1/documents/${documentId}/download-original`, { method: "GET" } ); if (!response.ok) throw new Error("Download failed"); diff --git a/surfsense_web/components/editor-panel/editor-panel.tsx b/surfsense_web/components/editor-panel/editor-panel.tsx index 75283c81f..01983cbe1 100644 --- a/surfsense_web/components/editor-panel/editor-panel.tsx +++ b/surfsense_web/components/editor-panel/editor-panel.tsx @@ -17,7 +17,6 @@ import { toast } from "sonner"; import { closeEditorPanelAtom, editorPanelAtom } from "@/atoms/editor/editor-panel.atom"; import { DownloadOriginalButton } from "@/components/documents/download-original-button"; import { VersionHistoryButton } from "@/components/documents/version-history"; -import { PlateErrorBoundary } from "@/components/editor/plate-error-boundary"; import { SourceCodeEditor } from "@/components/editor/source-code-editor"; import { fetchMemoryEditorDocument, @@ -35,15 +34,14 @@ import { useMediaQuery } from "@/hooks/use-media-query"; import { useElectronAPI } from "@/hooks/use-platform"; import { authenticatedFetch, getBearerToken, redirectToLogin } from "@/lib/auth-utils"; import { inferMonacoLanguageFromPath } from "@/lib/editor-language"; -import { buildBackendUrl } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; const PlateEditor = dynamic( () => import("@/components/editor/plate-editor").then((m) => ({ default: m.PlateEditor })), { ssr: false, loading: () => } ); -const LARGE_DOCUMENT_THRESHOLD = 1 * 1024 * 1024; // 1MB, matches backend -const LARGE_DOCUMENT_LINE_THRESHOLD = 5000; +const LARGE_DOCUMENT_THRESHOLD = 2 * 1024 * 1024; // 2MB interface EditorContent { document_id: number; @@ -51,11 +49,9 @@ interface EditorContent { document_type?: string; source_markdown: string; content_size_bytes?: number; - line_count?: number; chunk_count?: number; viewer_mode?: ViewerMode; editor_plate_max_bytes?: number; - editor_plate_max_lines?: number; } const EDITABLE_DOCUMENT_TYPES = new Set(["FILE", "NOTE"]); @@ -122,15 +118,6 @@ function getUtf8ByteSize(value: string): number { return new TextEncoder().encode(value).byteLength; } -function countLines(value: string): number { - if (!value) return 0; - let count = 1; - for (let i = 0; i < value.length; i++) { - if (value.charCodeAt(i) === 10) count++; - } - return count; -} - function formatBytes(bytes: number): string { if (bytes >= 1024 * 1024) { return `${(bytes / 1024 / 1024).toFixed(1)}MB`; @@ -197,17 +184,10 @@ export function EditorPanelContent({ ); const plateMaxBytes = editorDoc?.editor_plate_max_bytes ?? LARGE_DOCUMENT_THRESHOLD; - const plateMaxLines = editorDoc?.editor_plate_max_lines ?? LARGE_DOCUMENT_LINE_THRESHOLD; - const docSizeBytes = editorDoc?.content_size_bytes ?? 0; - const docLineCount = - editorDoc?.line_count ?? - (editorDoc?.source_markdown ? countLines(editorDoc.source_markdown) : 0); - const isLargeDocument = docSizeBytes > plateMaxBytes || docLineCount > plateMaxLines; + const isLargeDocument = (editorDoc?.content_size_bytes ?? 0) > plateMaxBytes; const viewerMode: ViewerMode = isMemoryMode ? "plate" - : editorDoc?.viewer_mode === "monaco" || isLargeDocument - ? "monaco" - : "plate"; + : (editorDoc?.viewer_mode ?? (isLargeDocument ? "monaco" : "plate")); useEffect(() => { const controller = new AbortController(); @@ -280,12 +260,10 @@ export function EditorPanelContent({ return; } - const response = await authenticatedFetch( - buildBackendUrl( - `/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/editor-content` - ), - { method: "GET" } + const url = new URL( + `${BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/editor-content` ); + const response = await authenticatedFetch(url.toString(), { method: "GET" }); if (controller.signal.aborted) return; @@ -424,7 +402,7 @@ export function EditorPanelContent({ return; } const response = await authenticatedFetch( - buildBackendUrl(`/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/save`), + `${BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/save`, { method: "POST", headers: { "Content-Type": "application/json" }, @@ -443,8 +421,7 @@ export function EditorPanelContent({ setEditedMarkdown(null); if (!options?.silent) { const savedSizeBytes = getUtf8ByteSize(markdownRef.current); - const savedLineCount = countLines(markdownRef.current); - if (savedSizeBytes > plateMaxBytes || savedLineCount > plateMaxLines) { + if (savedSizeBytes > plateMaxBytes) { toast.success("Document saved. It will reopen in raw markdown mode."); } else { toast.success("Document saved! Reindexing in background..."); @@ -470,7 +447,6 @@ export function EditorPanelContent({ memoryLimits, memoryScope, plateMaxBytes, - plateMaxLines, resolveLocalVirtualPath, searchSpaceId, ] @@ -491,12 +467,8 @@ export function EditorPanelContent({ const localFileLanguage = inferMonacoLanguageFromPath(localFilePath); const activeMarkdown = editedMarkdown ?? editorDoc?.source_markdown ?? ""; const activeMarkdownSizeBytes = useMemo(() => getUtf8ByteSize(activeMarkdown), [activeMarkdown]); - const activeMarkdownLineCount = useMemo(() => countLines(activeMarkdown), [activeMarkdown]); - const isNearPlateLimit = - activeMarkdownSizeBytes >= plateMaxBytes * 0.9 || - activeMarkdownLineCount >= plateMaxLines * 0.9; - const isOverPlateLimit = - activeMarkdownSizeBytes > plateMaxBytes || activeMarkdownLineCount > plateMaxLines; + const isNearPlateLimit = activeMarkdownSizeBytes >= plateMaxBytes * 0.9; + const isOverPlateLimit = activeMarkdownSizeBytes > plateMaxBytes; const showPlateSizeWarning = showEditingActions && !isMemoryMode && !isLocalFileMode && isNearPlateLimit; const memoryLimitState = isMemoryMode @@ -509,13 +481,6 @@ export function EditorPanelContent({ ? "text-orange-500" : "text-muted-foreground"; const saveDisabled = saving || !hasUnsavedChanges || (memoryLimitState?.isOverLimit ?? false); - const editorInstanceKey = `${ - isMemoryMode - ? `memory-${memoryScope ?? "user"}` - : isLocalFileMode - ? (localFilePath ?? "local-file") - : documentId - }-${isEditing ? "editing" : "viewing"}`; const handleCancelEditing = useCallback(() => { const savedContent = editorDoc?.source_markdown ?? ""; @@ -531,9 +496,7 @@ export function EditorPanelContent({ setDownloading(true); try { const response = await authenticatedFetch( - buildBackendUrl( - `/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/download-markdown` - ), + `${BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/download-markdown`, { method: "GET" } ); if (!response.ok) throw new Error("Download failed"); @@ -562,7 +525,7 @@ export function EditorPanelContent({ This document is too large for the editor ( - {formatBytes(editorDoc.content_size_bytes ?? 0)}, {docLineCount.toLocaleString()} lines,{" "} + {Math.round((editorDoc.content_size_bytes ?? 0) / 1024 / 1024)}MB,{" "} {editorDoc.chunk_count ?? 0} chunks). Showing raw markdown below.
) : ( diff --git a/surfsense_web/components/editor-panel/memory.ts b/surfsense_web/components/editor-panel/memory.ts index 1beb977a6..aa5b1f68d 100644 --- a/surfsense_web/components/editor-panel/memory.ts +++ b/surfsense_web/components/editor-panel/memory.ts @@ -1,7 +1,6 @@ "use client"; import { authenticatedFetch } from "@/lib/auth-utils"; -import { buildBackendUrl } from "@/lib/env-config"; export type MemoryScope = "user" | "team"; @@ -30,6 +29,10 @@ function getMemoryPath(scope: MemoryScope, searchSpaceId?: number | null) { return `/api/v1/searchspaces/${searchSpaceId}/memory`; } +function getBackendUrl(path: string) { + return `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}${path}`; +} + export function getMemoryLimitState(length: number, limits?: MemoryLimits | null) { if (!limits) { return { @@ -62,7 +65,7 @@ export async function fetchMemoryEditorDocument({ title?: string | null; signal?: AbortSignal; }) { - const response = await authenticatedFetch(buildBackendUrl(getMemoryPath(scope, searchSpaceId)), { + const response = await authenticatedFetch(getBackendUrl(getMemoryPath(scope, searchSpaceId)), { method: "GET", signal, }); @@ -94,7 +97,7 @@ export async function saveMemoryMarkdown({ searchSpaceId?: number | null; markdown: string; }) { - const response = await authenticatedFetch(buildBackendUrl(getMemoryPath(scope, searchSpaceId)), { + const response = await authenticatedFetch(getBackendUrl(getMemoryPath(scope, searchSpaceId)), { method: "PUT", headers: { "Content-Type": "application/json" }, body: JSON.stringify({ memory_md: markdown }), diff --git a/surfsense_web/components/editor/plate-error-boundary.tsx b/surfsense_web/components/editor/plate-error-boundary.tsx deleted file mode 100644 index c5c18f5e0..000000000 --- a/surfsense_web/components/editor/plate-error-boundary.tsx +++ /dev/null @@ -1,34 +0,0 @@ -"use client"; - -import { Component, type ReactNode } from "react"; - -interface PlateErrorBoundaryProps { - children: ReactNode; - fallback: ReactNode; -} - -interface PlateErrorBoundaryState { - hasError: boolean; -} - -export class PlateErrorBoundary extends Component< - PlateErrorBoundaryProps, - PlateErrorBoundaryState -> { - constructor(props: PlateErrorBoundaryProps) { - super(props); - this.state = { hasError: false }; - } - - static getDerivedStateFromError(): PlateErrorBoundaryState { - return { hasError: true }; - } - - render() { - if (this.state.hasError) { - return this.props.fallback; - } - - return this.props.children; - } -} diff --git a/surfsense_web/components/free-chat/anonymous-chat.tsx b/surfsense_web/components/free-chat/anonymous-chat.tsx index e3b8273bc..aff58f7bc 100644 --- a/surfsense_web/components/free-chat/anonymous-chat.tsx +++ b/surfsense_web/components/free-chat/anonymous-chat.tsx @@ -6,7 +6,7 @@ import { Button } from "@/components/ui/button"; import type { AnonModel, AnonQuotaResponse } from "@/contracts/types/anonymous-chat.types"; import { anonymousChatApiService } from "@/lib/apis/anonymous-chat-api.service"; import { readSSEStream } from "@/lib/chat/streaming-state"; -import { buildBackendUrl } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; import { trackAnonymousChatMessageSent } from "@/lib/posthog/events"; import { cn } from "@/lib/utils"; import { QuotaBar } from "./quota-bar"; @@ -81,7 +81,7 @@ export function AnonymousChat({ model }: AnonymousChatProps) { content: m.content, })); - const response = await fetch(buildBackendUrl("/api/v1/public/anon-chat/stream"), { + const response = await fetch(`${BACKEND_URL}/api/v1/public/anon-chat/stream`, { method: "POST", headers: { "Content-Type": "application/json" }, credentials: "include", @@ -188,6 +188,9 @@ export function AnonymousChat({ model }: AnonymousChatProps) {

{model.name}

+ {model.description && ( +

{model.description}

+ )}

Free to use · No login required · Start typing below

diff --git a/surfsense_web/components/free-chat/free-chat-page.tsx b/surfsense_web/components/free-chat/free-chat-page.tsx index 966aaee60..b28b1e0a1 100644 --- a/surfsense_web/components/free-chat/free-chat-page.tsx +++ b/surfsense_web/components/free-chat/free-chat-page.tsx @@ -33,8 +33,9 @@ import { updateThinkingSteps, updateToolCall, } from "@/lib/chat/streaming-state"; -import { buildBackendUrl } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; import { trackAnonymousChatMessageSent } from "@/lib/posthog/events"; +import { FreeModelSelector } from "./free-model-selector"; import { FreeThread } from "./free-thread"; import { RemoveAdsBanner } from "./remove-ads-banner"; @@ -62,21 +63,6 @@ function normalizeFreeChatErrorMessage(error: unknown): string { if (code === "THREAD_BUSY") { return "A previous response is still stopping. Please try again in a moment."; } - if (code === "MODEL_AUTH_FAILED") { - return "This model’s API key is invalid or expired. Switch models, or update the API key."; - } - if (code === "MODEL_NOT_FOUND") { - return "This model is unavailable or no longer exists. Please switch models."; - } - if (code === "MODEL_CONTEXT_LIMIT") { - return "This request is too large for the selected model. Reduce the input or switch models."; - } - if (code === "MODEL_PROVIDER_UNAVAILABLE") { - return "The selected model provider is temporarily unavailable. Please try again or switch models."; - } - if (code === "RATE_LIMITED") { - return "This model is temporarily rate-limited. Please try again in a few seconds or switch models."; - } return error.message || "An unexpected error occurred"; } @@ -168,7 +154,7 @@ export function FreeChatPage() { assistantMsgId: string, signal: AbortSignal, turnstileToken: string | null - ): Promise<"captcha" | undefined> => { + ): Promise<"captcha" | void> => { const reqBody: Record = { model_slug: modelSlug, messages: messageHistory, @@ -176,7 +162,7 @@ export function FreeChatPage() { if (!webSearchEnabled) reqBody.disabled_tools = ["web_search"]; if (turnstileToken) reqBody.turnstile_token = turnstileToken; - const response = await fetch(buildBackendUrl("/api/v1/public/anon-chat/stream"), { + const response = await fetch(`${BACKEND_URL}/api/v1/public/anon-chat/stream`, { method: "POST", headers: { "Content-Type": "application/json" }, credentials: "include", @@ -498,6 +484,10 @@ export function FreeChatPage() {
+
+ +
+ {captchaRequired && TURNSTILE_SITE_KEY && ( diff --git a/surfsense_web/components/free-chat/free-composer.tsx b/surfsense_web/components/free-chat/free-composer.tsx index 162b906ad..46d9e0259 100644 --- a/surfsense_web/components/free-chat/free-composer.tsx +++ b/surfsense_web/components/free-chat/free-composer.tsx @@ -13,7 +13,6 @@ import { useAnonymousMode } from "@/contexts/anonymous-mode"; import { useLoginGate } from "@/contexts/login-gate"; import { anonymousChatApiService } from "@/lib/apis/anonymous-chat-api.service"; import { cn } from "@/lib/utils"; -import { FreeModelSelector } from "./free-model-selector"; const ANON_ALLOWED_EXTENSIONS = new Set([ ".md", @@ -228,8 +227,7 @@ export const FreeComposer: FC = () => {
-
- +
{!isRunning ? ( diff --git a/surfsense_web/components/free-chat/free-model-selector.tsx b/surfsense_web/components/free-chat/free-model-selector.tsx index d04bca8a2..9bf4ecee5 100644 --- a/surfsense_web/components/free-chat/free-model-selector.tsx +++ b/surfsense_web/components/free-chat/free-model-selector.tsx @@ -1,6 +1,6 @@ "use client"; -import { Check, ChevronDown, Cpu } from "lucide-react"; +import { Bot, Check, ChevronDown } from "lucide-react"; import { useRouter } from "next/navigation"; import { useCallback, useEffect, useMemo, useState } from "react"; import { Badge } from "@/components/ui/badge"; @@ -82,7 +82,7 @@ export function FreeModelSelector({ className }: { className?: string }) { ) : ( <> - + Select Model )} diff --git a/surfsense_web/components/homepage/global-announcement.tsx b/surfsense_web/components/homepage/global-announcement.tsx deleted file mode 100644 index 212be42c7..000000000 --- a/surfsense_web/components/homepage/global-announcement.tsx +++ /dev/null @@ -1,27 +0,0 @@ -import { IconInfoCircle } from "@tabler/icons-react"; -import { GLOBAL_ANNOUNCEMENT_ENABLED, GLOBAL_ANNOUNCEMENT_MESSAGE } from "@/lib/env-config"; - -/** - * Small, site-wide banner for planned downtime / maintenance notices. - * - * Controlled entirely through build-time env vars so it can be toggled from - * Vercel without a code change: - * - NEXT_PUBLIC_GLOBAL_ANNOUNCEMENT_ENABLED ("true" to show) - * - NEXT_PUBLIC_GLOBAL_ANNOUNCEMENT_MESSAGE (the copy to display) - */ -export function GlobalAnnouncement() { - const message = GLOBAL_ANNOUNCEMENT_MESSAGE.trim(); - - if (!GLOBAL_ANNOUNCEMENT_ENABLED || !message) { - return null; - } - - return ( -
-
- - {message} -
-
- ); -} diff --git a/surfsense_web/components/homepage/hero-section.tsx b/surfsense_web/components/homepage/hero-section.tsx index 0f3bfe1aa..09cf316d8 100644 --- a/surfsense_web/components/homepage/hero-section.tsx +++ b/surfsense_web/components/homepage/hero-section.tsx @@ -37,7 +37,7 @@ import { getAssetLabel, usePrimaryDownload, } from "@/lib/desktop-download-utils"; -import { BUILD_TIME_AUTH_TYPE, buildBackendUrl } from "@/lib/env-config"; +import { AUTH_TYPE, BACKEND_URL } from "@/lib/env-config"; import { trackLoginAttempt } from "@/lib/posthog/events"; import { cn } from "@/lib/utils"; @@ -314,14 +314,14 @@ export function HeroSection() { } function GetStartedButton() { - const isGoogleAuth = BUILD_TIME_AUTH_TYPE === "GOOGLE"; + const isGoogleAuth = AUTH_TYPE === "GOOGLE"; const [isRedirecting, setIsRedirecting] = useState(false); const handleGoogleLogin = () => { if (isRedirecting) return; setIsRedirecting(true); trackLoginAttempt("google"); - window.location.href = buildBackendUrl("/auth/google/authorize-redirect"); + window.location.href = `${BACKEND_URL}/auth/google/authorize-redirect`; }; if (isGoogleAuth) { diff --git a/surfsense_web/components/icons/providers/azure.svg b/surfsense_web/components/icons/providers/azure.svg deleted file mode 100644 index ba80f55ca..000000000 --- a/surfsense_web/components/icons/providers/azure.svg +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file diff --git a/surfsense_web/components/icons/providers/bedrock.svg b/surfsense_web/components/icons/providers/bedrock.svg index cde500c0d..195aa6594 100644 --- a/surfsense_web/components/icons/providers/bedrock.svg +++ b/surfsense_web/components/icons/providers/bedrock.svg @@ -1 +1 @@ - \ No newline at end of file + \ No newline at end of file diff --git a/surfsense_web/components/icons/providers/claude.svg b/surfsense_web/components/icons/providers/claude.svg deleted file mode 100644 index 8d732d5b0..000000000 --- a/surfsense_web/components/icons/providers/claude.svg +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file diff --git a/surfsense_web/components/icons/providers/index.ts b/surfsense_web/components/icons/providers/index.ts index 5c8276e62..aefa2a053 100644 --- a/surfsense_web/components/icons/providers/index.ts +++ b/surfsense_web/components/icons/providers/index.ts @@ -1,10 +1,8 @@ export { default as Ai21Icon } from "./ai21.svg"; export { default as AnthropicIcon } from "./anthropic.svg"; export { default as AnyscaleIcon } from "./anyscale.svg"; -export { default as AzureIcon } from "./azure.svg"; export { default as BedrockIcon } from "./bedrock.svg"; export { default as CerebrasIcon } from "./cerebras.svg"; -export { default as ClaudeIcon } from "./claude.svg"; export { default as CohereIcon } from "./cohere.svg"; export { default as CometApiIcon } from "./cometapi.svg"; export { default as DatabricksIcon } from "./dbrx.svg"; @@ -15,7 +13,6 @@ export { default as GeminiIcon } from "./gemini.svg"; export { default as GitHubModelsIcon } from "./github.svg"; export { default as GroqIcon } from "./groq.svg"; export { default as HuggingFaceIcon } from "./huggingface.svg"; -export { default as LmStudioIcon } from "./lm-studio.svg"; export { default as MiniMaxIcon } from "./minimax.svg"; export { default as MistralIcon } from "./mistral.svg"; export { default as MoonshotIcon } from "./moonshot.svg"; diff --git a/surfsense_web/components/icons/providers/lm-studio.svg b/surfsense_web/components/icons/providers/lm-studio.svg deleted file mode 100644 index b6ae7db3e..000000000 --- a/surfsense_web/components/icons/providers/lm-studio.svg +++ /dev/null @@ -1,21 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - diff --git a/surfsense_web/components/icons/providers/vertexai.svg b/surfsense_web/components/icons/providers/vertexai.svg index e46a3ca0f..45adce83b 100644 --- a/surfsense_web/components/icons/providers/vertexai.svg +++ b/surfsense_web/components/icons/providers/vertexai.svg @@ -1 +1 @@ - \ No newline at end of file + \ No newline at end of file diff --git a/surfsense_web/components/layout/providers/LayoutDataProvider.tsx b/surfsense_web/components/layout/providers/LayoutDataProvider.tsx index 429a1fde8..549e6e7d7 100644 --- a/surfsense_web/components/layout/providers/LayoutDataProvider.tsx +++ b/surfsense_web/components/layout/providers/LayoutDataProvider.tsx @@ -2,7 +2,7 @@ import { useQuery } from "@tanstack/react-query"; import { useAtom, useAtomValue, useSetAtom } from "jotai"; -import { AlarmClock, AlertTriangle, Inbox, LibraryBig } from "lucide-react"; +import { AlertTriangle, Inbox, LibraryBig, Workflow } from "lucide-react"; import { useParams, usePathname, useRouter } from "next/navigation"; import { useTranslations } from "next-intl"; import { useTheme } from "next-themes"; @@ -342,7 +342,7 @@ export function LayoutDataProvider({ searchSpaceId, children }: LayoutDataProvid { title: "Automations", url: `/dashboard/${searchSpaceId}/automations`, - icon: AlarmClock, + icon: Workflow, isActive: isAutomationsActive, }, isMobile diff --git a/surfsense_web/components/layout/ui/dialogs/CreateSearchSpaceDialog.tsx b/surfsense_web/components/layout/ui/dialogs/CreateSearchSpaceDialog.tsx index 009b2c120..6f385b465 100644 --- a/surfsense_web/components/layout/ui/dialogs/CreateSearchSpaceDialog.tsx +++ b/surfsense_web/components/layout/ui/dialogs/CreateSearchSpaceDialog.tsx @@ -67,7 +67,7 @@ export function CreateSearchSpaceDialog({ open, onOpenChange }: CreateSearchSpac trackSearchSpaceCreated(result.id, values.name); - router.push(`/dashboard/${result.id}/new-chat`); + router.push(`/dashboard/${result.id}/onboard`); } catch (error) { console.error("Failed to create search space:", error); setIsSubmitting(false); diff --git a/surfsense_web/components/layout/ui/header/Header.tsx b/surfsense_web/components/layout/ui/header/Header.tsx index ea700391a..79839622d 100644 --- a/surfsense_web/components/layout/ui/header/Header.tsx +++ b/surfsense_web/components/layout/ui/header/Header.tsx @@ -6,6 +6,7 @@ import { currentThreadAtom } from "@/atoms/chat/current-thread.atom"; import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms"; import { activeTabAtom } from "@/atoms/tabs/tabs.atom"; import { ActionLogButton } from "@/components/agent-action-log/action-log-button"; +import { ChatHeader } from "@/components/new-chat/chat-header"; import { ChatShareButton } from "@/components/new-chat/chat-share-button"; import type { ThreadRecord } from "@/lib/chat/thread-persistence"; @@ -65,8 +66,13 @@ export function Header({ mobileMenuTrigger }: HeaderProps) { return (
- {/* Left side - Mobile menu trigger */} -
{mobileMenuTrigger}
+ {/* Left side - Mobile menu trigger + Model selector */} +
+ {mobileMenuTrigger} + {isChatPage && !isDocumentTab && searchSpaceId && ( + + )} +
{/* Right side - Actions */}
diff --git a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx index 44cc56ab0..6c6668319 100644 --- a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx +++ b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx @@ -43,7 +43,6 @@ import type { FolderDisplay } from "@/components/documents/FolderNode"; import { FolderPickerDialog } from "@/components/documents/FolderPickerDialog"; import { FolderTreeView } from "@/components/documents/FolderTreeView"; import { VersionHistoryDialog } from "@/components/documents/version-history"; -import { useRuntimeConfig } from "@/components/providers/runtime-config"; import { EXPORT_FILE_EXTENSIONS } from "@/components/shared/ExportMenuItems"; import { DEFAULT_EXCLUDE_PATTERNS, @@ -79,7 +78,7 @@ import { foldersApiService } from "@/lib/apis/folders-api.service"; import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service"; import { authenticatedFetch } from "@/lib/auth-utils"; import { getMentionDocKey } from "@/lib/chat/mention-doc-key"; -import { buildBackendUrl } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; import { uploadFolderScan } from "@/lib/folder-sync-upload"; import { getSupportedExtensionsSet } from "@/lib/supported-extensions"; import { queries } from "@/zero/queries/index"; @@ -227,7 +226,6 @@ function AuthenticatedDocumentsSidebarBase({ const isMobile = !useMediaQuery("(min-width: 640px)"); const platformElectronAPI = useElectronAPI(); const electronAPI = desktopFeaturesEnabled ? platformElectronAPI : null; - const { etlService } = useRuntimeConfig(); const searchSpaceId = Number(params.search_space_id); const setConnectorDialogOpen = useSetAtom(connectorDialogOpenAtom); const openEditorPanel = useSetAtom(openEditorPanelAtom); @@ -620,8 +618,7 @@ function AuthenticatedDocumentsSidebarBase({ folderName: matched.name, searchSpaceId, excludePatterns: matched.excludePatterns ?? DEFAULT_EXCLUDE_PATTERNS, - fileExtensions: - matched.fileExtensions ?? Array.from(getSupportedExtensionsSet(undefined, etlService)), + fileExtensions: matched.fileExtensions ?? Array.from(getSupportedExtensionsSet()), rootFolderId: folder.id, }); toast.success(`Re-scan complete: ${matched.name}`); @@ -629,7 +626,7 @@ function AuthenticatedDocumentsSidebarBase({ toast.error((err as Error)?.message || "Failed to re-scan folder"); } }, - [searchSpaceId, electronAPI, etlService] + [searchSpaceId, electronAPI] ); const handleStopWatching = useCallback( @@ -751,9 +748,7 @@ function AuthenticatedDocumentsSidebarBase({ .trim() .slice(0, 80) || "folder"; await doExport( - buildBackendUrl(`/api/v1/search-spaces/${searchSpaceId}/export`, { - folder_id: ctx.folder.id, - }), + `${BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/export?folder_id=${ctx.folder.id}`, `${safeName}.zip` ); toast.success(`Folder "${ctx.folder.name}" exported`); @@ -805,9 +800,7 @@ function AuthenticatedDocumentsSidebarBase({ .trim() .slice(0, 80) || "folder"; await doExport( - buildBackendUrl(`/api/v1/search-spaces/${searchSpaceId}/export`, { - folder_id: folder.id, - }), + `${BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/export?folder_id=${folder.id}`, `${safeName}.zip` ); toast.success(`Folder "${folder.name}" exported`); @@ -827,8 +820,8 @@ function AuthenticatedDocumentsSidebarBase({ try { const endpoint = doc.document_type === "USER_MEMORY" - ? buildBackendUrl("/api/v1/users/me/memory") - : buildBackendUrl(`/api/v1/searchspaces/${searchSpaceId}/memory`); + ? `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/users/me/memory` + : `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/searchspaces/${searchSpaceId}/memory`; const response = await authenticatedFetch(endpoint, { method: "GET" }); if (!response.ok) { const errorData = await response.json().catch(() => ({ detail: "Export failed" })); @@ -856,9 +849,7 @@ function AuthenticatedDocumentsSidebarBase({ try { const response = await authenticatedFetch( - buildBackendUrl(`/api/v1/search-spaces/${searchSpaceId}/documents/${doc.id}/export`, { - format, - }), + `${BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${doc.id}/export?format=${format}`, { method: "GET" } ); @@ -1037,8 +1028,8 @@ function AuthenticatedDocumentsSidebarBase({ } const endpoint = doc.document_type === "USER_MEMORY" - ? buildBackendUrl("/api/v1/users/me/memory/reset") - : buildBackendUrl(`/api/v1/searchspaces/${searchSpaceId}/memory/reset`); + ? `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/users/me/memory/reset` + : `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/searchspaces/${searchSpaceId}/memory/reset`; try { const response = await authenticatedFetch(endpoint, { method: "POST" }); if (!response.ok) { @@ -1158,7 +1149,6 @@ function AuthenticatedDocumentsSidebarBase({ const showCloudSkeleton = currentFilesystemTab === "cloud" && (zeroFoldersResult.type !== "complete" || zeroAllDocsResult.type !== "complete"); - const connectorButtonLabel = connectorCount > 0 ? "Manage connectors" : "Connect your connectors"; const cloudContent = ( <> @@ -1171,7 +1161,9 @@ function AuthenticatedDocumentsSidebarBase({ className="shrink-0 mx-4 mt-6 mb-2.5 h-auto select-none justify-start gap-2 bg-muted px-3 py-1.5 text-xs text-muted-foreground" > - {connectorButtonLabel} + + {connectorCount > 0 ? "Manage connectors" : "Connect your connectors"} + {connectorCount > 0 && ( {connectorCount} diff --git a/surfsense_web/components/layout/ui/tabs/DocumentTabContent.tsx b/surfsense_web/components/layout/ui/tabs/DocumentTabContent.tsx index d50d28a3c..61b8c3e25 100644 --- a/surfsense_web/components/layout/ui/tabs/DocumentTabContent.tsx +++ b/surfsense_web/components/layout/ui/tabs/DocumentTabContent.tsx @@ -11,7 +11,7 @@ import { Alert, AlertDescription } from "@/components/ui/alert"; import { Button } from "@/components/ui/button"; import { Spinner } from "@/components/ui/spinner"; import { authenticatedFetch, getBearerToken, redirectToLogin } from "@/lib/auth-utils"; -import { buildBackendUrl } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; const LARGE_DOCUMENT_THRESHOLD = 2 * 1024 * 1024; // 2MB @@ -108,12 +108,10 @@ export function DocumentTabContent({ documentId, searchSpaceId, title }: Documen } try { - const response = await authenticatedFetch( - buildBackendUrl( - `/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/editor-content` - ), - { method: "GET" } + const url = new URL( + `${BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/editor-content` ); + const response = await authenticatedFetch(url.toString(), { method: "GET" }); if (controller.signal.aborted) return; @@ -167,7 +165,7 @@ export function DocumentTabContent({ documentId, searchSpaceId, title }: Documen setSaving(true); try { const response = await authenticatedFetch( - buildBackendUrl(`/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/save`), + `${BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/save`, { method: "POST", headers: { "Content-Type": "application/json" }, @@ -325,9 +323,7 @@ export function DocumentTabContent({ documentId, searchSpaceId, title }: Documen setDownloading(true); try { const response = await authenticatedFetch( - buildBackendUrl( - `/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/download-markdown` - ), + `${BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/download-markdown`, { method: "GET" } ); if (!response.ok) throw new Error("Download failed"); diff --git a/surfsense_web/components/new-chat/chat-example-prompts.tsx b/surfsense_web/components/new-chat/chat-example-prompts.tsx index 344176629..61041cc29 100644 --- a/surfsense_web/components/new-chat/chat-example-prompts.tsx +++ b/surfsense_web/components/new-chat/chat-example-prompts.tsx @@ -1,12 +1,12 @@ "use client"; import { - AlarmClock, FilePlus2, type LucideIcon, Search, Settings2, WandSparkles, + Workflow, X, } from "lucide-react"; import { memo, useCallback, useState } from "react"; @@ -22,7 +22,7 @@ interface ChatExamplePromptsProps { const CATEGORY_ICONS: Record = { search: Search, create: FilePlus2, - automate: AlarmClock, + automate: Workflow, tools: Settings2, }; diff --git a/surfsense_web/components/new-chat/chat-header.tsx b/surfsense_web/components/new-chat/chat-header.tsx index 9882530d4..4716418ee 100644 --- a/surfsense_web/components/new-chat/chat-header.tsx +++ b/surfsense_web/components/new-chat/chat-header.tsx @@ -1,23 +1,167 @@ "use client"; -import { ImageModelSelector } from "./image-model-selector"; +import { useCallback, useState } from "react"; +import { ImageConfigDialog } from "@/components/shared/image-config-dialog"; +import { ModelConfigDialog } from "@/components/shared/model-config-dialog"; +import { VisionConfigDialog } from "@/components/shared/vision-config-dialog"; +import type { + GlobalImageGenConfig, + GlobalNewLLMConfig, + GlobalVisionLLMConfig, + ImageGenerationConfig, + NewLLMConfigPublic, + VisionLLMConfig, +} from "@/contracts/types/new-llm-config.types"; import { ModelSelector } from "./model-selector"; interface ChatHeaderProps { searchSpaceId: number; className?: string; - onChatModelSelected?: () => void; } -export function ChatHeader({ searchSpaceId, className, onChatModelSelected }: ChatHeaderProps) { +export function ChatHeader({ searchSpaceId, className }: ChatHeaderProps) { + // LLM config dialog state + const [dialogOpen, setDialogOpen] = useState(false); + const [selectedConfig, setSelectedConfig] = useState< + NewLLMConfigPublic | GlobalNewLLMConfig | null + >(null); + const [isGlobal, setIsGlobal] = useState(false); + const [dialogMode, setDialogMode] = useState<"create" | "edit" | "view">("view"); + + // Image config dialog state + const [imageDialogOpen, setImageDialogOpen] = useState(false); + const [selectedImageConfig, setSelectedImageConfig] = useState< + ImageGenerationConfig | GlobalImageGenConfig | null + >(null); + const [isImageGlobal, setIsImageGlobal] = useState(false); + const [imageDialogMode, setImageDialogMode] = useState<"create" | "edit" | "view">("view"); + + // Vision config dialog state + const [visionDialogOpen, setVisionDialogOpen] = useState(false); + const [selectedVisionConfig, setSelectedVisionConfig] = useState< + VisionLLMConfig | GlobalVisionLLMConfig | null + >(null); + const [isVisionGlobal, setIsVisionGlobal] = useState(false); + const [visionDialogMode, setVisionDialogMode] = useState<"create" | "edit" | "view">("view"); + + // Default provider for create dialogs + const [defaultLLMProvider, setDefaultLLMProvider] = useState(); + const [defaultImageProvider, setDefaultImageProvider] = useState(); + const [defaultVisionProvider, setDefaultVisionProvider] = useState(); + + // LLM handlers + const handleEditLLMConfig = useCallback( + (config: NewLLMConfigPublic | GlobalNewLLMConfig, global: boolean) => { + setSelectedConfig(config); + setIsGlobal(global); + setDialogMode(global ? "view" : "edit"); + setDefaultLLMProvider(undefined); + setDialogOpen(true); + }, + [] + ); + + const handleAddNewLLM = useCallback((provider?: string) => { + setSelectedConfig(null); + setIsGlobal(false); + setDialogMode("create"); + setDefaultLLMProvider(provider); + setDialogOpen(true); + }, []); + + const handleDialogClose = useCallback((open: boolean) => { + setDialogOpen(open); + if (!open) setSelectedConfig(null); + }, []); + + // Image model handlers + const handleAddImageModel = useCallback((provider?: string) => { + setSelectedImageConfig(null); + setIsImageGlobal(false); + setImageDialogMode("create"); + setDefaultImageProvider(provider); + setImageDialogOpen(true); + }, []); + + const handleEditImageConfig = useCallback( + (config: ImageGenerationConfig | GlobalImageGenConfig, global: boolean) => { + setSelectedImageConfig(config); + setIsImageGlobal(global); + setImageDialogMode(global ? "view" : "edit"); + setDefaultImageProvider(undefined); + setImageDialogOpen(true); + }, + [] + ); + + const handleImageDialogClose = useCallback((open: boolean) => { + setImageDialogOpen(open); + if (!open) setSelectedImageConfig(null); + }, []); + + // Vision model handlers + const handleAddVisionModel = useCallback((provider?: string) => { + setSelectedVisionConfig(null); + setIsVisionGlobal(false); + setVisionDialogMode("create"); + setDefaultVisionProvider(provider); + setVisionDialogOpen(true); + }, []); + + const handleEditVisionConfig = useCallback( + (config: VisionLLMConfig | GlobalVisionLLMConfig, global: boolean) => { + setSelectedVisionConfig(config); + setIsVisionGlobal(global); + setVisionDialogMode(global ? "view" : "edit"); + setDefaultVisionProvider(undefined); + setVisionDialogOpen(true); + }, + [] + ); + + const handleVisionDialogClose = useCallback((open: boolean) => { + setVisionDialogOpen(open); + if (!open) setSelectedVisionConfig(null); + }, []); + return (
- + + +
); } diff --git a/surfsense_web/components/new-chat/image-model-selector.tsx b/surfsense_web/components/new-chat/image-model-selector.tsx deleted file mode 100644 index e90a46c09..000000000 --- a/surfsense_web/components/new-chat/image-model-selector.tsx +++ /dev/null @@ -1,301 +0,0 @@ -"use client"; - -import { useAtom, useAtomValue } from "jotai"; -import { Check, ChevronDown, ImagePlus, Search, SlidersHorizontal } from "lucide-react"; -import { useRouter } from "next/navigation"; -import type { UIEvent } from "react"; -import { useCallback, useMemo, useState } from "react"; -import { updateModelRolesMutationAtom } from "@/atoms/model-connections/model-connections-mutation.atoms"; -import { - globalModelConnectionsAtom, - modelConnectionsAtom, - modelRolesAtom, -} from "@/atoms/model-connections/model-connections-query.atoms"; -import { Badge } from "@/components/ui/badge"; -import { Button } from "@/components/ui/button"; -import { - Drawer, - DrawerContent, - DrawerHandle, - DrawerHeader, - DrawerTitle, - DrawerTrigger, -} from "@/components/ui/drawer"; -import { Input } from "@/components/ui/input"; -import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; -import { Spinner } from "@/components/ui/spinner"; -import type { ConnectionRead, ModelRead } from "@/contracts/types/model-connections.types"; -import { useIsMobile } from "@/hooks/use-mobile"; -import { AUTO_PROVIDER_ICON_KEY, getProviderIcon } from "@/lib/provider-icons"; -import { cn } from "@/lib/utils"; -import { providerDisplay } from "../settings/model-connections/provider-metadata"; - -interface ImageModelSelectorProps { - searchSpaceId: number; - className?: string; -} - -type ImageModel = ModelRead & { - connectionId: number; - connectionLabel: string; - connectionScope: string; - provider: string; -}; - -const AUTO_IMAGE_MODEL_ID = 0; - -function connectionLabel(connection: ConnectionRead) { - if (connection.scope === "GLOBAL") return "Global"; - return providerDisplay(connection.provider).name; -} - -function flattenImageModels(connections: ConnectionRead[]) { - return connections.flatMap((connection) => - connection.models - .filter((model) => model.enabled && Boolean(model.supports_image_generation)) - .map((model) => ({ - ...model, - connectionId: connection.id, - connectionLabel: connectionLabel(connection), - connectionScope: connection.scope, - provider: connection.provider, - })) - ); -} - -function isFreeGlobalModel(model: ImageModel) { - return model.connectionScope === "GLOBAL" && model.billing_tier?.toLowerCase() === "free"; -} - -function modelName(model: ImageModel) { - const name = model.display_name || model.model_id; - if (model.connectionScope === "GLOBAL") { - return name.replace(/\s+\(free\)$/i, ""); - } - return name; -} - -function filterImageModels(models: ImageModel[], search: string) { - const normalized = search.trim().toLowerCase(); - if (!normalized) return models; - return models.filter((model) => - [modelName(model), model.model_id, model.connectionLabel] - .join(" ") - .toLowerCase() - .includes(normalized) - ); -} - -function groupedModels(models: ImageModel[]) { - return models.reduce>((groups, model) => { - const key = model.connectionLabel; - if (!groups[key]) groups[key] = []; - groups[key].push(model); - return groups; - }, {}); -} - -export function ImageModelSelector({ searchSpaceId, className }: ImageModelSelectorProps) { - const router = useRouter(); - const isMobile = useIsMobile(); - const [open, setOpen] = useState(false); - const [search, setSearch] = useState(""); - const [scrollPos, setScrollPos] = useState<"top" | "middle" | "bottom">("top"); - const [{ data: globalConnections = [], isLoading: globalLoading }] = useAtom( - globalModelConnectionsAtom - ); - const [{ data: connections = [], isLoading: connectionsLoading }] = useAtom(modelConnectionsAtom); - const [{ data: roles }] = useAtom(modelRolesAtom); - const updateRoles = useAtomValue(updateModelRolesMutationAtom); - - const allImageModels = useMemo( - () => flattenImageModels([...globalConnections, ...connections]), - [globalConnections, connections] - ); - - const visibleImageModels = useMemo( - () => filterImageModels(allImageModels, search), - [allImageModels, search] - ); - const imageModelsById = useMemo( - () => new Map(allImageModels.map((model) => [model.id, model])), - [allImageModels] - ); - const selectedModelId = roles?.image_gen_model_id ?? AUTO_IMAGE_MODEL_ID; - const selected = imageModelsById.get(selectedModelId); - const groups = useMemo(() => groupedModels(visibleImageModels), [visibleImageModels]); - const loading = globalLoading || connectionsLoading; - const hasSearchQuery = search.trim().length > 0; - - function handleOpenChange(nextOpen: boolean) { - if (!nextOpen) setSearch(""); - setOpen(nextOpen); - } - - function selectModel(modelId: number) { - updateRoles.mutate({ image_gen_model_id: modelId }); - setSearch(""); - setOpen(false); - } - - function manageModelConnections() { - setOpen(false); - router.push(`/dashboard/${searchSpaceId}/search-space-settings/models`); - } - - const handleScroll = useCallback((event: UIEvent) => { - const el = event.currentTarget; - const atTop = el.scrollTop <= 2; - const atBottom = el.scrollHeight - el.scrollTop - el.clientHeight <= 2; - setScrollPos(atTop ? "top" : atBottom ? "bottom" : "middle"); - }, []); - - // Only surface this control when usable image-generation models exist. - if (!loading && allImageModels.length === 0) { - return null; - } - - const content = ( -
-
-
- - setSearch(event.target.value)} - placeholder="Search image models" - className="h-8 border-0 bg-transparent pl-6 text-sm shadow-none" - /> -
-
-
- - {loading ? ( -
- -
- ) : Object.keys(groups).length === 0 ? ( -
- {hasSearchQuery - ? "No matching image models." - : "No enabled image models. Add or enable models in Settings."} -
- ) : ( - Object.entries(groups).map(([connection, models]) => ( -
-
- {connection} -
- {models.map((model) => ( - - ))} -
- )) - )} -
-
- -
-
- ); - - const trigger = ( - - ); - - if (isMobile) { - return ( - - {trigger} - - - - Select Image Model - - {content} - - - ); - } - - return ( - - {trigger} - - {content} - - - ); -} diff --git a/surfsense_web/components/new-chat/model-selector.tsx b/surfsense_web/components/new-chat/model-selector.tsx index 22d86aa92..0a096f5f8 100644 --- a/surfsense_web/components/new-chat/model-selector.tsx +++ b/surfsense_web/components/new-chat/model-selector.tsx @@ -1,16 +1,40 @@ "use client"; -import { useAtom, useAtomValue } from "jotai"; -import { Check, ChevronDown, Search, SlidersHorizontal } from "lucide-react"; -import { useRouter } from "next/navigation"; -import type { UIEvent } from "react"; -import { useCallback, useMemo, useState } from "react"; -import { updateModelRolesMutationAtom } from "@/atoms/model-connections/model-connections-mutation.atoms"; +import { useAtomValue } from "jotai"; import { - globalModelConnectionsAtom, - modelConnectionsAtom, - modelRolesAtom, -} from "@/atoms/model-connections/model-connections-query.atoms"; + Bot, + Check, + ChevronDown, + ChevronLeft, + ChevronRight, + ChevronUp, + ImageIcon, + Layers, + Pencil, + Plus, + ScanEye, + Search, + Zap, +} from "lucide-react"; +import type React from "react"; +import { Fragment, useCallback, useEffect, useMemo, useRef, useState } from "react"; +import { toast } from "sonner"; +import { pendingUserImageDataUrlsAtom } from "@/atoms/chat/pending-user-images.atom"; +import { + globalImageGenConfigsAtom, + imageGenConfigsAtom, +} from "@/atoms/image-gen-config/image-gen-config-query.atoms"; +import { updateLLMPreferencesMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; +import { + globalNewLLMConfigsAtom, + llmPreferencesAtom, + newLLMConfigsAtom, +} from "@/atoms/new-llm-config/new-llm-config-query.atoms"; +import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms"; +import { + globalVisionLLMConfigsAtom, + visionLLMConfigsAtom, +} from "@/atoms/vision-llm-config/vision-llm-config-query.atoms"; import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; import { @@ -21,284 +45,1389 @@ import { DrawerTitle, DrawerTrigger, } from "@/components/ui/drawer"; -import { Input } from "@/components/ui/input"; import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; import { Spinner } from "@/components/ui/spinner"; -import type { ConnectionRead, ModelRead } from "@/contracts/types/model-connections.types"; +import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; +import type { + GlobalImageGenConfig, + GlobalNewLLMConfig, + GlobalVisionLLMConfig, + ImageGenerationConfig, + NewLLMConfigPublic, + VisionLLMConfig, +} from "@/contracts/types/new-llm-config.types"; import { useIsMobile } from "@/hooks/use-mobile"; -import { AUTO_PROVIDER_ICON_KEY, getProviderIcon } from "@/lib/provider-icons"; +import { getProviderIcon } from "@/lib/provider-icons"; import { cn } from "@/lib/utils"; -import { providerDisplay } from "../settings/model-connections/provider-metadata"; -interface ModelSelectorProps { - searchSpaceId: number; - className?: string; - onChatModelSelected?: () => void; -} +// ─── Helpers ──────────────────────────────────────────────────────── -type ChatModel = ModelRead & { - connectionId: number; - connectionLabel: string; - connectionScope: string; - provider: string; +const PROVIDER_NAMES: Record = { + OPENAI: "OpenAI", + ANTHROPIC: "Anthropic", + GOOGLE: "Google", + AZURE: "Azure", + AZURE_OPENAI: "Azure OpenAI", + AWS_BEDROCK: "AWS Bedrock", + BEDROCK: "Bedrock", + DEEPSEEK: "DeepSeek", + MISTRAL: "Mistral", + COHERE: "Cohere", + GITHUB_MODELS: "GitHub Models", + GROQ: "Groq", + OLLAMA: "Ollama", + TOGETHER_AI: "Together AI", + FIREWORKS_AI: "Fireworks AI", + REPLICATE: "Replicate", + HUGGINGFACE: "HuggingFace", + PERPLEXITY: "Perplexity", + XAI: "xAI", + OPENROUTER: "OpenRouter", + CEREBRAS: "Cerebras", + SAMBANOVA: "SambaNova", + VERTEX_AI: "Vertex AI", + MINIMAX: "MiniMax", + MOONSHOT: "Moonshot", + ZHIPU: "Zhipu", + DEEPINFRA: "DeepInfra", + CLOUDFLARE: "Cloudflare", + DATABRICKS: "Databricks", + NSCALE: "NScale", + RECRAFT: "Recraft", + XINFERENCE: "XInference", + CUSTOM: "Custom", + AI21: "AI21", + ALIBABA_QWEN: "Qwen", + ANYSCALE: "Anyscale", + COMETAPI: "CometAPI", }; -const AUTO_CHAT_MODEL_ID = 0; +// Provider keys valid per model type, matching backend enums +// (LiteLLMProvider, ImageGenProvider, VisionProvider in db.py) +const LLM_PROVIDER_KEYS: string[] = [ + "OPENAI", + "ANTHROPIC", + "GOOGLE", + "AZURE_OPENAI", + "BEDROCK", + "VERTEX_AI", + "GROQ", + "DEEPSEEK", + "XAI", + "MISTRAL", + "COHERE", + "OPENROUTER", + "TOGETHER_AI", + "FIREWORKS_AI", + "REPLICATE", + "PERPLEXITY", + "OLLAMA", + "CEREBRAS", + "SAMBANOVA", + "DEEPINFRA", + "AI21", + "ALIBABA_QWEN", + "MOONSHOT", + "ZHIPU", + "MINIMAX", + "HUGGINGFACE", + "CLOUDFLARE", + "DATABRICKS", + "ANYSCALE", + "COMETAPI", + "GITHUB_MODELS", + "CUSTOM", +]; -function connectionLabel(connection: ConnectionRead) { - if (connection.scope === "GLOBAL") return "Global"; - return providerDisplay(connection.provider).name; -} +const IMAGE_PROVIDER_KEYS: string[] = [ + "OPENAI", + "AZURE_OPENAI", + "GOOGLE", + "VERTEX_AI", + "BEDROCK", + "RECRAFT", + "OPENROUTER", + "XINFERENCE", + "NSCALE", +]; -function flattenChatModels(connections: ConnectionRead[]) { - return connections.flatMap((connection) => - connection.models - .filter((model) => model.enabled && Boolean(model.supports_chat)) - .map((model) => ({ - ...model, - connectionId: connection.id, - connectionLabel: connectionLabel(connection), - connectionScope: connection.scope, - provider: connection.provider, - })) +const VISION_PROVIDER_KEYS: string[] = [ + "OPENAI", + "ANTHROPIC", + "GOOGLE", + "AZURE_OPENAI", + "VERTEX_AI", + "BEDROCK", + "XAI", + "OPENROUTER", + "OLLAMA", + "GROQ", + "TOGETHER_AI", + "FIREWORKS_AI", + "DEEPSEEK", + "MISTRAL", + "CUSTOM", +]; + +const PROVIDER_KEYS_BY_TAB: Record = { + llm: LLM_PROVIDER_KEYS, + image: IMAGE_PROVIDER_KEYS, + vision: VISION_PROVIDER_KEYS, +}; + +function formatProviderName(provider: string): string { + const key = provider.toUpperCase(); + return ( + PROVIDER_NAMES[key] ?? + provider.charAt(0).toUpperCase() + provider.slice(1).toLowerCase().replace(/_/g, " ") ); } -function isFreeGlobalModel(model: ChatModel) { - return model.connectionScope === "GLOBAL" && model.billing_tier?.toLowerCase() === "free"; +function normalizeText(input: string): string { + return input + .normalize("NFD") + .replace(/\p{Diacritic}/gu, "") + .toLowerCase() + .replace(/[^a-z0-9]+/g, " ") + .trim(); } -function modelName(model: ChatModel) { - const name = model.display_name || model.model_id; - if (model.connectionScope === "GLOBAL") { - return name.replace(/\s+\(free\)$/i, ""); +interface ConfigBase { + id: number; + name: string; + model_name: string; + provider: string; +} + +function filterAndScore( + configs: T[], + selectedProvider: string, + searchQuery: string +): T[] { + let result = configs; + + if (selectedProvider !== "all") { + result = result.filter((c) => c.provider.toUpperCase() === selectedProvider); } - return name; + + if (!searchQuery.trim()) return result; + + const normalized = normalizeText(searchQuery); + const tokens = normalized.split(/\s+/).filter(Boolean); + + const scored = result.map((c) => { + const aggregate = normalizeText([c.name, c.model_name, c.provider].join(" ")); + let score = 0; + if (aggregate.includes(normalized)) score += 5; + for (const token of tokens) { + if (aggregate.includes(token)) score += 1; + } + return { config: c, score }; + }); + + return scored + .filter((s) => s.score > 0) + .sort((a, b) => b.score - a.score) + .map((s) => s.config); } -function filterChatModels(models: ChatModel[], search: string) { - const normalized = search.trim().toLowerCase(); - if (!normalized) return models; - return models.filter((model) => - [modelName(model), model.model_id, model.connectionLabel] - .join(" ") - .toLowerCase() - .includes(normalized) +interface DisplayItem { + config: ConfigBase & Record; + isGlobal: boolean; + isAutoMode: boolean; +} + +const TruncatedNameWithTooltip: React.FC<{ + text: string; + className?: string; + enableTooltip: boolean; +}> = ({ text, className, enableTooltip }) => { + const textRef = useRef(null); + const openTimerRef = useRef(undefined); + const [isTruncated, setIsTruncated] = useState(false); + const [open, setOpen] = useState(false); + + const recalcTruncation = useCallback(() => { + const el = textRef.current; + if (!el) return; + setIsTruncated(el.scrollWidth > el.clientWidth + 1); + }, []); + + useEffect(() => { + if (!enableTooltip) return; + const el = textRef.current; + if (!el) return; + + const raf = requestAnimationFrame(recalcTruncation); + recalcTruncation(); + + const observer = new ResizeObserver(recalcTruncation); + observer.observe(el); + if (el.parentElement) observer.observe(el.parentElement); + window.addEventListener("resize", recalcTruncation); + + return () => { + cancelAnimationFrame(raf); + observer.disconnect(); + window.removeEventListener("resize", recalcTruncation); + }; + }, [enableTooltip, recalcTruncation]); + + useEffect(() => { + // Recompute when row text changes. + void text; + requestAnimationFrame(recalcTruncation); + }, [text, recalcTruncation]); + + useEffect( + () => () => { + if (openTimerRef.current) window.clearTimeout(openTimerRef.current); + }, + [] ); -} -function groupedModels(models: ChatModel[]) { - return models.reduce>((groups, model) => { - const key = model.connectionLabel; - if (!groups[key]) groups[key] = []; - groups[key].push(model); - return groups; - }, {}); + if (!enableTooltip) { + return ( + + {text} + + ); + } + + const handleOpenChange = (nextOpen: boolean) => { + if (openTimerRef.current) { + window.clearTimeout(openTimerRef.current); + openTimerRef.current = undefined; + } + if (!nextOpen) { + setOpen(false); + return; + } + if (!isTruncated) return; + openTimerRef.current = window.setTimeout(() => { + setOpen(true); + openTimerRef.current = undefined; + }, 220); + }; + + return ( + + + + {text} + + + + {text} + + + ); +}; + +// ─── Component ────────────────────────────────────────────────────── + +interface ModelSelectorProps { + onEditLLM: (config: NewLLMConfigPublic | GlobalNewLLMConfig, isGlobal: boolean) => void; + onAddNewLLM: (provider?: string) => void; + onEditImage?: (config: ImageGenerationConfig | GlobalImageGenConfig, isGlobal: boolean) => void; + onAddNewImage?: (provider?: string) => void; + onEditVision?: (config: VisionLLMConfig | GlobalVisionLLMConfig, isGlobal: boolean) => void; + onAddNewVision?: (provider?: string) => void; + className?: string; } export function ModelSelector({ - searchSpaceId, + onEditLLM, + onAddNewLLM, + onEditImage, + onAddNewImage, + onEditVision, + onAddNewVision, className, - onChatModelSelected, }: ModelSelectorProps) { - const router = useRouter(); - const isMobile = useIsMobile(); const [open, setOpen] = useState(false); - const [search, setSearch] = useState(""); - const [scrollPos, setScrollPos] = useState<"top" | "middle" | "bottom">("top"); - const [{ data: globalConnections = [], isLoading: globalLoading }] = useAtom( - globalModelConnectionsAtom - ); - const [{ data: connections = [], isLoading: connectionsLoading }] = useAtom(modelConnectionsAtom); - const [{ data: roles }] = useAtom(modelRolesAtom); - const updateRoles = useAtomValue(updateModelRolesMutationAtom); + const [activeTab, setActiveTab] = useState<"llm" | "image" | "vision">("llm"); + const [searchQuery, setSearchQuery] = useState(""); + const [selectedProvider, setSelectedProvider] = useState("all"); + const [focusedIndex, setFocusedIndex] = useState(-1); + const [modelScrollPos, setModelScrollPos] = useState<"top" | "middle" | "bottom">("top"); + const [sidebarScrollPos, setSidebarScrollPos] = useState<"top" | "middle" | "bottom">("top"); + const providerSidebarRef = useRef(null); + const modelListRef = useRef(null); + const searchInputRef = useRef(null); + const isMobile = useIsMobile(); - const allChatModels = useMemo( - () => flattenChatModels([...globalConnections, ...connections]), - [globalConnections, connections] + const handleOpenChange = useCallback( + (next: boolean) => { + if (next) { + setSearchQuery(""); + setSelectedProvider("all"); + if (!isMobile) { + requestAnimationFrame(() => searchInputRef.current?.focus()); + } + } + setOpen(next); + }, + [isMobile] ); - const visibleChatModels = useMemo( - () => filterChatModels(allChatModels, search), - [allChatModels, search] + const handleTabChange = useCallback( + (next: "llm" | "image" | "vision") => { + setActiveTab(next); + setSelectedProvider("all"); + setSearchQuery(""); + setFocusedIndex(-1); + setModelScrollPos("top"); + if (open && !isMobile) { + requestAnimationFrame(() => searchInputRef.current?.focus()); + } + }, + [open, isMobile] ); - const chatModelsById = useMemo( - () => new Map(allChatModels.map((model) => [model.id, model])), - [allChatModels] - ); - const selectedModelId = roles?.chat_model_id ?? AUTO_CHAT_MODEL_ID; - const selected = chatModelsById.get(selectedModelId); - const groups = useMemo(() => groupedModels(visibleChatModels), [visibleChatModels]); - const loading = globalLoading || connectionsLoading; - const hasSearchQuery = search.trim().length > 0; - function handleOpenChange(nextOpen: boolean) { - if (!nextOpen) setSearch(""); - setOpen(nextOpen); - } - - function selectModel(modelId: number) { - updateRoles.mutate({ chat_model_id: modelId }); - setSearch(""); - setOpen(false); - requestAnimationFrame(() => { - onChatModelSelected?.(); - }); - } - - function manageModelConnections() { - setOpen(false); - router.push(`/dashboard/${searchSpaceId}/search-space-settings/models`); - } - - const handleScroll = useCallback((event: UIEvent) => { - const el = event.currentTarget; + const handleModelListScroll = useCallback((e: React.UIEvent) => { + const el = e.currentTarget; const atTop = el.scrollTop <= 2; const atBottom = el.scrollHeight - el.scrollTop - el.clientHeight <= 2; - setScrollPos(atTop ? "top" : atBottom ? "bottom" : "middle"); + setModelScrollPos(atTop ? "top" : atBottom ? "bottom" : "middle"); }, []); - const content = ( -
-
-
- - setSearch(event.target.value)} - placeholder="Search chat models" - className="h-8 border-0 bg-transparent pl-6 text-sm shadow-none" - /> -
-
-
- - {loading ? ( -
- -
- ) : Object.keys(groups).length === 0 ? ( -
- {hasSearchQuery - ? "No matching chat models." - : "No enabled chat models. Add or enable models in Settings."} -
- ) : ( - Object.entries(groups).map(([connection, models]) => ( -
-
- {connection} -
- {models.map((model) => ( - - ))} -
- )) - )} -
-
- -
-
+ const handleSidebarScroll = useCallback( + (e: React.UIEvent) => { + const el = e.currentTarget; + if (isMobile) { + const atStart = el.scrollLeft <= 2; + const atEnd = el.scrollWidth - el.scrollLeft - el.clientWidth <= 2; + setSidebarScrollPos(atStart ? "top" : atEnd ? "bottom" : "middle"); + } else { + const atTop = el.scrollTop <= 2; + const atBottom = el.scrollHeight - el.scrollTop - el.clientHeight <= 2; + setSidebarScrollPos(atTop ? "top" : atBottom ? "bottom" : "middle"); + } + }, + [isMobile] ); - const trigger = ( + const scrollProviderSidebar = useCallback( + (direction: "backward" | "forward") => { + const el = providerSidebarRef.current; + if (!el) return; + const delta = isMobile + ? Math.max(56, Math.floor(el.clientWidth * 0.5)) + : Math.max(44, Math.floor(el.clientHeight * 0.4)); + + if (isMobile) { + el.scrollBy({ + left: direction === "backward" ? -delta : delta, + behavior: "smooth", + }); + return; + } + + el.scrollBy({ + top: direction === "backward" ? -delta : delta, + behavior: "smooth", + }); + }, + [isMobile] + ); + + // Cmd/Ctrl+M shortcut (desktop only) + useEffect(() => { + if (isMobile) return; + const handler = (e: KeyboardEvent) => { + if ((e.metaKey || e.ctrlKey) && e.key === "m") { + e.preventDefault(); + // setOpen((prev) => !prev); + handleOpenChange(!open); + } + }; + document.addEventListener("keydown", handler); + return () => document.removeEventListener("keydown", handler); + }, [isMobile, open, handleOpenChange]); + + // ─── Data ─── + const { data: llmUserConfigs, isLoading: llmUserLoading } = useAtomValue(newLLMConfigsAtom); + const { data: llmGlobalConfigs, isLoading: llmGlobalLoading } = + useAtomValue(globalNewLLMConfigsAtom); + const { data: preferences, isLoading: prefsLoading } = useAtomValue(llmPreferencesAtom); + const searchSpaceId = useAtomValue(activeSearchSpaceIdAtom); + const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom); + const { data: imageGlobalConfigs, isLoading: imageGlobalLoading } = + useAtomValue(globalImageGenConfigsAtom); + const { data: imageUserConfigs, isLoading: imageUserLoading } = useAtomValue(imageGenConfigsAtom); + const { data: visionGlobalConfigs, isLoading: visionGlobalLoading } = useAtomValue( + globalVisionLLMConfigsAtom + ); + const { data: visionUserConfigs, isLoading: visionUserLoading } = + useAtomValue(visionLLMConfigsAtom); + + // Pending image attachments on the composer. Used to surface an + // amber "No image" hint on chat models the catalog reports as + // non-vision (`supports_image_input=false`) when the next message + // will carry an image. The hint is purely advisory: selection, + // focus, and click handling are unaffected. The backend's safety + // net (`is_known_text_only_chat_model`) is the actual block, and + // it only fires when LiteLLM *explicitly* marks a model as + // text-only — so a model that's secretly capable but hasn't been + // annotated will still flow through to the provider. + const pendingUserImageUrls = useAtomValue(pendingUserImageDataUrlsAtom); + const hasPendingImages = pendingUserImageUrls.length > 0; + + const isLoading = + llmUserLoading || + llmGlobalLoading || + prefsLoading || + imageGlobalLoading || + imageUserLoading || + visionGlobalLoading || + visionUserLoading; + + // ─── Current selected configs ─── + const currentLLMConfig = useMemo(() => { + if (!preferences) return null; + const id = preferences.agent_llm_id; + if (id === null || id === undefined) return null; + if (id <= 0) return llmGlobalConfigs?.find((c) => c.id === id) ?? null; + return llmUserConfigs?.find((c) => c.id === id) ?? null; + }, [preferences, llmGlobalConfigs, llmUserConfigs]); + + const isLLMAutoMode = + currentLLMConfig && "is_auto_mode" in currentLLMConfig && currentLLMConfig.is_auto_mode; + + const currentImageConfig = useMemo(() => { + if (!preferences) return null; + const id = preferences.image_generation_config_id; + if (id === null || id === undefined) return null; + return ( + imageGlobalConfigs?.find((c) => c.id === id) ?? + imageUserConfigs?.find((c) => c.id === id) ?? + null + ); + }, [preferences, imageGlobalConfigs, imageUserConfigs]); + + const isImageAutoMode = + currentImageConfig && "is_auto_mode" in currentImageConfig && currentImageConfig.is_auto_mode; + + const currentVisionConfig = useMemo(() => { + if (!preferences) return null; + const id = preferences.vision_llm_config_id; + if (id === null || id === undefined) return null; + return ( + visionGlobalConfigs?.find((c) => c.id === id) ?? + visionUserConfigs?.find((c) => c.id === id) ?? + null + ); + }, [preferences, visionGlobalConfigs, visionUserConfigs]); + + const isVisionAutoMode = + currentVisionConfig && + "is_auto_mode" in currentVisionConfig && + currentVisionConfig.is_auto_mode; + + // ─── Filtered configs (separate global / user for section headers) ─── + const filteredLLMGlobal = useMemo( + () => filterAndScore(llmGlobalConfigs ?? [], selectedProvider, searchQuery), + [llmGlobalConfigs, selectedProvider, searchQuery] + ); + const filteredLLMUser = useMemo( + () => filterAndScore(llmUserConfigs ?? [], selectedProvider, searchQuery), + [llmUserConfigs, selectedProvider, searchQuery] + ); + const filteredImageGlobal = useMemo( + () => filterAndScore(imageGlobalConfigs ?? [], selectedProvider, searchQuery), + [imageGlobalConfigs, selectedProvider, searchQuery] + ); + const filteredImageUser = useMemo( + () => filterAndScore(imageUserConfigs ?? [], selectedProvider, searchQuery), + [imageUserConfigs, selectedProvider, searchQuery] + ); + const filteredVisionGlobal = useMemo( + () => filterAndScore(visionGlobalConfigs ?? [], selectedProvider, searchQuery), + [visionGlobalConfigs, selectedProvider, searchQuery] + ); + const filteredVisionUser = useMemo( + () => filterAndScore(visionUserConfigs ?? [], selectedProvider, searchQuery), + [visionUserConfigs, selectedProvider, searchQuery] + ); + + // Combined display list for keyboard navigation + const currentDisplayItems: DisplayItem[] = useMemo(() => { + const toItems = (configs: ConfigBase[], isGlobal: boolean): DisplayItem[] => + configs.map((c) => ({ + config: c as ConfigBase & Record, + isGlobal, + isAutoMode: + isGlobal && "is_auto_mode" in c && !!(c as Record).is_auto_mode, + })); + + const sortGlobalItems = (items: DisplayItem[]): DisplayItem[] => + [...items].sort((a, b) => { + if (a.isAutoMode !== b.isAutoMode) return a.isAutoMode ? -1 : 1; + const aPremium = !!(a.config as Record).is_premium; + const bPremium = !!(b.config as Record).is_premium; + if (aPremium !== bPremium) return aPremium ? 1 : -1; + return 0; + }); + + switch (activeTab) { + case "llm": + return [ + ...sortGlobalItems(toItems(filteredLLMGlobal, true)), + ...toItems(filteredLLMUser, false), + ]; + case "image": + return [ + ...sortGlobalItems(toItems(filteredImageGlobal, true)), + ...toItems(filteredImageUser, false), + ]; + case "vision": + return [ + ...sortGlobalItems(toItems(filteredVisionGlobal, true)), + ...toItems(filteredVisionUser, false), + ]; + } + }, [ + activeTab, + filteredLLMGlobal, + filteredLLMUser, + filteredImageGlobal, + filteredImageUser, + filteredVisionGlobal, + filteredVisionUser, + ]); + + // ─── Provider sidebar data ─── + // Collect which providers actually have configured models for the active tab + const configuredProviderSet = useMemo(() => { + const configs = + activeTab === "llm" + ? [...(llmGlobalConfigs ?? []), ...(llmUserConfigs ?? [])] + : activeTab === "image" + ? [...(imageGlobalConfigs ?? []), ...(imageUserConfigs ?? [])] + : [...(visionGlobalConfigs ?? []), ...(visionUserConfigs ?? [])]; + const set = new Set(); + for (const c of configs) { + if (c.provider) set.add(c.provider.toUpperCase()); + } + return set; + }, [ + activeTab, + llmGlobalConfigs, + llmUserConfigs, + imageGlobalConfigs, + imageUserConfigs, + visionGlobalConfigs, + visionUserConfigs, + ]); + + // Show only providers valid for the active tab; configured ones first + const activeProviders = useMemo(() => { + const tabKeys = PROVIDER_KEYS_BY_TAB[activeTab] ?? LLM_PROVIDER_KEYS; + const configured = tabKeys.filter((p) => configuredProviderSet.has(p)); + const unconfigured = tabKeys.filter((p) => !configuredProviderSet.has(p)); + return ["all", ...configured, ...unconfigured]; + }, [activeTab, configuredProviderSet]); + + const providerModelCounts = useMemo(() => { + const allConfigs = + activeTab === "llm" + ? [...(llmGlobalConfigs ?? []), ...(llmUserConfigs ?? [])] + : activeTab === "image" + ? [...(imageGlobalConfigs ?? []), ...(imageUserConfigs ?? [])] + : [...(visionGlobalConfigs ?? []), ...(visionUserConfigs ?? [])]; + const counts: Record = { all: allConfigs.length }; + for (const c of allConfigs) { + const p = c.provider.toUpperCase(); + counts[p] = (counts[p] || 0) + 1; + } + return counts; + }, [ + activeTab, + llmGlobalConfigs, + llmUserConfigs, + imageGlobalConfigs, + imageUserConfigs, + visionGlobalConfigs, + visionUserConfigs, + ]); + + // ─── Selection handlers ─── + const handleSelectLLM = useCallback( + async (config: NewLLMConfigPublic | GlobalNewLLMConfig) => { + if (currentLLMConfig?.id === config.id) { + setOpen(false); + return; + } + if (!searchSpaceId) { + toast.error("No search space selected"); + return; + } + try { + await updatePreferences({ + search_space_id: Number(searchSpaceId), + data: { agent_llm_id: config.id }, + }); + toast.success(`Switched to ${config.name}`); + setOpen(false); + } catch { + toast.error("Failed to switch model"); + } + }, + [currentLLMConfig, searchSpaceId, updatePreferences] + ); + + const handleSelectImage = useCallback( + async (configId: number) => { + if (currentImageConfig?.id === configId) { + setOpen(false); + return; + } + if (!searchSpaceId) { + toast.error("No search space selected"); + return; + } + try { + await updatePreferences({ + search_space_id: Number(searchSpaceId), + data: { image_generation_config_id: configId }, + }); + toast.success("Image model updated"); + setOpen(false); + } catch { + toast.error("Failed to switch image model"); + } + }, + [currentImageConfig, searchSpaceId, updatePreferences] + ); + + const handleSelectVision = useCallback( + async (configId: number) => { + if (currentVisionConfig?.id === configId) { + setOpen(false); + return; + } + if (!searchSpaceId) { + toast.error("No search space selected"); + return; + } + try { + await updatePreferences({ + search_space_id: Number(searchSpaceId), + data: { vision_llm_config_id: configId }, + }); + toast.success("Vision model updated"); + setOpen(false); + } catch { + toast.error("Failed to switch vision model"); + } + }, + [currentVisionConfig, searchSpaceId, updatePreferences] + ); + + const handleSelectItem = useCallback( + (item: DisplayItem) => { + switch (activeTab) { + case "llm": + handleSelectLLM(item.config as NewLLMConfigPublic | GlobalNewLLMConfig); + break; + case "image": + handleSelectImage(item.config.id); + break; + case "vision": + handleSelectVision(item.config.id); + break; + } + }, + [activeTab, handleSelectLLM, handleSelectImage, handleSelectVision] + ); + + const handleEditItem = useCallback( + (e: React.MouseEvent, item: DisplayItem) => { + e.stopPropagation(); + setOpen(false); + switch (activeTab) { + case "llm": + onEditLLM(item.config as NewLLMConfigPublic | GlobalNewLLMConfig, item.isGlobal); + break; + case "image": + onEditImage?.(item.config as ImageGenerationConfig | GlobalImageGenConfig, item.isGlobal); + break; + case "vision": + onEditVision?.(item.config as VisionLLMConfig | GlobalVisionLLMConfig, item.isGlobal); + break; + } + }, + [activeTab, onEditLLM, onEditImage, onEditVision] + ); + + // ─── Keyboard navigation ─── + // biome-ignore lint/correctness/useExhaustiveDependencies: searchQuery and selectedProvider are intentional triggers to reset focus + useEffect(() => { + setFocusedIndex(-1); + }, [searchQuery, selectedProvider]); + + useEffect(() => { + if (focusedIndex < 0 || !modelListRef.current) return; + const items = modelListRef.current.querySelectorAll("[data-model-index]"); + items[focusedIndex]?.scrollIntoView({ + block: "nearest", + behavior: "smooth", + }); + }, [focusedIndex]); + + const handleKeyDown = useCallback( + (e: React.KeyboardEvent) => { + const count = currentDisplayItems.length; + + // Arrow Left/Right cycle provider filters + if (e.key === "ArrowLeft" || e.key === "ArrowRight") { + e.preventDefault(); + const providers = activeProviders; + const idx = providers.indexOf(selectedProvider); + let next: number; + if (e.key === "ArrowLeft") { + next = idx > 0 ? idx - 1 : providers.length - 1; + } else { + next = idx < providers.length - 1 ? idx + 1 : 0; + } + setSelectedProvider(providers[next]); + if (providerSidebarRef.current) { + const buttons = providerSidebarRef.current.querySelectorAll("button"); + buttons[next]?.scrollIntoView({ + block: "nearest", + inline: "nearest", + behavior: "smooth", + }); + } + return; + } + + if (count === 0) return; + + switch (e.key) { + case "ArrowDown": + e.preventDefault(); + setFocusedIndex((prev) => (prev < count - 1 ? prev + 1 : 0)); + break; + case "ArrowUp": + e.preventDefault(); + setFocusedIndex((prev) => (prev > 0 ? prev - 1 : count - 1)); + break; + case "Enter": + e.preventDefault(); + if (focusedIndex >= 0 && focusedIndex < count) { + handleSelectItem(currentDisplayItems[focusedIndex]); + } + break; + case "Home": + e.preventDefault(); + setFocusedIndex(0); + break; + case "End": + e.preventDefault(); + setFocusedIndex(count - 1); + break; + } + }, + [currentDisplayItems, focusedIndex, activeProviders, selectedProvider, handleSelectItem] + ); + + // ─── Render: Provider sidebar ─── + const renderProviderSidebar = () => { + const configuredCount = configuredProviderSet.size; + + return ( +
+ {!isMobile && ( +
+ +
+ )} + {isMobile && ( +
+ +
+ )} +
+ {activeProviders.map((provider, idx) => { + const isAll = provider === "all"; + const isActive = selectedProvider === provider; + const count = providerModelCounts[provider] || 0; + const isConfigured = isAll || configuredProviderSet.has(provider); + + // Separator between configured and unconfigured providers + // idx 0 is "all", configured run from 1..configuredCount, unconfigured start at configuredCount+1 + const showSeparator = !isAll && idx === configuredCount + 1 && configuredCount > 0; + + return ( + + {showSeparator && + (isMobile ? ( +
+ ) : ( +
+ ))} + + + + + + {isAll ? "All Models" : formatProviderName(provider)} + {isConfigured ? ` (${count})` : " (not configured)"} + + + + ); + })} +
+ {!isMobile && ( +
+ +
+ )} + {isMobile && ( +
+ +
+ )} +
+ ); + }; + + // ─── Render: Model card ─── + const getSelectedId = () => { + switch (activeTab) { + case "llm": + return currentLLMConfig?.id; + case "image": + return currentImageConfig?.id; + case "vision": + return currentVisionConfig?.id; + } + }; + + const renderModelCard = (item: DisplayItem, index: number) => { + const { config, isAutoMode } = item; + const isSelected = getSelectedId() === config.id; + const isFocused = focusedIndex === index; + const hasCitations = "citations_enabled" in config && !!config.citations_enabled; + const hasPremiumStatus = "is_premium" in config && !isAutoMode; + const isPremium = hasPremiumStatus && !!(config as Record).is_premium; + // Chat-tab only: surface an amber "No image" hint when the + // composer carries images and the catalog reports the model as + // non-vision. This is purely advisory — selection is *not* + // blocked. The backend's narrow safety net + // (`is_known_text_only_chat_model`) is the source of truth for + // rejecting image turns, and it only fires when LiteLLM + // explicitly marks the model as text-only. A model surfaced as + // `supports_image_input=false` here may still be capable in + // practice (unknown / unmapped LiteLLM entry), so we let the + // user pick it and the provider response decide. + const isImageIncompatibleChatModel = + activeTab === "llm" && + hasPendingImages && + "supports_image_input" in config && + (config as Record).supports_image_input === false; + + return ( +
handleSelectItem(item)} + onKeyDown={ + isMobile + ? undefined + : (e) => { + if (e.key === "Enter" || e.key === " ") { + e.preventDefault(); + handleSelectItem(item); + } + } + } + onMouseEnter={() => setFocusedIndex(index)} + className={cn( + "group flex items-center gap-2.5 px-3 py-2 rounded-xl", + "transition-colors duration-150 mx-2 cursor-pointer", + "hover:bg-accent hover:text-accent-foreground", + isFocused && "bg-accent text-accent-foreground", + isSelected && "bg-accent text-accent-foreground" + )} + > + {/* Provider icon */} +
+ {getProviderIcon(config.provider as string, { + isAutoMode, + className: "size-5", + })} +
+ + {/* Model info */} +
+
+ + {isAutoMode && ( + + Recommended + + )} + {isImageIncompatibleChatModel && ( + + No image + + )} +
+ {isAutoMode ? ( +
+ Auto Mode +
+ ) : ( + (hasPremiumStatus || hasCitations) && ( +
+ {hasPremiumStatus && ( + + {isPremium ? "Premium" : "Free"} + + )} + {hasCitations && ( + + Citations + + )} +
+ ) + )} +
+ + {/* Actions */} +
+ {!isAutoMode && ( + + )} + {isSelected && ( +
+ +
+ )} +
+
+ ); + }; + + // ─── Render: Full content ─── + const renderContent = () => { + const globalItems = currentDisplayItems.filter((i) => i.isGlobal); + const userItems = currentDisplayItems.filter((i) => !i.isGlobal); + const globalStartIdx = 0; + const userStartIdx = globalItems.length; + + const addHandler = + activeTab === "llm" ? onAddNewLLM : activeTab === "image" ? onAddNewImage : onAddNewVision; + const addLabel = + activeTab === "llm" + ? "Add Model" + : activeTab === "image" + ? "Add Image Model" + : "Add Vision Model"; + + return ( +
+ {/* Tab header */} +
+
+ {( + [ + { + value: "llm" as const, + icon: Zap, + label: "LLM", + }, + { + value: "image" as const, + icon: ImageIcon, + label: "Image", + }, + { + value: "vision" as const, + icon: ScanEye, + label: "Vision", + }, + ] as const + ).map(({ value, icon: Icon, label }) => ( + + ))} +
+
+ + {/* Two-pane layout */} +
+ {/* Provider sidebar */} + {renderProviderSidebar()} + + {/* Main content */} +
+ {/* Search */} +
+ + setSearchQuery(e.target.value)} + onKeyDown={isMobile ? undefined : handleKeyDown} + role="combobox" + aria-expanded={true} + aria-controls="model-selector-list" + className={cn( + "w-full pl-8 pr-3 py-2.5 text-sm bg-transparent", + "focus:outline-none", + "placeholder:text-muted-foreground" + )} + /> +
+ + {/* Provider header when filtered */} + {selectedProvider !== "all" && ( +
+ {getProviderIcon(selectedProvider, { + className: "size-4", + })} + {formatProviderName(selectedProvider)} + + {configuredProviderSet.has(selectedProvider) + ? `${providerModelCounts[selectedProvider] || 0} models` + : "Not configured"} + +
+ )} + + {/* Model list */} +
+ {currentDisplayItems.length === 0 ? ( +
+ {selectedProvider !== "all" && !configuredProviderSet.has(selectedProvider) ? ( + <> +
+ {getProviderIcon(selectedProvider, { + className: "size-10", + })} +
+

+ No {formatProviderName(selectedProvider)} models configured +

+

+ Add a model with this provider to get started +

+ {addHandler && ( + + )} + + ) : searchQuery ? ( + <> + +

No models found

+

+ Try a different search term +

+ + ) : ( + <> +

+ No models configured +

+

+ Configure models in your search space settings +

+ + )} +
+ ) : ( + <> + {globalItems.length > 0 && ( + <> +
+ Global Models +
+ {globalItems.map((item, i) => renderModelCard(item, globalStartIdx + i))} + + )} + {globalItems.length > 0 && userItems.length > 0 && ( +
+ )} + {userItems.length > 0 && ( + <> +
+ Your Configurations +
+ {userItems.map((item, i) => renderModelCard(item, userStartIdx + i))} + + )} + + )} +
+ + {/* Add model button */} + {addHandler && ( +
+ +
+ )} +
+
+
+ ); + }; + + // ─── Trigger button ─── + const triggerButton = ( ); + // ─── Shell: Drawer on mobile, Popover on desktop ─── if (isMobile) { return ( - {trigger} + {triggerButton} - - Select Chat Model + + Select Model - {content} +
{renderContent()}
); @@ -306,12 +1435,14 @@ export function ModelSelector({ return ( - {trigger} + {triggerButton} e.preventDefault()} > - {content} + {renderContent()} ); diff --git a/surfsense_web/components/providers/ZeroProvider.tsx b/surfsense_web/components/providers/ZeroProvider.tsx index 35d51311a..5bb43db99 100644 --- a/surfsense_web/components/providers/ZeroProvider.tsx +++ b/surfsense_web/components/providers/ZeroProvider.tsx @@ -12,15 +12,7 @@ import { getBearerToken, handleUnauthorized, refreshAccessToken } from "@/lib/au import { queries } from "@/zero/queries"; import { schema } from "@/zero/schema"; -const configuredCacheURL = process.env.NEXT_PUBLIC_ZERO_CACHE_URL; - -function getCacheURL() { - if (configuredCacheURL) return configuredCacheURL; - if (typeof window !== "undefined") { - return `${window.location.origin}/zero`; - } - return "http://localhost:4848"; -} +const cacheURL = process.env.NEXT_PUBLIC_ZERO_CACHE_URL || "http://localhost:4848"; function ZeroAuthSync() { const zero = useZero(); @@ -50,7 +42,6 @@ function ZeroAuthSync() { export function ZeroProvider({ children }: { children: React.ReactNode }) { const { data: user } = useAtomValue(currentUserAtom); - const cacheURL = useMemo(() => getCacheURL(), []); const userId = user?.id; const hasUser = !!userId; @@ -74,7 +65,7 @@ export function ZeroProvider({ children }: { children: React.ReactNode }) { cacheURL, auth, }), - [userID, context, cacheURL, auth] + [userID, context, auth] ); return ( diff --git a/surfsense_web/components/providers/runtime-config.server.tsx b/surfsense_web/components/providers/runtime-config.server.tsx deleted file mode 100644 index c515820c2..000000000 --- a/surfsense_web/components/providers/runtime-config.server.tsx +++ /dev/null @@ -1,19 +0,0 @@ -import { connection } from "next/server"; -import { RuntimeConfigProvider } from "@/components/providers/runtime-config"; -import { - BUILD_TIME_AUTH_TYPE, - BUILD_TIME_DEPLOYMENT_MODE, - BUILD_TIME_ETL_SERVICE, -} from "@/lib/env-config"; - -export async function RuntimeConfig({ children }: { children: React.ReactNode }) { - await connection(); - - const value = { - authType: process.env.AUTH_TYPE ?? BUILD_TIME_AUTH_TYPE, - etlService: process.env.ETL_SERVICE ?? BUILD_TIME_ETL_SERVICE, - deploymentMode: process.env.DEPLOYMENT_MODE ?? BUILD_TIME_DEPLOYMENT_MODE, - }; - - return {children}; -} diff --git a/surfsense_web/components/providers/runtime-config.tsx b/surfsense_web/components/providers/runtime-config.tsx deleted file mode 100644 index 560acd597..000000000 --- a/surfsense_web/components/providers/runtime-config.tsx +++ /dev/null @@ -1,48 +0,0 @@ -"use client"; - -import { createContext, useContext } from "react"; - -export type AuthType = "LOCAL" | "GOOGLE" | string; -export type DeploymentMode = "self-hosted" | "cloud" | string; - -export interface RuntimeConfigValue { - authType: AuthType; - etlService: string; - deploymentMode: DeploymentMode; -} - -const RuntimeConfigContext = createContext(null); - -export function RuntimeConfigProvider({ - value, - children, -}: { - value: RuntimeConfigValue; - children: React.ReactNode; -}) { - return {children}; -} - -export function useRuntimeConfig() { - const context = useContext(RuntimeConfigContext); - if (!context) { - throw new Error("useRuntimeConfig must be used within RuntimeConfigProvider"); - } - return context; -} - -export function useIsLocalAuth() { - return useRuntimeConfig().authType === "LOCAL"; -} - -export function useIsGoogleAuth() { - return useRuntimeConfig().authType === "GOOGLE"; -} - -export function useIsSelfHosted() { - return useRuntimeConfig().deploymentMode === "self-hosted"; -} - -export function useIsCloud() { - return useRuntimeConfig().deploymentMode === "cloud"; -} diff --git a/surfsense_web/components/report-panel/report-panel.tsx b/surfsense_web/components/report-panel/report-panel.tsx index 53b0c9867..682235e0f 100644 --- a/surfsense_web/components/report-panel/report-panel.tsx +++ b/surfsense_web/components/report-panel/report-panel.tsx @@ -22,7 +22,7 @@ import { Spinner } from "@/components/ui/spinner"; import { useMediaQuery } from "@/hooks/use-media-query"; import { baseApiService } from "@/lib/apis/base-api.service"; import { authenticatedFetch } from "@/lib/auth-utils"; -import { buildBackendUrl } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; function ReportPanelSkeleton() { return ( @@ -245,7 +245,7 @@ export function ReportPanelContent({ URL.revokeObjectURL(url); } else { const response = await authenticatedFetch( - buildBackendUrl(`/api/v1/reports/${activeReportId}/export`, { format }), + `${BACKEND_URL}/api/v1/reports/${activeReportId}/export?format=${format}`, { method: "GET" } ); @@ -278,7 +278,7 @@ export function ReportPanelContent({ setSaving(true); try { const response = await authenticatedFetch( - buildBackendUrl(`/api/v1/reports/${activeReportId}/content`), + `${BACKEND_URL}/api/v1/reports/${activeReportId}/content`, { method: "PUT", headers: { "Content-Type": "application/json" }, @@ -506,11 +506,7 @@ export function ReportPanelContent({
) : reportContent.content_type === "typst" ? ( diff --git a/surfsense_web/components/settings/agent-model-manager.tsx b/surfsense_web/components/settings/agent-model-manager.tsx new file mode 100644 index 000000000..507a263e0 --- /dev/null +++ b/surfsense_web/components/settings/agent-model-manager.tsx @@ -0,0 +1,423 @@ +"use client"; + +import { useAtomValue } from "jotai"; +import { AlertCircle, Dot, FileText, Info, Pencil, RefreshCw, Trash2 } from "lucide-react"; +import { useMemo, useState } from "react"; +import { membersAtom, myAccessAtom } from "@/atoms/members/members-query.atoms"; +import { deleteNewLLMConfigMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; +import { + globalNewLLMConfigsAtom, + newLLMConfigsAtom, +} from "@/atoms/new-llm-config/new-llm-config-query.atoms"; +import { ModelConfigDialog } from "@/components/shared/model-config-dialog"; +import { Alert, AlertDescription } from "@/components/ui/alert"; +import { + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, +} from "@/components/ui/alert-dialog"; +import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar"; +import { Badge } from "@/components/ui/badge"; +import { Button } from "@/components/ui/button"; +import { Card, CardContent } from "@/components/ui/card"; +import { Separator } from "@/components/ui/separator"; +import { Skeleton } from "@/components/ui/skeleton"; +import { Spinner } from "@/components/ui/spinner"; +import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"; +import type { NewLLMConfig } from "@/contracts/types/new-llm-config.types"; +import { useMediaQuery } from "@/hooks/use-media-query"; +import { getProviderIcon } from "@/lib/provider-icons"; +import { cn } from "@/lib/utils"; + +interface AgentModelManagerProps { + searchSpaceId: number; +} + +function getInitials(name: string): string { + const parts = name.trim().split(/\s+/); + if (parts.length >= 2) { + return (parts[0][0] + parts[1][0]).toUpperCase(); + } + return name.slice(0, 2).toUpperCase(); +} + +export function AgentModelManager({ searchSpaceId }: AgentModelManagerProps) { + const isDesktop = useMediaQuery("(min-width: 768px)"); + // Mutations + const { mutateAsync: deleteConfig, isPending: isDeleting } = useAtomValue( + deleteNewLLMConfigMutationAtom + ); + + // Queries + const { + data: configs, + isFetching: isLoading, + error: fetchError, + refetch: refreshConfigs, + } = useAtomValue(newLLMConfigsAtom); + const { data: globalConfigs = [] } = useAtomValue(globalNewLLMConfigsAtom); + + // Members for user resolution + const { data: members } = useAtomValue(membersAtom); + const memberMap = useMemo(() => { + const map = new Map(); + if (members) { + for (const m of members) { + map.set(m.user_id, { + name: m.user_display_name || m.user_email || "Unknown", + email: m.user_email || undefined, + avatarUrl: m.user_avatar_url || undefined, + }); + } + } + return map; + }, [members]); + + // Permissions + const { data: access } = useAtomValue(myAccessAtom); + const canCreate = + !!access && (access.is_owner || (access.permissions?.includes("llm_configs:create") ?? false)); + const canUpdate = + !!access && (access.is_owner || (access.permissions?.includes("llm_configs:update") ?? false)); + const canDelete = + !!access && (access.is_owner || (access.permissions?.includes("llm_configs:delete") ?? false)); + const isReadOnly = !canCreate && !canUpdate && !canDelete; + + // Local state + const [isDialogOpen, setIsDialogOpen] = useState(false); + const [editingConfig, setEditingConfig] = useState(null); + const [configToDelete, setConfigToDelete] = useState(null); + + const handleDelete = async () => { + if (!configToDelete) return; + try { + await deleteConfig({ id: configToDelete.id, name: configToDelete.name }); + setConfigToDelete(null); + } catch { + // Error handled by mutation state + } + }; + + const openEditDialog = (config: NewLLMConfig) => { + setEditingConfig(config); + setIsDialogOpen(true); + }; + + const openNewDialog = () => { + setEditingConfig(null); + setIsDialogOpen(true); + }; + + return ( +
+ {/* Header actions */} +
+ + {canCreate && ( + + )} +
+ + {/* Fetch Error Alert */} + {fetchError && ( +
+ + + + {fetchError?.message ?? "Failed to load configurations"} + + +
+ )} + + {/* Read-only / Limited permissions notice */} + {access && !isLoading && isReadOnly && ( +
+ + + +

+ You have read-only access to LLM + configurations. Contact a space owner to request additional permissions. +

+
+
+
+ )} + {access && !isLoading && !isReadOnly && (!canCreate || !canUpdate || !canDelete) && ( +
+ + + +

+ You can{" "} + {[canCreate && "create", canUpdate && "edit", canDelete && "delete"] + .filter(Boolean) + .join(" and ")}{" "} + configurations + {!canDelete && ", but cannot delete them"}. +

+
+
+
+ )} + + {/* Global Configs Info */} + {(isLoading || globalConfigs.length > 0) && ( + + + + {isLoading ? ( +
+ +
+ ) : ( +

+ + {globalConfigs.length} global {globalConfigs.length === 1 ? "model" : "models"} + {" "} + available from your administrator. +

+ )} +
+
+ )} + + {/* Loading Skeleton */} + {isLoading && ( +
+ {["skeleton-a", "skeleton-b", "skeleton-c"].map((key) => ( + + + + + + + + ))} +
+ )} + + {/* Configurations List */} + {!isLoading && ( +
+ {configs?.length === 0 ? ( +
+ + +

No Models Yet

+

+ {canCreate + ? "Add your first model to power chat, reports, and other agent capabilities" + : "No models have been added to this space yet. Contact a space owner to add one"} +

+
+
+
+ ) : ( +
+ {configs?.map((config) => { + const member = config.user_id ? memberMap.get(config.user_id) : null; + + return ( +
+ + + {/* Header: Icon + Name + Actions */} +
+
+
+ {getProviderIcon(config.provider, { className: "size-4" })} +
+
+

+ {config.name} +

+ {config.description && ( +

+ {config.description} +

+ )} +
+
+ {(canUpdate || canDelete) && ( +
+ {canUpdate && ( + + + + + + Edit + + + )} + {canDelete && ( + + + + + + Delete + + + )} +
+ )} +
+ + {/* Feature badges */} +
+ {config.citations_enabled && ( + + Citations + + )} + {!config.use_default_system_instructions && + config.system_instructions && ( + + + Custom + + )} +
+ + {/* Footer: Date + Creator */} +
+ +
+ + {new Date(config.created_at).toLocaleDateString(undefined, { + year: "numeric", + month: "short", + day: "numeric", + })} + + {member && ( + <> + + + + +
+ + {member.avatarUrl && ( + + )} + + {getInitials(member.name)} + + + + {member.name} + +
+
+ + {member.email || member.name} + +
+
+ + )} +
+
+
+
+
+ ); + })} +
+ )} +
+ )} + + {/* Add/Edit Configuration Dialog */} + { + setIsDialogOpen(open); + if (!open) setEditingConfig(null); + }} + config={editingConfig} + isGlobal={false} + searchSpaceId={searchSpaceId} + mode={editingConfig ? "edit" : "create"} + /> + + {/* Delete Confirmation Dialog */} + !open && setConfigToDelete(null)} + > + + + Delete Model + + Are you sure you want to delete{" "} + {configToDelete?.name}? This + action cannot be undone. + + + + Cancel + + {isDeleting ? ( + <> + + Deleting + + ) : ( + "Delete" + )} + + + + +
+ ); +} diff --git a/surfsense_web/components/settings/general-settings-manager.tsx b/surfsense_web/components/settings/general-settings-manager.tsx index 68ff21f07..a308acfad 100644 --- a/surfsense_web/components/settings/general-settings-manager.tsx +++ b/surfsense_web/components/settings/general-settings-manager.tsx @@ -12,7 +12,7 @@ import { Label } from "@/components/ui/label"; import { Skeleton } from "@/components/ui/skeleton"; import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service"; import { authenticatedFetch } from "@/lib/auth-utils"; -import { buildBackendUrl } from "@/lib/env-config"; +import { BACKEND_URL } from "@/lib/env-config"; import { cacheKeys } from "@/lib/query-client/cache-keys"; import { Spinner } from "../ui/spinner"; @@ -49,7 +49,7 @@ export function GeneralSettingsManager({ searchSpaceId }: GeneralSettingsManager setIsExporting(true); try { const response = await authenticatedFetch( - buildBackendUrl(`/api/v1/search-spaces/${searchSpaceId}/export`), + `${BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/export`, { method: "GET" } ); if (!response.ok) { diff --git a/surfsense_web/components/settings/image-model-manager.tsx b/surfsense_web/components/settings/image-model-manager.tsx new file mode 100644 index 000000000..494f7aae9 --- /dev/null +++ b/surfsense_web/components/settings/image-model-manager.tsx @@ -0,0 +1,489 @@ +"use client"; + +import { useAtomValue } from "jotai"; +import { AlertCircle, Dot, Info, Pencil, RefreshCw, Trash2 } from "lucide-react"; +import { useMemo, useState } from "react"; +import { deleteImageGenConfigMutationAtom } from "@/atoms/image-gen-config/image-gen-config-mutation.atoms"; +import { + globalImageGenConfigsAtom, + imageGenConfigsAtom, +} from "@/atoms/image-gen-config/image-gen-config-query.atoms"; +import { membersAtom, myAccessAtom } from "@/atoms/members/members-query.atoms"; +import { ImageConfigDialog } from "@/components/shared/image-config-dialog"; +import { Alert, AlertDescription } from "@/components/ui/alert"; +import { + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, +} from "@/components/ui/alert-dialog"; +import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar"; +import { Badge } from "@/components/ui/badge"; +import { Button } from "@/components/ui/button"; +import { Card, CardContent } from "@/components/ui/card"; +import { Separator } from "@/components/ui/separator"; +import { Skeleton } from "@/components/ui/skeleton"; +import { Spinner } from "@/components/ui/spinner"; +import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"; +import type { ImageGenerationConfig } from "@/contracts/types/new-llm-config.types"; +import { useMediaQuery } from "@/hooks/use-media-query"; +import { getProviderIcon } from "@/lib/provider-icons"; +import { cn } from "@/lib/utils"; + +interface ImageModelManagerProps { + searchSpaceId: number; +} + +function getInitials(name: string): string { + const parts = name.trim().split(/\s+/); + if (parts.length >= 2) { + return (parts[0][0] + parts[1][0]).toUpperCase(); + } + return name.slice(0, 2).toUpperCase(); +} + +export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) { + const isDesktop = useMediaQuery("(min-width: 768px)"); + + const { + mutateAsync: deleteConfig, + isPending: isDeleting, + error: deleteError, + } = useAtomValue(deleteImageGenConfigMutationAtom); + + const { + data: userConfigs, + isFetching: configsLoading, + error: fetchError, + refetch: refreshConfigs, + } = useAtomValue(imageGenConfigsAtom); + const { data: globalConfigs = [], isFetching: globalLoading } = + useAtomValue(globalImageGenConfigsAtom); + + const { data: members } = useAtomValue(membersAtom); + const memberMap = useMemo(() => { + const map = new Map(); + if (members) { + for (const m of members) { + map.set(m.user_id, { + name: m.user_display_name || m.user_email || "Unknown", + email: m.user_email || undefined, + avatarUrl: m.user_avatar_url || undefined, + }); + } + } + return map; + }, [members]); + + const { data: access } = useAtomValue(myAccessAtom); + const canCreate = + !!access && + (access.is_owner || (access.permissions?.includes("image_generations:create") ?? false)); + const canDelete = + !!access && + (access.is_owner || (access.permissions?.includes("image_generations:delete") ?? false)); + const canUpdate = canCreate; + const isReadOnly = !canCreate && !canDelete; + + const [isDialogOpen, setIsDialogOpen] = useState(false); + const [editingConfig, setEditingConfig] = useState(null); + const [configToDelete, setConfigToDelete] = useState(null); + + const isLoading = configsLoading || globalLoading; + const errors = [deleteError, fetchError].filter(Boolean) as Error[]; + + const openEditDialog = (config: ImageGenerationConfig) => { + setEditingConfig(config); + setIsDialogOpen(true); + }; + + const openNewDialog = () => { + setEditingConfig(null); + setIsDialogOpen(true); + }; + + const handleDelete = async () => { + if (!configToDelete) return; + try { + await deleteConfig({ id: configToDelete.id, name: configToDelete.name }); + setConfigToDelete(null); + } catch { + // Error handled by mutation + } + }; + + return ( +
+ {/* Header actions */} +
+ + {canCreate && ( + + )} +
+ + {/* Errors */} + {errors.map((err) => ( +
+ + + {err?.message} + +
+ ))} + + {/* Read-only / Limited permissions notice */} + {access && !isLoading && isReadOnly && ( +
+ + + +

+ You have read-only access to image generation + configurations. Contact a space owner to request additional permissions. +

+
+
+
+ )} + {access && !isLoading && !isReadOnly && (!canCreate || !canDelete) && ( +
+ + + +

+ You can{" "} + {[canCreate && "create and edit", canDelete && "delete"] + .filter(Boolean) + .join(" and ")}{" "} + image model configurations + {!canDelete && ", but cannot delete them"}. +

+
+
+
+ )} + + {/* Global info */} + {(isLoading || + globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length > 0) && ( + + + + {isLoading ? ( +
+ +
+ ) : ( +

+ + {globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length}{" "} + global image{" "} + {globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length === + 1 + ? "model" + : "models"} + {" "} + available from your administrator. {(() => { + const nonAuto = globalConfigs.filter( + (g) => !("is_auto_mode" in g && g.is_auto_mode) + ); + const premium = nonAuto.filter( + (g) => + "billing_tier" in g && + (g as { billing_tier?: string }).billing_tier === "premium" + ).length; + const free = nonAuto.length - premium; + if (premium > 0 && free > 0) { + return `${premium} premium, ${free} free.`; + } + if (premium > 0) { + return `All ${premium} premium — debits your shared credit pool.`; + } + return `All ${free} free.`; + })()} +

+ )} +
+
+ )} + + {/* Global Image Models — read-only cards with per-model Free/Premium + badges. Mirrors the badge palette used by the chat role selector + (`llm-role-manager.tsx`) so the meaning is consistent across + every model-configuration surface (chat / image / vision). */} + {!isLoading && + globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length > 0 && ( +
+
+ {globalConfigs + .filter((g) => !("is_auto_mode" in g && g.is_auto_mode)) + .map((cfg) => { + const billingTier = + ("billing_tier" in cfg && + typeof (cfg as { billing_tier?: string }).billing_tier === "string" && + (cfg as { billing_tier?: string }).billing_tier) || + "free"; + const isPremium = billingTier === "premium"; + return ( + + +
+
+ {getProviderIcon(cfg.provider, { className: "size-4" })} +
+
+

+ {cfg.name} +

+ {isPremium ? ( + + Premium + + ) : ( + + Free + + )} +
+
+ {cfg.description && ( +

+ {cfg.description} +

+ )} +
+ +
+ + {cfg.model_name} + +
+
+
+
+ ); + })} +
+
+ )} + + {/* Loading Skeleton */} + {isLoading && ( +
+
+
+ {["skeleton-a", "skeleton-b", "skeleton-c"].map((key) => ( + + + + + + + + ))} +
+
+
+ )} + + {/* User Configs */} + {!isLoading && ( +
+ {(userConfigs?.length ?? 0) === 0 ? ( + + +

No Image Models Yet

+

+ {canCreate + ? "Add your own image generation model (DALL-E 3, GPT Image 1, etc.)" + : "No image models have been added to this space yet. Contact a space owner to add one."} +

+
+
+ ) : ( +
+ {userConfigs?.map((config) => { + const member = config.user_id ? memberMap.get(config.user_id) : null; + + return ( +
+ + + {/* Header: Icon + Name + Actions */} +
+
+
+ {getProviderIcon(config.provider, { className: "size-4" })} +
+
+

+ {config.name} +

+ {config.description && ( +

+ {config.description} +

+ )} +
+
+ {(canUpdate || canDelete) && ( +
+ {canUpdate && ( + + + + + + Edit + + + )} + {canDelete && ( + + + + + + Delete + + + )} +
+ )} +
+ + {/* Footer: Date + Creator */} +
+ +
+ + {new Date(config.created_at).toLocaleDateString(undefined, { + year: "numeric", + month: "short", + day: "numeric", + })} + + {member && ( + <> + + + + +
+ + {member.avatarUrl && ( + + )} + + {getInitials(member.name)} + + + + {member.name} + +
+
+ + {member.email || member.name} + +
+
+ + )} +
+
+
+
+
+ ); + })} +
+ )} +
+ )} + + {/* Create/Edit Dialog — shared component */} + { + setIsDialogOpen(open); + if (!open) setEditingConfig(null); + }} + config={editingConfig} + isGlobal={false} + searchSpaceId={searchSpaceId} + mode={editingConfig ? "edit" : "create"} + /> + + {/* Delete Confirmation */} + !open && setConfigToDelete(null)} + > + + + Delete Image Model + + Are you sure you want to delete{" "} + {configToDelete?.name}? + + + + Cancel + + Delete + {isDeleting && } + + + + +
+ ); +} diff --git a/surfsense_web/components/settings/llm-role-manager.tsx b/surfsense_web/components/settings/llm-role-manager.tsx new file mode 100644 index 000000000..c32e79a8e --- /dev/null +++ b/surfsense_web/components/settings/llm-role-manager.tsx @@ -0,0 +1,443 @@ +"use client"; + +import { useAtomValue } from "jotai"; +import { + AlertCircle, + Bot, + CircleCheck, + CircleDashed, + FileText, + ImageIcon, + RefreshCw, + ScanEye, +} from "lucide-react"; +import { useCallback, useEffect, useState } from "react"; +import { toast } from "sonner"; +import { + globalImageGenConfigsAtom, + imageGenConfigsAtom, +} from "@/atoms/image-gen-config/image-gen-config-query.atoms"; +import { updateLLMPreferencesMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; +import { + globalNewLLMConfigsAtom, + llmPreferencesAtom, + newLLMConfigsAtom, +} from "@/atoms/new-llm-config/new-llm-config-query.atoms"; +import { + globalVisionLLMConfigsAtom, + visionLLMConfigsAtom, +} from "@/atoms/vision-llm-config/vision-llm-config-query.atoms"; +import { Alert, AlertDescription } from "@/components/ui/alert"; +import { Badge } from "@/components/ui/badge"; +import { Button } from "@/components/ui/button"; +import { Card, CardContent } from "@/components/ui/card"; +import { Label } from "@/components/ui/label"; +import { + Select, + SelectContent, + SelectGroup, + SelectItem, + SelectLabel, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import { Skeleton } from "@/components/ui/skeleton"; +import { Spinner } from "@/components/ui/spinner"; +import { cn } from "@/lib/utils"; + +const ROLE_DESCRIPTIONS = { + agent: { + icon: Bot, + title: "Agent LLM", + description: "Primary LLM for chat interactions and agent operations", + color: "text-muted-foreground", + bgColor: "bg-muted", + prefKey: "agent_llm_id" as const, + configType: "llm" as const, + }, + image_generation: { + icon: ImageIcon, + title: "Image Generation Model", + description: "Model used for AI image generation (DALL-E, GPT Image, etc.)", + color: "text-muted-foreground", + bgColor: "bg-muted", + prefKey: "image_generation_config_id" as const, + configType: "image" as const, + }, + vision: { + icon: ScanEye, + title: "Vision LLM", + description: "Vision-capable model for screenshot analysis and context extraction", + color: "text-muted-foreground", + bgColor: "bg-muted", + prefKey: "vision_llm_config_id" as const, + configType: "vision" as const, + }, +}; + +interface LLMRoleManagerProps { + searchSpaceId: number; +} + +export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) { + // LLM configs + const { + data: newLLMConfigs = [], + isFetching: configsLoading, + error: configsError, + refetch: refreshConfigs, + } = useAtomValue(newLLMConfigsAtom); + const { + data: globalConfigs = [], + isFetching: globalConfigsLoading, + error: globalConfigsError, + } = useAtomValue(globalNewLLMConfigsAtom); + + // Image gen configs + const { + data: userImageConfigs = [], + isFetching: imageConfigsLoading, + error: imageConfigsError, + } = useAtomValue(imageGenConfigsAtom); + const { + data: globalImageConfigs = [], + isFetching: globalImageConfigsLoading, + error: globalImageConfigsError, + } = useAtomValue(globalImageGenConfigsAtom); + + // Vision LLM configs + const { + data: userVisionConfigs = [], + isFetching: visionConfigsLoading, + error: visionConfigsError, + } = useAtomValue(visionLLMConfigsAtom); + const { + data: globalVisionConfigs = [], + isFetching: globalVisionConfigsLoading, + error: globalVisionConfigsError, + } = useAtomValue(globalVisionLLMConfigsAtom); + + // Preferences + const { + data: preferences = {}, + isFetching: preferencesLoading, + error: preferencesError, + } = useAtomValue(llmPreferencesAtom); + + const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom); + + const [assignments, setAssignments] = useState>(() => ({ + agent_llm_id: preferences.agent_llm_id ?? null, + image_generation_config_id: preferences.image_generation_config_id ?? null, + vision_llm_config_id: preferences.vision_llm_config_id ?? null, + })); + + // Sync local state when preferences load/change. Without this, the selects + // stay on their initial (often empty) value while the query is in flight, + // so a saved assignment — including Auto mode (id 0) — never appears. + useEffect(() => { + setAssignments({ + agent_llm_id: preferences.agent_llm_id ?? null, + image_generation_config_id: preferences.image_generation_config_id ?? null, + vision_llm_config_id: preferences.vision_llm_config_id ?? null, + }); + }, [ + preferences.agent_llm_id, + preferences.image_generation_config_id, + preferences.vision_llm_config_id, + ]); + + const [savingRole, setSavingRole] = useState(null); + + const handleRoleAssignment = useCallback( + async (prefKey: string, configId: string) => { + // "unassigned" clears the role (null). Every other option — including + // Auto mode, whose config id is 0 — must be sent as-is. Using a falsy + // check here (e.g. `value || undefined`) would drop id 0 and silently + // fail to persist Auto mode. + const value = configId === "unassigned" ? null : Number(configId); + + setAssignments((prev) => ({ ...prev, [prefKey]: value })); + setSavingRole(prefKey); + + try { + await updatePreferences({ + search_space_id: searchSpaceId, + data: { [prefKey]: value }, + }); + toast.success("Role assignment updated"); + } finally { + setSavingRole(null); + } + }, + [updatePreferences, searchSpaceId] + ); + + // Combine global and custom LLM configs + const allLLMConfigs = [ + ...globalConfigs.map((config) => ({ ...config, is_global: true })), + ...newLLMConfigs.filter((config) => config.id && config.id.toString().trim() !== ""), + ]; + + // Combine global and custom image gen configs + const allImageConfigs = [ + ...globalImageConfigs.map((config) => ({ ...config, is_global: true })), + ...(userImageConfigs ?? []).filter((config) => config.id && config.id.toString().trim() !== ""), + ]; + + // Combine global and custom vision LLM configs + const allVisionConfigs = [ + ...globalVisionConfigs.map((config) => ({ ...config, is_global: true })), + ...(userVisionConfigs ?? []).filter( + (config) => config.id && config.id.toString().trim() !== "" + ), + ]; + + const isLoading = + configsLoading || + preferencesLoading || + globalConfigsLoading || + imageConfigsLoading || + globalImageConfigsLoading || + visionConfigsLoading || + globalVisionConfigsLoading; + const hasError = + configsError || + preferencesError || + globalConfigsError || + imageConfigsError || + globalImageConfigsError || + visionConfigsError || + globalVisionConfigsError; + const hasAnyConfigs = allLLMConfigs.length > 0 || allImageConfigs.length > 0; + + return ( +
+ {/* Header actions */} +
+ +
+ + {/* Error Alert */} + {hasError && ( +
+ + + + {(configsError?.message ?? "Failed to load LLM configurations") || + (preferencesError?.message ?? "Failed to load preferences") || + (globalConfigsError?.message ?? "Failed to load global configurations")} + + +
+ )} + + {/* Loading Skeleton */} + {isLoading && ( +
+ {["skeleton-a", "skeleton-b", "skeleton-c"].map((key) => ( + + + + + + + + ))} +
+ )} + + {/* No configs warning */} + {!isLoading && !hasError && !hasAnyConfigs && ( + + + + No configurations found. Please add at least one LLM provider or image model in the + respective settings tabs before assigning roles. + + + )} + + {/* Role Assignment Cards */} + {!isLoading && !hasError && hasAnyConfigs && ( +
+ {Object.entries(ROLE_DESCRIPTIONS).map(([key, role]) => { + const IconComponent = role.icon; + const currentAssignment = assignments[role.prefKey as keyof typeof assignments]; + + // Pick the right config lists based on role type + const roleGlobalConfigs = + role.configType === "image" + ? globalImageConfigs + : role.configType === "vision" + ? globalVisionConfigs + : globalConfigs; + const roleUserConfigs = + role.configType === "image" + ? (userImageConfigs ?? []).filter((c) => c.id && c.id.toString().trim() !== "") + : role.configType === "vision" + ? (userVisionConfigs ?? []).filter((c) => c.id && c.id.toString().trim() !== "") + : newLLMConfigs.filter((c) => c.id && c.id.toString().trim() !== ""); + const roleAllConfigs = + role.configType === "image" + ? allImageConfigs + : role.configType === "vision" + ? allVisionConfigs + : allLLMConfigs; + + const assignedConfig = roleAllConfigs.find((config) => config.id === currentAssignment); + const isAssigned = !!assignedConfig; + + return ( +
+ + + {/* Role Header */} +
+
+
+ +
+
+

{role.title}

+

+ {role.description} +

+
+
+ {savingRole === role.prefKey ? ( + + ) : isAssigned ? ( + + ) : ( + + )} +
+ + {/* Selector */} +
+ + +
+
+
+
+ ); + })} +
+ )} +
+ ); +} diff --git a/surfsense_web/components/settings/model-connections-settings.tsx b/surfsense_web/components/settings/model-connections-settings.tsx deleted file mode 100644 index 3b30b1558..000000000 --- a/surfsense_web/components/settings/model-connections-settings.tsx +++ /dev/null @@ -1,153 +0,0 @@ -"use client"; - -import { useAtom, useAtomValue } from "jotai"; -import { Dot } from "lucide-react"; -import { updateModelRolesMutationAtom } from "@/atoms/model-connections/model-connections-mutation.atoms"; -import { - globalModelConnectionsAtom, - modelConnectionsAtom, - modelRolesAtom, -} from "@/atoms/model-connections/model-connections-query.atoms"; -import { Label } from "@/components/ui/label"; -import { - Select, - SelectContent, - SelectItem, - SelectTrigger, - SelectValue, -} from "@/components/ui/select"; -import { Separator } from "@/components/ui/separator"; -import type { ConnectionRead, ModelRead } from "@/contracts/types/model-connections.types"; -import { AUTO_PROVIDER_ICON_KEY, getProviderIcon } from "@/lib/provider-icons"; -import { ModelProviderConnectionsPanel } from "./model-connections/model-provider-connections-panel"; -import { capability, modelLabel } from "./model-connections/model-utils"; -import { providerDisplay, providerIcon } from "./model-connections/provider-metadata"; - -function flattenModels(connections: ConnectionRead[]) { - return connections.flatMap((connection) => - connection.models.map((model) => ({ - ...model, - connectionName: providerDisplay(connection.provider).name, - connectionId: connection.id, - provider: connection.provider, - })) - ); -} - -function roleSelectValue(modelId: number | null | undefined, models: Array<{ id: number }>) { - if (!modelId) return "0"; - return models.some((model) => model.id === modelId) ? String(modelId) : "0"; -} - -function renderAutoModeOption() { - return ( - - - {getProviderIcon(AUTO_PROVIDER_ICON_KEY)} - Auto mode - - - ); -} - -export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: number }) { - const [{ data: globalConnections = [] }] = useAtom(globalModelConnectionsAtom); - const [{ data: connections = [] }] = useAtom(modelConnectionsAtom); - const [{ data: roles }] = useAtom(modelRolesAtom); - const updateRoles = useAtomValue(updateModelRolesMutationAtom); - - const allConnections = [...globalConnections, ...connections]; - const enabledModels = flattenModels(allConnections).filter((model) => model.enabled); - const chatModels = enabledModels.filter((model) => capability(model, "chat")); - const visionModels = enabledModels.filter((model) => capability(model, "vision")); - const imageModels = enabledModels.filter((model) => capability(model, "image_gen")); - - function renderModelOption(model: ModelRead & { connectionName: string; provider: string }) { - return ( - - - {providerIcon(model.provider)} - - {modelLabel(model)} - - - - ); - } - - return ( -
-
-
-

Model Roles

-

- Pick which enabled model powers chat, vision, and image generation for this search - space. -

-
-
-
- -

- Primary model for chat responses and agent tasks. You can also change it from the - chat. -

- -
-
- -

- Used to understand images in uploads, documents, connectors, and automations. Falls - back to chat model when possible. -

- -
-
- -

Used when generating images in chat.

- -
-
-
- - - - -
- ); -} diff --git a/surfsense_web/components/settings/model-connections/azure-connect-form.tsx b/surfsense_web/components/settings/model-connections/azure-connect-form.tsx deleted file mode 100644 index 451f053db..000000000 --- a/surfsense_web/components/settings/model-connections/azure-connect-form.tsx +++ /dev/null @@ -1,54 +0,0 @@ -import { useEffect, useState } from "react"; -import { Input } from "@/components/ui/input"; -import { Label } from "@/components/ui/label"; -import { ApiKeyField } from "./connect-fields"; -import { - isValidAzureTargetUri, - type ProviderConnectFormProps, - parseAzureTargetUri, -} from "./provider-metadata"; - -/** - * Azure OpenAI connect form. The user pastes a single Target URI, which we parse - * into api base, api version, and the deployment name (seeded as the model). - */ -export function AzureConnectForm({ onDraftChange }: ProviderConnectFormProps) { - const [targetUri, setTargetUri] = useState(""); - const [apiKey, setApiKey] = useState(""); - const canSubmit = isValidAzureTargetUri(targetUri) && Boolean(apiKey.trim()); - - useEffect(() => { - const parsed = parseAzureTargetUri(targetUri); - onDraftChange( - { - base_url: parsed?.origin ?? null, - api_key: apiKey || null, - extra: parsed?.apiVersion ? { api_version: parsed.apiVersion } : {}, - seedModelId: parsed?.deploymentName || undefined, - }, - canSubmit - ); - }, [apiKey, canSubmit, onDraftChange, targetUri]); - - return ( -
-
- - setTargetUri(event.target.value)} - placeholder="https://your-resource.cognitiveservices.azure.com/openai/deployments/deployment-name/chat/completions?api-version=2025-01-01-preview" - /> -

- Paste your endpoint target URI from Azure OpenAI (including API base, deployment name, and - API version). -

-
- -
- ); -} diff --git a/surfsense_web/components/settings/model-connections/bedrock-connect-form.tsx b/surfsense_web/components/settings/model-connections/bedrock-connect-form.tsx deleted file mode 100644 index f76308421..000000000 --- a/surfsense_web/components/settings/model-connections/bedrock-connect-form.tsx +++ /dev/null @@ -1,120 +0,0 @@ -import { useEffect, useState } from "react"; -import { Input } from "@/components/ui/input"; -import { Label } from "@/components/ui/label"; -import { - Select, - SelectContent, - SelectItem, - SelectTrigger, - SelectValue, -} from "@/components/ui/select"; -import { ApiKeyField } from "./connect-fields"; -import { - AWS_REGION_OPTIONS, - BEDROCK_AUTH_ACCESS_KEY, - BEDROCK_AUTH_IAM, - BEDROCK_AUTH_LONG_TERM_API_KEY, - type ProviderConnectFormProps, -} from "./provider-metadata"; - -/** - * Amazon Bedrock connect form. Region + auth method drive which AWS credentials - * are collected; everything rides along in `extra.litellm_params`. - */ -export function BedrockConnectForm({ onDraftChange }: ProviderConnectFormProps) { - const [region, setRegion] = useState(""); - const [authMethod, setAuthMethod] = useState(BEDROCK_AUTH_ACCESS_KEY); - const [accessKeyId, setAccessKeyId] = useState(""); - const [secretAccessKey, setSecretAccessKey] = useState(""); - const [bearerToken, setBearerToken] = useState(""); - - const canSubmit = (() => { - if (!region) return false; - if (authMethod === BEDROCK_AUTH_ACCESS_KEY) { - return Boolean(accessKeyId && secretAccessKey); - } - if (authMethod === BEDROCK_AUTH_LONG_TERM_API_KEY) { - return Boolean(bearerToken); - } - return true; - })(); - - useEffect(() => { - const params: Record = { aws_region_name: region }; - if (authMethod === BEDROCK_AUTH_ACCESS_KEY) { - params.aws_access_key_id = accessKeyId; - params.aws_secret_access_key = secretAccessKey; - } else if (authMethod === BEDROCK_AUTH_LONG_TERM_API_KEY) { - params.aws_bearer_token_bedrock = bearerToken; - } - onDraftChange({ base_url: null, api_key: null, extra: { litellm_params: params } }, canSubmit); - }, [accessKeyId, authMethod, bearerToken, canSubmit, onDraftChange, region, secretAccessKey]); - - return ( -
-
- - -
-
- - -
- {authMethod === BEDROCK_AUTH_ACCESS_KEY ? ( - <> -
- - setAccessKeyId(event.target.value)} - placeholder="Enter your AWS access key ID" - /> -
- - - ) : null} - {authMethod === BEDROCK_AUTH_LONG_TERM_API_KEY ? ( - - ) : null} - {authMethod === BEDROCK_AUTH_IAM ? ( -

- SurfSense will use the IAM role attached to the environment it's running in to - authenticate. -

- ) : null} -

- Add Bedrock model IDs from the provider's settings after connecting. -

-
- ); -} diff --git a/surfsense_web/components/settings/model-connections/connect-fields.tsx b/surfsense_web/components/settings/model-connections/connect-fields.tsx deleted file mode 100644 index 584fb98b0..000000000 --- a/surfsense_web/components/settings/model-connections/connect-fields.tsx +++ /dev/null @@ -1,105 +0,0 @@ -import { Eye, EyeOff } from "lucide-react"; -import type { ReactNode } from "react"; -import { useState } from "react"; -import { Button } from "@/components/ui/button"; -import { DialogFooter } from "@/components/ui/dialog"; -import { Input } from "@/components/ui/input"; -import { Label } from "@/components/ui/label"; -import { Spinner } from "@/components/ui/spinner"; - -interface ApiBaseUrlFieldProps { - value: string; - onChange: (value: string) => void; - /** Placeholder, typically the provider's prefilled default base URL. */ - placeholder?: string; - hint?: ReactNode; -} - -/** Shared API Base URL input. The prefilled default is passed in via `value`. */ -export function ApiBaseUrlField({ value, onChange, placeholder, hint }: ApiBaseUrlFieldProps) { - return ( -
- - onChange(event.target.value)} - placeholder={placeholder || "https://api.example.com/v1"} - /> - {hint ?

{hint}

: null} -
- ); -} - -interface ApiKeyFieldProps { - value: string; - onChange: (value: string) => void; - label?: string; - placeholder?: string; -} - -/** Shared masked API Key input. */ -export function ApiKeyField({ - value, - onChange, - label = "API Key", - placeholder = "API key", -}: ApiKeyFieldProps) { - const [showApiKey, setShowApiKey] = useState(false); - - return ( -
- -
- onChange(event.target.value)} - placeholder={placeholder} - type={showApiKey ? "text" : "password"} - className="pr-11" - /> - -
-
- ); -} - -interface ConnectFormFooterProps { - onCancel: () => void; - onSubmit: () => void; - canSubmit: boolean; - isPending: boolean; -} - -/** Shared Cancel / Connect footer for every provider connect form. */ -export function ConnectFormFooter({ - onCancel, - onSubmit, - canSubmit, - isPending, -}: ConnectFormFooterProps) { - return ( - - - - - ); -} diff --git a/surfsense_web/components/settings/model-connections/connection-card.tsx b/surfsense_web/components/settings/model-connections/connection-card.tsx deleted file mode 100644 index b482cac9f..000000000 --- a/surfsense_web/components/settings/model-connections/connection-card.tsx +++ /dev/null @@ -1,88 +0,0 @@ -"use client"; - -import { useAtomValue } from "jotai"; -import { Trash2 } from "lucide-react"; -import { deleteModelConnectionMutationAtom } from "@/atoms/model-connections/model-connections-mutation.atoms"; -import { - AlertDialog, - AlertDialogAction, - AlertDialogCancel, - AlertDialogContent, - AlertDialogDescription, - AlertDialogFooter, - AlertDialogHeader, - AlertDialogTitle, - AlertDialogTrigger, -} from "@/components/ui/alert-dialog"; -import { Badge } from "@/components/ui/badge"; -import { Button } from "@/components/ui/button"; -import type { ConnectionRead } from "@/contracts/types/model-connections.types"; -import { ConnectionSettingsDialog } from "./connection-settings-dialog"; -import { providerDisplay, providerIcon } from "./provider-metadata"; - -export function ConnectionCard({ connection }: { connection: ConnectionRead }) { - const deleteConnection = useAtomValue(deleteModelConnectionMutationAtom); - - const providerMeta = providerDisplay(connection.provider); - const providerLabel = providerMeta.name; - - function deleteCurrentConnection() { - deleteConnection.mutate(connection.id); - } - - return ( -
-
-
-
- {providerIcon(connection.provider)} - {providerLabel} - {connection.scope === "GLOBAL" ? ( - - Default - - ) : null} -
-
- {connection.base_url || "Provider default endpoint"} -
-
-
- - - - - - - - Delete this provider? - - {providerLabel} and all of - its models will be removed from this search space. This cannot be undone. - - - - Cancel - - Delete - - - - -
-
-
- ); -} diff --git a/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx b/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx deleted file mode 100644 index 1f16c3bd0..000000000 --- a/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx +++ /dev/null @@ -1,333 +0,0 @@ -import { useAtomValue } from "jotai"; -import { Eye, EyeOff, Settings } from "lucide-react"; -import { useMemo, useState } from "react"; -import { - addManualModelMutationAtom, - bulkUpdateModelsMutationAtom, - discoverConnectionModelsMutationAtom, - testPreviewModelMutationAtom, - updateModelConnectionMutationAtom, -} from "@/atoms/model-connections/model-connections-mutation.atoms"; -import { Button } from "@/components/ui/button"; -import { - Dialog, - DialogContent, - DialogDescription, - DialogFooter, - DialogHeader, - DialogTitle, - DialogTrigger, -} from "@/components/ui/dialog"; -import { Input } from "@/components/ui/input"; -import { Label } from "@/components/ui/label"; -import { Separator } from "@/components/ui/separator"; -import { Spinner } from "@/components/ui/spinner"; -import type { - ConnectionRead, - ConnectionUpdateRequest, -} from "@/contracts/types/model-connections.types"; -import { capability, type SelectableModel } from "./model-utils"; -import { ModelsSelectionPanel } from "./models-selection-panel"; -import { providerIcon } from "./provider-metadata"; - -interface ConnectionSettingsDialogProps { - connection: ConnectionRead; - providerLabel: string; -} - -function enabledModelIds(models: SelectableModel[]) { - return new Set( - models - .filter((model) => typeof model.id === "number" && model.enabled) - .map((model) => Number(model.id)) - ); -} - -export function ConnectionSettingsDialog({ - connection, - providerLabel, -}: ConnectionSettingsDialogProps) { - const discoverModels = useAtomValue(discoverConnectionModelsMutationAtom); - const testPreviewModel = useAtomValue(testPreviewModelMutationAtom); - const updateConnection = useAtomValue(updateModelConnectionMutationAtom); - const addManualModel = useAtomValue(addManualModelMutationAtom); - const bulkUpdateModels = useAtomValue(bulkUpdateModelsMutationAtom); - - const allowlist = Array.isArray(connection.extra?.model_ids) - ? (connection.extra.model_ids as string[]) - : []; - const [isOpen, setIsOpen] = useState(false); - const [baseUrlDraft, setBaseUrlDraft] = useState(connection.base_url ?? ""); - const [apiKeyDraft, setApiKeyDraft] = useState(""); - const [showApiKey, setShowApiKey] = useState(false); - const [allowlistText, setAllowlistText] = useState(allowlist.join(", ")); - const [isSavingConnectionSettings, setIsSavingConnectionSettings] = useState(false); - const [draftEnabledModelIds, setDraftEnabledModelIds] = useState(() => - enabledModelIds(connection.models) - ); - - const isLocal = - connection.provider === "ollama_chat" || - connection.provider === "lm_studio" || - !connection.base_url?.startsWith("https"); - const hasConnectionChanges = - baseUrlDraft.trim() !== (connection.base_url ?? "") || - apiKeyDraft.trim() !== (connection.api_key ?? ""); - const draftModels = useMemo( - () => - connection.models.map((model) => - typeof model.id === "number" - ? { ...model, enabled: draftEnabledModelIds.has(model.id) } - : model - ), - [connection.models, draftEnabledModelIds] - ); - const hasModelChanges = connection.models.some( - (model) => typeof model.id === "number" && draftEnabledModelIds.has(model.id) !== model.enabled - ); - const canUpdate = hasConnectionChanges || hasModelChanges; - - function handleOpenChange(open: boolean) { - setIsOpen(open); - if (open) { - setBaseUrlDraft(connection.base_url ?? ""); - setApiKeyDraft(connection.api_key ?? ""); - setShowApiKey(false); - setAllowlistText(allowlist.join(", ")); - setIsSavingConnectionSettings(false); - setDraftEnabledModelIds(enabledModelIds(connection.models)); - } - } - - async function saveModelChanges() { - const toEnable = connection.models - .filter((model) => typeof model.id === "number" && draftEnabledModelIds.has(model.id)) - .filter((model) => !model.enabled) - .map((model) => Number(model.id)); - const toDisable = connection.models - .filter((model) => typeof model.id === "number" && !draftEnabledModelIds.has(model.id)) - .filter((model) => model.enabled) - .map((model) => Number(model.id)); - - if (toEnable.length > 0) { - await bulkUpdateModels.mutateAsync({ - connectionId: connection.id, - data: { model_ids: toEnable, enabled: true }, - }); - } - if (toDisable.length > 0) { - await bulkUpdateModels.mutateAsync({ - connectionId: connection.id, - data: { model_ids: toDisable, enabled: false }, - }); - } - } - - async function saveConnectionSettings() { - if (isSavingConnectionSettings) return; - - const data: ConnectionUpdateRequest = { - base_url: baseUrlDraft.trim() || null, - }; - - if (apiKeyDraft.trim() !== (connection.api_key ?? "")) { - data.api_key = apiKeyDraft.trim() || null; - } - const apiKeyForTest = Object.hasOwn(data, "api_key") - ? (data.api_key ?? null) - : (connection.api_key ?? null); - - const enabledModels = draftModels.filter((model) => model.enabled); - const testModel = enabledModels.find((model) => capability(model, "chat")) ?? enabledModels[0]; - setIsSavingConnectionSettings(true); - try { - if (hasConnectionChanges) { - if (testModel) { - const result = await testPreviewModel.mutateAsync({ - provider: connection.provider, - base_url: data.base_url, - api_key: apiKeyForTest, - scope: "SEARCH_SPACE", - search_space_id: connection.search_space_id, - extra: connection.extra ?? {}, - enabled: connection.enabled, - models: [], - model_id: testModel.model_id, - }); - if (!result.ok) return; - } - await updateConnection.mutateAsync({ id: connection.id, data }); - setApiKeyDraft(""); - } - - if (hasModelChanges) { - await saveModelChanges(); - } - } finally { - setIsSavingConnectionSettings(false); - } - } - - function saveAllowlist() { - const ids = allowlistText - .split(",") - .map((value) => value.trim()) - .filter(Boolean); - updateConnection.mutate({ - id: connection.id, - data: { extra: { ...(connection.extra ?? {}), model_ids: ids } }, - }); - } - - function handleToggleModel(model: SelectableModel, enabled: boolean) { - if (typeof model.id !== "number") return; - const modelId = model.id; - setDraftEnabledModelIds((current) => { - const next = new Set(current); - if (enabled) { - next.add(modelId); - } else { - next.delete(modelId); - } - return next; - }); - } - - function handleBulkToggle(models: SelectableModel[], enabled: boolean) { - const modelIds = models - .map((model) => model.id) - .filter((id): id is number => typeof id === "number"); - if (modelIds.length === 0) return; - setDraftEnabledModelIds((current) => { - const next = new Set(current); - for (const id of modelIds) { - if (enabled) { - next.add(id); - } else { - next.delete(id); - } - } - return next; - }); - } - - return ( - - - - - - -
- {providerIcon(connection.provider, "size-5")} -
- - Configure {providerLabel} - - - Manage credentials and choose which models are available from this provider. - -
-
-
- -
-
-
- - setBaseUrlDraft(event.target.value)} - placeholder="https://api.example.com/v1" - /> -

- Leave empty to use the provider default endpoint. -

-
- -
- -
- setApiKeyDraft(event.target.value)} - placeholder={connection.has_api_key ? "Saved API key" : "Paste an API key"} - type={showApiKey ? "text" : "password"} - className="pr-11" - /> - -
-
- - {!isLocal ? ( -
- -
- setAllowlistText(event.target.value)} - placeholder="Comma-separated, e.g. anthropic/claude-sonnet-4-5, google/gemini-2.5-pro" - /> - -
-

- Leave empty to discover all models. Recommended for providers with large catalogs. -

-
- ) : null} - - - - discoverModels.mutate(connection.id)} - onAddManual={(modelId) => - addManualModel.mutate({ - connectionId: connection.id, - data: { model_id: modelId }, - }) - } - onToggleModel={handleToggleModel} - onBulkToggle={handleBulkToggle} - /> -
-
- - - - -
-
- ); -} diff --git a/surfsense_web/components/settings/model-connections/default-connect-form.tsx b/surfsense_web/components/settings/model-connections/default-connect-form.tsx deleted file mode 100644 index e3111202d..000000000 --- a/surfsense_web/components/settings/model-connections/default-connect-form.tsx +++ /dev/null @@ -1,62 +0,0 @@ -import { useEffect, useState } from "react"; -import { ApiBaseUrlField, ApiKeyField } from "./connect-fields"; -import type { ProviderConnectFormProps } from "./provider-metadata"; - -const OPTIONAL_API_KEY_PROVIDERS = new Set(["ollama_chat", "lm_studio", "openai_compatible"]); - -function baseUrlHint(provider: string) { - if (provider === "ollama_chat" || provider === "lm_studio") { - return "For local servers, use host.docker.internal instead of localhost."; - } - if (provider === "openai_compatible") { - return "Enter the full endpoint URL."; - } - if (provider === "openai" || provider === "anthropic" || provider === "openrouter") { - return "Override only if you route through a proxy or gateway."; - } - return undefined; -} - -/** - * Connect form for OpenAI-compatible / native key providers (OpenAI, Anthropic, - * OpenRouter, OpenAI-Compatible, LM Studio, Ollama, …). The base URL is - * prefilled from the provider default. - */ -export function DefaultConnectForm({ - provider, - defaultBaseUrl, - baseUrlRequired, - onDraftChange, -}: ProviderConnectFormProps) { - const [baseUrl, setBaseUrl] = useState(defaultBaseUrl); - const [apiKey, setApiKey] = useState(""); - const isApiKeyOptional = OPTIONAL_API_KEY_PROVIDERS.has(provider); - const hint = baseUrlHint(provider); - const apiKeyValue = apiKey.trim(); - const canSubmit = - !(baseUrlRequired && !baseUrl.trim()) && (isApiKeyOptional || Boolean(apiKeyValue)); - - useEffect(() => { - onDraftChange( - { base_url: baseUrl || null, api_key: apiKeyValue || null, extra: {} }, - canSubmit - ); - }, [apiKeyValue, baseUrl, canSubmit, onDraftChange]); - - return ( -
- - -
- ); -} diff --git a/surfsense_web/components/settings/model-connections/model-provider-connections-panel.tsx b/surfsense_web/components/settings/model-connections/model-provider-connections-panel.tsx deleted file mode 100644 index a703ab1c8..000000000 --- a/surfsense_web/components/settings/model-connections/model-provider-connections-panel.tsx +++ /dev/null @@ -1,299 +0,0 @@ -"use client"; - -import { useAtomValue } from "jotai"; -import { type ReactNode, useState } from "react"; -import { toast } from "sonner"; -import { - createModelConnectionMutationAtom, - previewConnectionModelsMutationAtom, - testPreviewModelMutationAtom, -} from "@/atoms/model-connections/model-connections-mutation.atoms"; -import { modelProvidersAtom } from "@/atoms/model-connections/model-connections-query.atoms"; -import { Button } from "@/components/ui/button"; -import { Separator } from "@/components/ui/separator"; -import type { ConnectionRead, ModelSelection } from "@/contracts/types/model-connections.types"; -import { ConnectionCard } from "./connection-card"; -import { capability, type SelectableModel } from "./model-utils"; -import { ProviderConnectDialog } from "./provider-connect-dialog"; -import { - type ConnectionDraft, - PROVIDER_ORDER, - providerDisplay, - providerIcon, -} from "./provider-metadata"; - -interface ModelProviderConnectionsPanelProps { - searchSpaceId: number; - connections: ConnectionRead[]; - className?: string; - addProviderTitle?: string; - addProviderDescription?: string; - availableProvidersTitle?: string; - footerAction?: ReactNode; - showAddProviderHeader?: boolean; -} - -function toModelSelection(model: SelectableModel): ModelSelection { - return { - model_id: model.model_id, - display_name: model.display_name, - source: model.source || "DISCOVERED", - supports_chat: model.supports_chat, - max_input_tokens: model.max_input_tokens, - supports_image_input: model.supports_image_input, - supports_tools: model.supports_tools, - supports_image_generation: model.supports_image_generation, - enabled: model.enabled, - metadata: "metadata" in model ? (model.metadata ?? {}) : (model.catalog ?? {}), - }; -} - -export function ModelProviderConnectionsPanel({ - searchSpaceId, - connections, - className, - addProviderTitle = "Add Provider", - addProviderDescription = "SurfSense supports popular providers and self-hosted model endpoints.", - availableProvidersTitle = "Available Providers", - footerAction, - showAddProviderHeader = true, -}: ModelProviderConnectionsPanelProps) { - const { data: providers = [] } = useAtomValue(modelProvidersAtom); - const createConnection = useAtomValue(createModelConnectionMutationAtom); - const previewModels = useAtomValue(previewConnectionModelsMutationAtom); - const testPreviewModel = useAtomValue(testPreviewModelMutationAtom); - - const [isAddProviderOpen, setIsAddProviderOpen] = useState(false); - const [provider, setProvider] = useState("openai_compatible"); - const [connectModels, setConnectModels] = useState([]); - const selectedProvider = providers.find((item) => item.provider === provider); - - const sortedProviders = [...providers].sort((left, right) => { - const leftIndex = PROVIDER_ORDER.indexOf(left.provider); - const rightIndex = PROVIDER_ORDER.indexOf(right.provider); - if (leftIndex !== -1 || rightIndex !== -1) { - return ( - (leftIndex === -1 ? Number.MAX_SAFE_INTEGER : leftIndex) - - (rightIndex === -1 ? Number.MAX_SAFE_INTEGER : rightIndex) - ); - } - return providerDisplay(left.provider).name.localeCompare(providerDisplay(right.provider).name); - }); - - function resetConnectState() { - setConnectModels([]); - } - - function handleConnectOpenChange(open: boolean) { - setIsAddProviderOpen(open); - if (!open) { - resetConnectState(); - } - } - - function mergePreviewModels(fetchedModels: SelectableModel[]) { - setConnectModels((current) => { - const currentById = new Map(current.map((model) => [model.model_id, model])); - return fetchedModels.map((model) => { - const prior = currentById.get(model.model_id); - return { - ...toModelSelection(model), - enabled: prior ? prior.enabled : model.enabled, - }; - }); - }); - } - - function connectionModelsForDraft(draft: ConnectionDraft) { - const models = [...connectModels]; - if (draft.seedModelId && !models.some((model) => model.model_id === draft.seedModelId)) { - models.push({ - model_id: draft.seedModelId, - display_name: draft.seedModelId, - source: "MANUAL", - enabled: true, - metadata: {}, - }); - } - return models; - } - - function representativeTestModel(models: ModelSelection[]) { - const enabledModels = models.filter((model) => model.enabled); - return enabledModels.find((model) => capability(model, "chat")) ?? enabledModels[0]; - } - - // Each provider connect form builds its own credential payload; the backend - // resolver (`to_litellm`) forwards `extra.litellm_params` straight to LiteLLM. - function handleCreate(draft: ConnectionDraft) { - const models = connectionModelsForDraft(draft); - const testModel = representativeTestModel(models); - if (!testModel) { - toast.error("Select at least one model before connecting"); - return; - } - - const request = { - provider, - base_url: draft.base_url, - api_key: draft.api_key, - scope: "SEARCH_SPACE" as const, - search_space_id: searchSpaceId, - extra: draft.extra, - enabled: true, - models, - }; - - testPreviewModel.mutate( - { ...request, model_id: testModel.model_id }, - { - onSuccess: (result) => { - if (!result.ok) return; - createConnection.mutate(request, { - onSuccess: () => { - setIsAddProviderOpen(false); - resetConnectState(); - }, - }); - }, - } - ); - } - - function openProviderDialog(providerId: string) { - resetConnectState(); - setProvider(providerId); - setIsAddProviderOpen(true); - if (providerId === "vertex_ai") { - previewModels.mutate( - { - provider: providerId, - base_url: null, - api_key: null, - scope: "SEARCH_SPACE", - search_space_id: searchSpaceId, - extra: {}, - enabled: true, - models: [], - }, - { - onSuccess: mergePreviewModels, - } - ); - } - } - - function refreshConnectModels(draft: ConnectionDraft) { - previewModels.mutate( - { - provider, - base_url: draft.base_url, - api_key: draft.api_key, - scope: "SEARCH_SPACE", - search_space_id: searchSpaceId, - extra: draft.extra, - enabled: true, - models: [], - }, - { - onSuccess: mergePreviewModels, - } - ); - } - - function addConnectModel(modelId: string) { - setConnectModels((current) => { - if (current.some((model) => model.model_id === modelId)) return current; - return [ - ...current, - { - model_id: modelId, - display_name: modelId, - source: "MANUAL", - enabled: true, - metadata: {}, - }, - ]; - }); - } - - function toggleConnectModel(model: SelectableModel, enabled: boolean) { - setConnectModels((current) => - current.map((item) => (item.model_id === model.model_id ? { ...item, enabled } : item)) - ); - } - - function bulkToggleConnectModels(models: SelectableModel[], enabled: boolean) { - const modelIds = new Set(models.map((model) => model.model_id)); - setConnectModels((current) => - current.map((item) => (modelIds.has(item.model_id) ? { ...item, enabled } : item)) - ); - } - - return ( -
-
- {showAddProviderHeader ? ( -
-

{addProviderTitle}

-

{addProviderDescription}

-
- ) : null} -
- {sortedProviders.map((item) => { - const meta = providerDisplay(item.provider); - - return ( - - ); - })} -
-
- - - - {connections.length > 0 ? ( -
- -

{availableProvidersTitle}

-
- {connections.map((connection) => ( - - ))} -
-
- ) : null} - {footerAction ?
{footerAction}
: null} -
- ); -} diff --git a/surfsense_web/components/settings/model-connections/model-utils.ts b/surfsense_web/components/settings/model-connections/model-utils.ts deleted file mode 100644 index 2887f2179..000000000 --- a/surfsense_web/components/settings/model-connections/model-utils.ts +++ /dev/null @@ -1,30 +0,0 @@ -import type { ModelPreviewRead, ModelRead } from "@/contracts/types/model-connections.types"; - -export type ModelCapabilityFilter = "chat" | "vision" | "image_gen"; - -export const MODEL_CAPABILITY_FILTERS: { key: ModelCapabilityFilter; label: string }[] = [ - { key: "chat", label: "Chat" }, - { key: "vision", label: "Vision" }, - { key: "image_gen", label: "Image" }, -]; - -export type SelectableModel = (ModelRead | ModelPreviewRead) & { - id?: number | string; - connection_id?: number; -}; - -export function modelLabel(model: SelectableModel) { - return model.display_name || model.model_id; -} - -export function capability(model: SelectableModel, key: ModelCapabilityFilter) { - if (key === "chat") return Boolean(model.supports_chat); - if (key === "vision") return Boolean(model.supports_image_input); - return Boolean(model.supports_image_generation); -} - -export function capabilityLabels(model: SelectableModel) { - return MODEL_CAPABILITY_FILTERS.filter((filter) => capability(model, filter.key)) - .map((filter) => filter.label.toLowerCase()) - .join(", "); -} diff --git a/surfsense_web/components/settings/model-connections/models-selection-panel.tsx b/surfsense_web/components/settings/model-connections/models-selection-panel.tsx deleted file mode 100644 index 3c6990afb..000000000 --- a/surfsense_web/components/settings/model-connections/models-selection-panel.tsx +++ /dev/null @@ -1,198 +0,0 @@ -import { RefreshCw } from "lucide-react"; -import { useState } from "react"; -import { Badge } from "@/components/ui/badge"; -import { Button } from "@/components/ui/button"; -import { Checkbox } from "@/components/ui/checkbox"; -import { Input } from "@/components/ui/input"; -import { Spinner } from "@/components/ui/spinner"; -import { - capability, - capabilityLabels, - MODEL_CAPABILITY_FILTERS, - type ModelCapabilityFilter, - modelLabel, - type SelectableModel, -} from "./model-utils"; - -interface ModelsSelectionPanelProps { - models: SelectableModel[]; - description?: string; - emptyMessage?: string; - manualInputPlaceholder?: string; - refreshLabel?: string; - isRefreshing?: boolean; - isAddingManual?: boolean; - isUpdatingModel?: boolean; - isBulkUpdating?: boolean; - onRefresh?: () => void; - onAddManual?: (modelId: string) => void; - onToggleModel?: (model: SelectableModel, enabled: boolean) => void; - onBulkToggle?: (models: SelectableModel[], enabled: boolean) => void; -} - -export function ModelsSelectionPanel({ - models, - description = "Select models to make available for this provider.", - emptyMessage = "No models available.", - manualInputPlaceholder = "Add a model ID manually", - refreshLabel = "Refresh models", - isRefreshing = false, - isAddingManual = false, - isUpdatingModel = false, - isBulkUpdating = false, - onRefresh, - onAddManual, - onToggleModel, - onBulkToggle, -}: ModelsSelectionPanelProps) { - const [manualModelId, setManualModelId] = useState(""); - const [modelFilter, setModelFilter] = useState(null); - - const filteredModels = modelFilter - ? models.filter((model) => capability(model, modelFilter)) - : models; - const allFilteredModelsEnabled = - filteredModels.length > 0 && filteredModels.every((model) => model.enabled); - - function addModel() { - const modelId = manualModelId.trim(); - if (!modelId || !onAddManual) return; - onAddManual(modelId); - setManualModelId(""); - } - - function toggleFilteredModels() { - const nextEnabled = !allFilteredModelsEnabled; - const changedModels = filteredModels.filter((model) => model.enabled !== nextEnabled); - if (changedModels.length === 0) return; - onBulkToggle?.(changedModels, nextEnabled); - } - - return ( -
-
-
-
Models
-

{description}

-
-
- - {onRefresh ? ( - - ) : null} -
-
- - {onAddManual ? ( -
- setManualModelId(event.target.value)} - onKeyDown={(event) => { - if (event.key === "Enter") { - event.preventDefault(); - addModel(); - } - }} - placeholder={manualInputPlaceholder} - /> - -
- ) : null} - - {models.length > 0 ? ( -
- Filter models - {MODEL_CAPABILITY_FILTERS.map((filter) => { - const count = models.filter((model) => capability(model, filter.key)).length; - const isActive = modelFilter === filter.key; - - return ( - - ); - })} -
- ) : null} - -
- {models.length === 0 ? ( -
- {emptyMessage} -
- ) : null} - {filteredModels.length === 0 && modelFilter ? ( -
- No{" "} - {MODEL_CAPABILITY_FILTERS.find( - (filter) => filter.key === modelFilter - )?.label.toLowerCase()}{" "} - models found on this connection. -
- ) : null} -
- {filteredModels.map((model) => ( -
- onToggleModel?.(model, checked === true)} - disabled={!onToggleModel || isUpdatingModel} - /> -
-
- {modelLabel(model)} - {model.source === "MANUAL" ? ( - - manual - - ) : null} -
-
- {capabilityLabels(model) || "No discovered capabilities"} -
-
-
- ))} -
-
-
- ); -} diff --git a/surfsense_web/components/settings/model-connections/provider-connect-dialog.tsx b/surfsense_web/components/settings/model-connections/provider-connect-dialog.tsx deleted file mode 100644 index 2eee2cf8c..000000000 --- a/surfsense_web/components/settings/model-connections/provider-connect-dialog.tsx +++ /dev/null @@ -1,155 +0,0 @@ -import { useCallback, useRef, useState } from "react"; -import { - Dialog, - DialogContent, - DialogDescription, - DialogHeader, - DialogTitle, -} from "@/components/ui/dialog"; -import { Separator } from "@/components/ui/separator"; -import type { ModelProviderRead } from "@/contracts/types/model-connections.types"; -import { AzureConnectForm } from "./azure-connect-form"; -import { BedrockConnectForm } from "./bedrock-connect-form"; -import { ConnectFormFooter } from "./connect-fields"; -import { DefaultConnectForm } from "./default-connect-form"; -import type { SelectableModel } from "./model-utils"; -import { ModelsSelectionPanel } from "./models-selection-panel"; -import { - type ConnectionDraft, - type ProviderConnectFormProps, - providerDefaultBaseUrl, - providerDisplay, - providerIcon, -} from "./provider-metadata"; -import { VertexConnectForm } from "./vertex-connect-form"; - -interface ProviderConnectDialogProps { - open: boolean; - onOpenChange: (open: boolean) => void; - provider: string; - selectedProvider?: ModelProviderRead; - isPending: boolean; - onSubmit: (draft: ConnectionDraft) => void; - previewModels?: SelectableModel[]; - isPreviewingModels?: boolean; - onPreviewModels?: (draft: ConnectionDraft) => void; - onAddPreviewModel?: (modelId: string) => void; - onTogglePreviewModel?: (model: SelectableModel, enabled: boolean) => void; - onBulkTogglePreviewModels?: (models: SelectableModel[], enabled: boolean) => void; -} - -/** - * Shared dialog shell for the "Add Provider" flow. It owns the header and routes - * to the provider-specific connect form. Forms remount on open (Radix unmounts - * closed content), so each gets fresh, prefilled state. - */ -export function ProviderConnectDialog({ - open, - onOpenChange, - provider, - selectedProvider, - isPending, - onSubmit, - previewModels = [], - isPreviewingModels = false, - onPreviewModels, - onAddPreviewModel, - onTogglePreviewModel, - onBulkTogglePreviewModels, -}: ProviderConnectDialogProps) { - const meta = providerDisplay(provider); - const isAzure = provider === "azure"; - const isBedrock = provider === "bedrock"; - const isVertex = provider === "vertex_ai"; - const titleRef = useRef(null); - const [currentDraft, setCurrentDraft] = useState({ - base_url: null, - api_key: null, - extra: {}, - }); - const [canSubmit, setCanSubmit] = useState(false); - - const handleDraftChange = useCallback((draft: ConnectionDraft, nextCanSubmit: boolean) => { - setCurrentDraft(draft); - setCanSubmit(nextCanSubmit); - }, []); - - const formProps: ProviderConnectFormProps = { - provider, - defaultBaseUrl: providerDefaultBaseUrl(provider, selectedProvider?.default_base_url), - baseUrlRequired: Boolean(selectedProvider?.base_url_required), - onDraftChange: handleDraftChange, - }; - - const modelDescription = (() => { - if (isAzure) { - return "Select the models to enable for Azure OpenAI"; - } - if (isBedrock) { - return "Select the models to enable for Amazon Bedrock"; - } - if (isVertex) { - return "Select the models to enable for Gemini"; - } - return "Select the models to enable for this provider"; - })(); - - const canRefreshModels = !isAzure && !isVertex && (!isBedrock || canSubmit); - const hasEnabledModel = - previewModels.some((model) => model.enabled) || Boolean(currentDraft.seedModelId); - const canConnect = canSubmit && hasEnabledModel; - - return ( - - { - event.preventDefault(); - titleRef.current?.focus(); - }} - > - -
- {providerIcon(provider, "size-5")} -
- - Connect {meta.name} - - {meta.subtitle} -
-
-
-
- {provider === "azure" ? ( - - ) : provider === "bedrock" ? ( - - ) : provider === "vertex_ai" ? ( - - ) : ( - - )} - - - - onPreviewModels?.(currentDraft) : undefined} - onAddManual={onAddPreviewModel} - onToggleModel={onTogglePreviewModel} - onBulkToggle={onBulkTogglePreviewModels} - /> -
- onOpenChange(false)} - onSubmit={() => onSubmit(currentDraft)} - canSubmit={canConnect} - isPending={isPending} - /> -
-
- ); -} diff --git a/surfsense_web/components/settings/model-connections/provider-metadata.tsx b/surfsense_web/components/settings/model-connections/provider-metadata.tsx deleted file mode 100644 index 8b8a877b9..000000000 --- a/surfsense_web/components/settings/model-connections/provider-metadata.tsx +++ /dev/null @@ -1,137 +0,0 @@ -import { getProviderIcon } from "@/lib/provider-icons"; - -export const PROVIDER_ORDER = [ - "openai", - "anthropic", - "vertex_ai", - "bedrock", - "azure", - "openrouter", - "ollama_chat", - "lm_studio", - "openai_compatible", -]; - -export const PROVIDER_DISPLAY: Record< - string, - { name: string; subtitle: string; iconKey?: string; defaultBaseUrl?: string } -> = { - anthropic: { - name: "Claude", - subtitle: "Anthropic", - iconKey: "claude", - defaultBaseUrl: "https://api.anthropic.com/v1", - }, - azure: { name: "Azure OpenAI", subtitle: "Microsoft Azure", iconKey: "azure" }, - bedrock: { name: "Amazon Bedrock", subtitle: "AWS", iconKey: "bedrock" }, - lm_studio: { name: "LM Studio", subtitle: "LM Studio", iconKey: "lm_studio" }, - ollama_chat: { name: "Ollama", subtitle: "Ollama", iconKey: "ollama" }, - openai: { - name: "GPT", - subtitle: "OpenAI", - iconKey: "openai", - defaultBaseUrl: "https://api.openai.com/v1", - }, - openai_compatible: { - name: "OpenAI-Compatible", - subtitle: "OpenAI-compatible endpoint", - iconKey: "custom", - }, - openrouter: { - name: "OpenRouter", - subtitle: "OpenRouter", - iconKey: "openrouter", - defaultBaseUrl: "https://openrouter.ai/api/v1", - }, - vertex_ai: { name: "Gemini", subtitle: "Google Cloud Vertex AI", iconKey: "vertex_ai" }, -}; - -export function providerDisplay(provider: string) { - const fallback = provider - .split("_") - .filter(Boolean) - .map((part) => part.charAt(0).toUpperCase() + part.slice(1)) - .join(" "); - - return ( - PROVIDER_DISPLAY[provider] ?? { - name: fallback || provider, - subtitle: provider, - iconKey: provider, - } - ); -} - -export function providerIcon(provider: string, className = "size-4") { - return getProviderIcon(providerDisplay(provider).iconKey ?? provider, { className }); -} - -export function providerDefaultBaseUrl(provider: string, registryDefault?: string | null) { - return registryDefault ?? PROVIDER_DISPLAY[provider]?.defaultBaseUrl ?? ""; -} - -export const AWS_REGION_OPTIONS = [ - "us-east-1", - "us-east-2", - "us-west-2", - "us-gov-east-1", - "us-gov-west-1", - "ap-northeast-1", - "ap-south-1", - "ap-southeast-1", - "ap-southeast-2", - "ap-east-1", - "ca-central-1", - "eu-central-1", - "eu-west-2", -]; - -export const VERTEX_DEFAULT_LOCATION = "global"; - -export const BEDROCK_AUTH_IAM = "iam"; -export const BEDROCK_AUTH_ACCESS_KEY = "access_key"; -export const BEDROCK_AUTH_LONG_TERM_API_KEY = "long_term_api_key"; - -export const VERTEX_AUTH_SERVICE_ACCOUNT = "service_account_json"; -export const VERTEX_AUTH_WORKLOAD_IDENTITY = "workload_identity"; - -// Mirrors Onyx's Azure "Target URI" parser: the user pastes the full endpoint -// (e.g. https://res.cognitiveservices.azure.com/openai/deployments//chat/completions?api-version=) -// which we split into api base (origin), api version, and deployment name. -export function parseAzureTargetUri(rawUri: string) { - try { - const url = new URL(rawUri); - const deploymentMatch = url.pathname.match(/\/openai\/deployments\/([^/]+)/i); - return { - origin: url.origin, - apiVersion: url.searchParams.get("api-version")?.trim() ?? "", - deploymentName: deploymentMatch?.[1] ? deploymentMatch[1].toLowerCase() : "", - isResponsesPath: /\/openai\/responses/i.test(url.pathname), - }; - } catch { - return null; - } -} - -export function isValidAzureTargetUri(rawUri: string) { - const parsed = parseAzureTargetUri(rawUri); - if (!parsed) return false; - return Boolean(parsed.apiVersion) && (Boolean(parsed.deploymentName) || parsed.isResponsesPath); -} - -/** Connection payload produced by a provider connect form. */ -export interface ConnectionDraft { - base_url: string | null; - api_key: string | null; - extra: Record; - /** Model id to seed after creation (providers without discovery, e.g. Azure). */ - seedModelId?: string; -} - -/** Props shared by every provider-specific connect form. */ -export interface ProviderConnectFormProps { - provider: string; - defaultBaseUrl: string; - baseUrlRequired: boolean; - onDraftChange: (draft: ConnectionDraft, canSubmit: boolean) => void; -} diff --git a/surfsense_web/components/settings/model-connections/vertex-connect-form.tsx b/surfsense_web/components/settings/model-connections/vertex-connect-form.tsx deleted file mode 100644 index 1027742bc..000000000 --- a/surfsense_web/components/settings/model-connections/vertex-connect-form.tsx +++ /dev/null @@ -1,118 +0,0 @@ -import { useEffect, useState } from "react"; -import { Input } from "@/components/ui/input"; -import { Label } from "@/components/ui/label"; -import { - Select, - SelectContent, - SelectItem, - SelectTrigger, - SelectValue, -} from "@/components/ui/select"; -import { - type ProviderConnectFormProps, - VERTEX_AUTH_SERVICE_ACCOUNT, - VERTEX_AUTH_WORKLOAD_IDENTITY, - VERTEX_DEFAULT_LOCATION, -} from "./provider-metadata"; - -/** - * Google Vertex AI (Gemini) connect form. Service-account auth uploads a - * credentials JSON file (read into a string); workload identity collects a - * project id. Credentials ride along in `extra.litellm_params`. - */ -export function VertexConnectForm({ onDraftChange }: ProviderConnectFormProps) { - const [authMethod, setAuthMethod] = useState(VERTEX_AUTH_SERVICE_ACCOUNT); - const [location, setLocation] = useState(VERTEX_DEFAULT_LOCATION); - const [credentials, setCredentials] = useState(""); - const [project, setProject] = useState(""); - - const canSubmit = - authMethod === VERTEX_AUTH_SERVICE_ACCOUNT ? Boolean(credentials) : Boolean(project); - - async function handleCredentialsFile(file: File | undefined) { - if (!file) return; - setCredentials(await file.text()); - } - - useEffect(() => { - const params: Record = {}; - if (location) params.vertex_location = location; - if (authMethod === VERTEX_AUTH_SERVICE_ACCOUNT) { - if (credentials) params.vertex_credentials = credentials; - } else if (project) { - params.vertex_project = project; - } - onDraftChange({ base_url: null, api_key: null, extra: { litellm_params: params } }, canSubmit); - }, [authMethod, canSubmit, credentials, location, onDraftChange, project]); - - return ( -
-
- - -
-
- - setLocation(event.target.value)} - placeholder={VERTEX_DEFAULT_LOCATION} - /> -

- Region where your Google Vertex AI models are hosted. -

-
- {authMethod === VERTEX_AUTH_SERVICE_ACCOUNT ? ( -
- - handleCredentialsFile(event.target.files?.[0])} - /> - -

- {credentials - ? "Credentials file loaded." - : "Attach your service account key JSON from Google Cloud."} -

-
- ) : ( -
- - setProject(event.target.value)} - placeholder="my-vertex-project" - /> -

- The GCP project where Vertex AI is enabled. -

-
- )} -

- Add Vertex AI model IDs from the provider's settings after connecting. -

-
- ); -} diff --git a/surfsense_web/components/settings/vision-model-manager.tsx b/surfsense_web/components/settings/vision-model-manager.tsx new file mode 100644 index 000000000..31578b4f1 --- /dev/null +++ b/surfsense_web/components/settings/vision-model-manager.tsx @@ -0,0 +1,486 @@ +"use client"; + +import { useAtomValue } from "jotai"; +import { AlertCircle, Dot, Info, Pencil, RefreshCw, Trash2 } from "lucide-react"; +import { useMemo, useState } from "react"; +import { membersAtom, myAccessAtom } from "@/atoms/members/members-query.atoms"; +import { deleteVisionLLMConfigMutationAtom } from "@/atoms/vision-llm-config/vision-llm-config-mutation.atoms"; +import { + globalVisionLLMConfigsAtom, + visionLLMConfigsAtom, +} from "@/atoms/vision-llm-config/vision-llm-config-query.atoms"; +import { VisionConfigDialog } from "@/components/shared/vision-config-dialog"; +import { Alert, AlertDescription } from "@/components/ui/alert"; +import { + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, +} from "@/components/ui/alert-dialog"; +import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar"; +import { Badge } from "@/components/ui/badge"; +import { Button } from "@/components/ui/button"; +import { Card, CardContent } from "@/components/ui/card"; +import { Separator } from "@/components/ui/separator"; +import { Skeleton } from "@/components/ui/skeleton"; +import { Spinner } from "@/components/ui/spinner"; +import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"; +import type { VisionLLMConfig } from "@/contracts/types/new-llm-config.types"; +import { useMediaQuery } from "@/hooks/use-media-query"; +import { getProviderIcon } from "@/lib/provider-icons"; +import { cn } from "@/lib/utils"; + +interface VisionModelManagerProps { + searchSpaceId: number; +} + +function getInitials(name: string): string { + const parts = name.trim().split(/\s+/); + if (parts.length >= 2) { + return (parts[0][0] + parts[1][0]).toUpperCase(); + } + return name.slice(0, 2).toUpperCase(); +} + +export function VisionModelManager({ searchSpaceId }: VisionModelManagerProps) { + const isDesktop = useMediaQuery("(min-width: 768px)"); + + const { + mutateAsync: deleteConfig, + isPending: isDeleting, + error: deleteError, + } = useAtomValue(deleteVisionLLMConfigMutationAtom); + + const { + data: userConfigs, + isFetching: configsLoading, + error: fetchError, + refetch: refreshConfigs, + } = useAtomValue(visionLLMConfigsAtom); + const { data: globalConfigs = [], isFetching: globalLoading } = useAtomValue( + globalVisionLLMConfigsAtom + ); + + const { data: members } = useAtomValue(membersAtom); + const memberMap = useMemo(() => { + const map = new Map(); + if (members) { + for (const m of members) { + map.set(m.user_id, { + name: m.user_display_name || m.user_email || "Unknown", + email: m.user_email || undefined, + avatarUrl: m.user_avatar_url || undefined, + }); + } + } + return map; + }, [members]); + + const { data: access } = useAtomValue(myAccessAtom); + const canCreate = useMemo(() => { + if (!access) return false; + if (access.is_owner) return true; + return access.permissions?.includes("vision_configs:create") ?? false; + }, [access]); + const canDelete = useMemo(() => { + if (!access) return false; + if (access.is_owner) return true; + return access.permissions?.includes("vision_configs:delete") ?? false; + }, [access]); + const canUpdate = canCreate; + const isReadOnly = !canCreate && !canDelete; + + const [isDialogOpen, setIsDialogOpen] = useState(false); + const [editingConfig, setEditingConfig] = useState(null); + const [configToDelete, setConfigToDelete] = useState(null); + + const isLoading = configsLoading || globalLoading; + const errors = [deleteError, fetchError].filter(Boolean) as Error[]; + + const openEditDialog = (config: VisionLLMConfig) => { + setEditingConfig(config); + setIsDialogOpen(true); + }; + + const openNewDialog = () => { + setEditingConfig(null); + setIsDialogOpen(true); + }; + + const handleDelete = async () => { + if (!configToDelete) return; + try { + await deleteConfig({ id: configToDelete.id, name: configToDelete.name }); + setConfigToDelete(null); + } catch { + // Error handled by mutation + } + }; + + return ( +
+
+ + {canCreate && ( + + )} +
+ + {errors.map((err) => ( +
+ + + {err?.message} + +
+ ))} + + {access && !isLoading && isReadOnly && ( +
+ + + +

+ You have read-only access to vision model + configurations. Contact a space owner to request additional permissions. +

+
+
+
+ )} + {access && !isLoading && !isReadOnly && (!canCreate || !canDelete) && ( +
+ + + +

+ You can{" "} + {[canCreate && "create and edit", canDelete && "delete"] + .filter(Boolean) + .join(" and ")}{" "} + vision model configurations + {!canDelete && ", but cannot delete them"}. +

+
+
+
+ )} + + {(isLoading || + globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length > 0) && ( + + + + {isLoading ? ( +
+ +
+ ) : ( +

+ + {globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length}{" "} + global vision{" "} + {globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length === + 1 + ? "model" + : "models"} + {" "} + available from your administrator. {(() => { + const nonAuto = globalConfigs.filter( + (g) => !("is_auto_mode" in g && g.is_auto_mode) + ); + const premium = nonAuto.filter( + (g) => + "billing_tier" in g && + (g as { billing_tier?: string }).billing_tier === "premium" + ).length; + const free = nonAuto.length - premium; + if (premium > 0 && free > 0) { + return `${premium} premium, ${free} free.`; + } + if (premium > 0) { + return `All ${premium} premium — debits your shared credit pool.`; + } + return `All ${free} free.`; + })()} +

+ )} +
+
+ )} + + {/* Global Vision Models — read-only cards with per-model Free/Premium + badges. Mirrors the badge palette used by the chat role selector + (`llm-role-manager.tsx`) so the meaning is consistent across + every model-configuration surface (chat / image / vision). */} + {!isLoading && + globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length > 0 && ( +
+
+ {globalConfigs + .filter((g) => !("is_auto_mode" in g && g.is_auto_mode)) + .map((cfg) => { + const billingTier = + ("billing_tier" in cfg && + typeof (cfg as { billing_tier?: string }).billing_tier === "string" && + (cfg as { billing_tier?: string }).billing_tier) || + "free"; + const isPremium = billingTier === "premium"; + return ( + + +
+
+ {getProviderIcon(cfg.provider, { className: "size-4" })} +
+
+

+ {cfg.name} +

+ {isPremium ? ( + + Premium + + ) : ( + + Free + + )} +
+
+ {cfg.description && ( +

+ {cfg.description} +

+ )} +
+ +
+ + {cfg.model_name} + +
+
+
+
+ ); + })} +
+
+ )} + + {isLoading && ( +
+
+
+ {["skeleton-a", "skeleton-b", "skeleton-c"].map((key) => ( + + + + + + + + ))} +
+
+
+ )} + + {!isLoading && ( +
+ {(userConfigs?.length ?? 0) === 0 ? ( + + +

No Vision Models Yet

+

+ {canCreate + ? "Add your own vision-capable model (GPT-4o, Claude, Gemini, etc.)" + : "No vision models have been added to this space yet. Contact a space owner to add one."} +

+
+
+ ) : ( +
+ {userConfigs?.map((config) => { + const member = config.user_id ? memberMap.get(config.user_id) : null; + + return ( +
+ + + {/* Header: Icon + Name + Actions */} +
+
+
+ {getProviderIcon(config.provider, { className: "size-4" })} +
+
+

+ {config.name} +

+ {config.description && ( +

+ {config.description} +

+ )} +
+
+ {(canUpdate || canDelete) && ( +
+ {canUpdate && ( + + + + + + Edit + + + )} + {canDelete && ( + + + + + + Delete + + + )} +
+ )} +
+ + {/* Footer: Date + Creator */} +
+ +
+ + {new Date(config.created_at).toLocaleDateString(undefined, { + year: "numeric", + month: "short", + day: "numeric", + })} + + {member && ( + <> + + + + +
+ + {member.avatarUrl && ( + + )} + + {getInitials(member.name)} + + + + {member.name} + +
+
+ + {member.email || member.name} + +
+
+ + )} +
+
+
+
+
+ ); + })} +
+ )} +
+ )} + + { + setIsDialogOpen(open); + if (!open) setEditingConfig(null); + }} + config={editingConfig} + isGlobal={false} + searchSpaceId={searchSpaceId} + mode={editingConfig ? "edit" : "create"} + /> + + !open && setConfigToDelete(null)} + > + + + Delete Vision Model + + Are you sure you want to delete{" "} + {configToDelete?.name}? + + + + Cancel + + Delete + {isDeleting && } + + + + +
+ ); +} diff --git a/surfsense_web/components/shared/image-config-dialog.tsx b/surfsense_web/components/shared/image-config-dialog.tsx new file mode 100644 index 000000000..36d16081a --- /dev/null +++ b/surfsense_web/components/shared/image-config-dialog.tsx @@ -0,0 +1,456 @@ +"use client"; + +import { useAtomValue } from "jotai"; +import { AlertCircle, Check, ChevronsUpDown } from "lucide-react"; +import { useCallback, useEffect, useMemo, useRef, useState } from "react"; +import { toast } from "sonner"; +import { + createImageGenConfigMutationAtom, + updateImageGenConfigMutationAtom, +} from "@/atoms/image-gen-config/image-gen-config-mutation.atoms"; +import { updateLLMPreferencesMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; +import { Alert, AlertDescription } from "@/components/ui/alert"; +import { Badge } from "@/components/ui/badge"; +import { Button } from "@/components/ui/button"; +import { + Command, + CommandEmpty, + CommandGroup, + CommandInput, + CommandItem, + CommandList, +} from "@/components/ui/command"; +import { Dialog, DialogContent, DialogTitle } from "@/components/ui/dialog"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import { Separator } from "@/components/ui/separator"; +import { Spinner } from "@/components/ui/spinner"; +import { IMAGE_GEN_MODELS, IMAGE_GEN_PROVIDERS } from "@/contracts/enums/image-gen-providers"; +import type { + GlobalImageGenConfig, + ImageGenerationConfig, + ImageGenProvider, +} from "@/contracts/types/new-llm-config.types"; +import { cn } from "@/lib/utils"; + +interface ImageConfigDialogProps { + open: boolean; + onOpenChange: (open: boolean) => void; + config: ImageGenerationConfig | GlobalImageGenConfig | null; + isGlobal: boolean; + searchSpaceId: number; + mode: "create" | "edit" | "view"; + defaultProvider?: string; +} + +const INITIAL_FORM = { + name: "", + description: "", + provider: "", + model_name: "", + api_key: "", + api_base: "", + api_version: "", +}; + +export function ImageConfigDialog({ + open, + onOpenChange, + config, + isGlobal, + searchSpaceId, + mode, + defaultProvider, +}: ImageConfigDialogProps) { + const [isSubmitting, setIsSubmitting] = useState(false); + const [formData, setFormData] = useState(INITIAL_FORM); + const [modelComboboxOpen, setModelComboboxOpen] = useState(false); + const [scrollPos, setScrollPos] = useState<"top" | "middle" | "bottom">("top"); + const scrollRef = useRef(null); + + useEffect(() => { + if (open) { + if (mode === "edit" && config && !isGlobal) { + setFormData({ + name: config.name || "", + description: config.description || "", + provider: config.provider || "", + model_name: config.model_name || "", + api_key: (config as ImageGenerationConfig).api_key || "", + api_base: config.api_base || "", + api_version: config.api_version || "", + }); + } else if (mode === "create") { + setFormData({ ...INITIAL_FORM, provider: defaultProvider ?? "" }); + } + setScrollPos("top"); + } + }, [open, mode, config, isGlobal, defaultProvider]); + + const { mutateAsync: createConfig } = useAtomValue(createImageGenConfigMutationAtom); + const { mutateAsync: updateConfig } = useAtomValue(updateImageGenConfigMutationAtom); + const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom); + + const handleScroll = useCallback((e: React.UIEvent) => { + const el = e.currentTarget; + const atTop = el.scrollTop <= 2; + const atBottom = el.scrollHeight - el.scrollTop - el.clientHeight <= 2; + setScrollPos(atTop ? "top" : atBottom ? "bottom" : "middle"); + }, []); + + const suggestedModels = useMemo(() => { + if (!formData.provider) return []; + return IMAGE_GEN_MODELS.filter((m) => m.provider === formData.provider); + }, [formData.provider]); + + const getTitle = () => { + if (mode === "create") return "Add Image Model"; + if (isGlobal) return "View Global Image Model"; + return "Edit Image Model"; + }; + + const getSubtitle = () => { + if (mode === "create") return "Set up a new image generation provider"; + if (isGlobal) return "Read-only global configuration"; + return "Update your image model settings"; + }; + + const handleSubmit = useCallback(async () => { + setIsSubmitting(true); + try { + if (mode === "create") { + const result = await createConfig({ + name: formData.name, + provider: formData.provider as ImageGenProvider, + model_name: formData.model_name, + api_key: formData.api_key, + api_base: formData.api_base || undefined, + api_version: formData.api_version || undefined, + description: formData.description || undefined, + search_space_id: searchSpaceId, + }); + if (result?.id) { + await updatePreferences({ + search_space_id: searchSpaceId, + data: { image_generation_config_id: result.id }, + }); + } + onOpenChange(false); + } else if (!isGlobal && config) { + await updateConfig({ + id: config.id, + data: { + name: formData.name, + description: formData.description || undefined, + provider: formData.provider as ImageGenProvider, + model_name: formData.model_name, + api_key: formData.api_key, + api_base: formData.api_base || undefined, + api_version: formData.api_version || undefined, + }, + }); + onOpenChange(false); + } + } catch (error) { + console.error("Failed to save image config:", error); + toast.error("Failed to save image model"); + } finally { + setIsSubmitting(false); + } + }, [ + mode, + isGlobal, + config, + formData, + searchSpaceId, + createConfig, + updateConfig, + updatePreferences, + onOpenChange, + ]); + + const handleUseGlobalConfig = useCallback(async () => { + if (!config || !isGlobal) return; + setIsSubmitting(true); + try { + await updatePreferences({ + search_space_id: searchSpaceId, + data: { image_generation_config_id: config.id }, + }); + toast.success(`Now using ${config.name}`); + onOpenChange(false); + } catch (error) { + console.error("Failed to set image model:", error); + toast.error("Failed to set image model"); + } finally { + setIsSubmitting(false); + } + }, [config, isGlobal, searchSpaceId, updatePreferences, onOpenChange]); + + const isFormValid = formData.name && formData.provider && formData.model_name && formData.api_key; + const selectedProvider = IMAGE_GEN_PROVIDERS.find((p) => p.value === formData.provider); + + return ( + + e.preventDefault()} + > + {getTitle()} + + {/* Header */} +
+
+
+

{getTitle()}

+ {isGlobal && mode !== "create" && ( + + Global + + )} +
+

{getSubtitle()}

+ {config && mode !== "create" && ( +

{config.model_name}

+ )} +
+
+ + {/* Scrollable content */} +
+ {isGlobal && config && ( + <> + + + + Global configurations are read-only. To customize, create a new model. + + +
+
+
+
+ Name +
+

{config.name}

+
+ {config.description && ( +
+
+ Description +
+

{config.description}

+
+ )} +
+ +
+
+
+ Provider +
+

{config.provider}

+
+
+
+ Model +
+

{config.model_name}

+
+
+
+ + )} + + {(mode === "create" || (mode === "edit" && !isGlobal)) && ( +
+
+ + setFormData((p) => ({ ...p, name: e.target.value }))} + /> +
+ +
+ + setFormData((p) => ({ ...p, description: e.target.value }))} + /> +
+ + + +
+ + +
+ +
+ + {suggestedModels.length > 0 ? ( + + + + + + + setFormData((p) => ({ ...p, model_name: val }))} + /> + + + + Type a custom model name + + + + {suggestedModels.map((m) => ( + { + setFormData((p) => ({ ...p, model_name: m.value })); + setModelComboboxOpen(false); + }} + > + + {m.value} + + {m.label} + + + ))} + + + + + + ) : ( + setFormData((p) => ({ ...p, model_name: e.target.value }))} + /> + )} +
+ +
+ + setFormData((p) => ({ ...p, api_key: e.target.value }))} + /> +
+ +
+ + setFormData((p) => ({ ...p, api_base: e.target.value }))} + /> +
+ + {formData.provider === "AZURE_OPENAI" && ( +
+ + setFormData((p) => ({ ...p, api_version: e.target.value }))} + /> +
+ )} +
+ )} +
+ + {/* Fixed footer */} +
+ + {mode === "create" || (mode === "edit" && !isGlobal) ? ( + + ) : isGlobal && config ? ( + + ) : null} +
+
+
+ ); +} diff --git a/surfsense_web/components/shared/llm-config-form.tsx b/surfsense_web/components/shared/llm-config-form.tsx new file mode 100644 index 000000000..06de4129b --- /dev/null +++ b/surfsense_web/components/shared/llm-config-form.tsx @@ -0,0 +1,527 @@ +"use client"; + +import { zodResolver } from "@hookform/resolvers/zod"; +import { useAtomValue } from "jotai"; +import { Check, ChevronDown, ChevronsUpDown } from "lucide-react"; +import { useEffect, useMemo, useState } from "react"; +import { type Resolver, useForm } from "react-hook-form"; +import { z } from "zod"; +import { + defaultSystemInstructionsAtom, + modelListAtom, +} from "@/atoms/new-llm-config/new-llm-config-query.atoms"; +import { Badge } from "@/components/ui/badge"; +import { Button } from "@/components/ui/button"; +import { Collapsible, CollapsibleContent, CollapsibleTrigger } from "@/components/ui/collapsible"; +import { + Command, + CommandEmpty, + CommandGroup, + CommandInput, + CommandItem, + CommandList, +} from "@/components/ui/command"; +import { + Form, + FormControl, + FormDescription, + FormField, + FormItem, + FormLabel, + FormMessage, +} from "@/components/ui/form"; +import { Input } from "@/components/ui/input"; +import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import { Separator } from "@/components/ui/separator"; +import { Switch } from "@/components/ui/switch"; +import { Textarea } from "@/components/ui/textarea"; +import { LLM_PROVIDERS } from "@/contracts/enums/llm-providers"; +import type { CreateNewLLMConfigRequest } from "@/contracts/types/new-llm-config.types"; +import { cn } from "@/lib/utils"; +import InferenceParamsEditor from "../inference-params-editor"; + +// Form schema with zod +const formSchema = z.object({ + name: z.string().min(1, "Name is required").max(100), + description: z.string().max(500).optional().nullable(), + provider: z.string().min(1, "Provider is required"), + custom_provider: z.string().max(100).optional().nullable(), + model_name: z.string().min(1, "Model name is required").max(100), + api_key: z.string().min(1, "API key is required"), + api_base: z.string().max(500).optional().nullable(), + litellm_params: z.record(z.string(), z.any()).optional().nullable(), + system_instructions: z.string().default(""), + use_default_system_instructions: z.boolean().default(true), + citations_enabled: z.boolean().default(true), + search_space_id: z.number(), +}); + +type FormValues = z.infer; + +export type LLMConfigFormData = CreateNewLLMConfigRequest; + +interface LLMConfigFormProps { + initialData?: Partial; + searchSpaceId: number; + onSubmit: (data: LLMConfigFormData) => Promise; + mode?: "create" | "edit"; + showAdvanced?: boolean; + formId?: string; +} + +export function LLMConfigForm({ + initialData, + searchSpaceId, + onSubmit, + mode = "create", + showAdvanced = true, + formId, +}: LLMConfigFormProps) { + const { data: defaultInstructions, isSuccess: defaultInstructionsLoaded } = useAtomValue( + defaultSystemInstructionsAtom + ); + const { data: dynamicModels } = useAtomValue(modelListAtom); + const [modelComboboxOpen, setModelComboboxOpen] = useState(false); + const [advancedOpen, setAdvancedOpen] = useState(false); + const [systemInstructionsOpen, setSystemInstructionsOpen] = useState(false); + + const form = useForm({ + resolver: zodResolver(formSchema) as Resolver, + defaultValues: { + name: initialData?.name ?? "", + description: initialData?.description ?? "", + provider: initialData?.provider ?? "", + custom_provider: initialData?.custom_provider ?? "", + model_name: initialData?.model_name ?? "", + api_key: initialData?.api_key ?? "", + api_base: initialData?.api_base ?? "", + litellm_params: initialData?.litellm_params ?? {}, + system_instructions: initialData?.system_instructions ?? "", + use_default_system_instructions: initialData?.use_default_system_instructions ?? true, + citations_enabled: initialData?.citations_enabled ?? true, + search_space_id: searchSpaceId, + }, + }); + + // Load default instructions when available (only for new configs) + useEffect(() => { + if ( + mode === "create" && + defaultInstructionsLoaded && + defaultInstructions?.default_system_instructions && + !form.getValues("system_instructions") + ) { + form.setValue("system_instructions", defaultInstructions.default_system_instructions); + } + }, [defaultInstructionsLoaded, defaultInstructions, mode, form]); + + const watchProvider = form.watch("provider"); + const selectedProvider = LLM_PROVIDERS.find((p) => p.value === watchProvider); + const availableModels = useMemo( + () => (dynamicModels ?? []).filter((m) => m.provider === watchProvider), + [dynamicModels, watchProvider] + ); + + const handleProviderChange = (value: string) => { + form.setValue("provider", value); + form.setValue("model_name", ""); + + // Auto-fill API base for certain providers + const provider = LLM_PROVIDERS.find((p) => p.value === value); + if (provider?.apiBase) { + form.setValue("api_base", provider.apiBase); + } + }; + + const handleFormSubmit = async (values: FormValues) => { + await onSubmit(values as LLMConfigFormData); + }; + + return ( +
+ + {/* Model Configuration Section */} +
+
+ Model Configuration +
+ + {/* Name & Description */} +
+ ( + + Configuration Name + + + + + + )} + /> + + ( + + + Description + + Optional + + + + + + + + )} + /> +
+ + {/* Provider Selection */} + ( + + LLM Provider + + + + )} + /> + + {/* Custom Provider (conditional) */} + {watchProvider === "CUSTOM" && ( + ( + + Custom Provider Name + + + + + + )} + /> + )} + + {/* Model Name with Combobox */} + ( + + Model Name + + + + + + + + + + + +
+ {field.value ? `Using: "${field.value}"` : "Type your model name"} +
+
+ {availableModels.length > 0 && ( + + {availableModels + .filter( + (model) => + !field.value || + model.value.toLowerCase().includes(field.value.toLowerCase()) || + model.label.toLowerCase().includes(field.value.toLowerCase()) + ) + .slice(0, 50) + .map((model) => ( + { + field.onChange(value); + setModelComboboxOpen(false); + }} + className="py-2" + > + +
+
{model.label}
+ {model.contextWindow && ( +
+ Context: {model.contextWindow} +
+ )} +
+
+ ))} +
+ )} +
+
+
+
+ {selectedProvider?.example && ( + + Example: {selectedProvider.example} + + )} + +
+ )} + /> + + {/* API Credentials */} +
+ ( + + API Key + + + + {watchProvider === "OLLAMA" && ( + + Ollama doesn't require auth — enter any value + + )} + + + )} + /> + + ( + + + API Base URL + {selectedProvider?.apiBase && ( + + Auto-filled + + )} + + + + + + + )} + /> +
+ + {/* Ollama Quick Actions */} + {watchProvider === "OLLAMA" && ( +
+ + +
+ )} +
+ + {/* Advanced Parameters */} + {showAdvanced && ( + <> + + + + + + + ( + + + + + + + )} + /> + + + + )} + + {/* System Instructions & Citations Section */} + + + + + + + {/* System Instructions */} + ( + +
+ Instructions for the AI + {defaultInstructions && ( + + )} +
+ +