mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-18 21:15:16 +02:00
Merge pull request #1509 from MODSetter/dev
feat(release: 0.0.29): ETL/embedding caches, unified model connections, reverse-proxy support, podcast & indexing improvements
This commit is contained in:
commit
c941907448
408 changed files with 15877 additions and 16310 deletions
7
.github/workflows/desktop-release.yml
vendored
7
.github/workflows/desktop-release.yml
vendored
|
|
@ -95,10 +95,12 @@ jobs:
|
|||
run: pnpm build
|
||||
working-directory: surfsense_web
|
||||
env:
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_URL: ${{ vars.NEXT_PUBLIC_FASTAPI_BACKEND_URL }}
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_URL: ${{ vars.HOSTED_BACKEND_URL }}
|
||||
SURFSENSE_BACKEND_INTERNAL_URL: ${{ vars.HOSTED_BACKEND_URL }}
|
||||
NEXT_PUBLIC_ZERO_CACHE_URL: ${{ vars.NEXT_PUBLIC_ZERO_CACHE_URL }}
|
||||
NEXT_PUBLIC_DEPLOYMENT_MODE: ${{ vars.NEXT_PUBLIC_DEPLOYMENT_MODE }}
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE: ${{ vars.NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE }}
|
||||
NEXT_PUBLIC_AUTH_TYPE: ${{ vars.NEXT_PUBLIC_AUTH_TYPE }}
|
||||
NEXT_PUBLIC_ETL_SERVICE: ${{ vars.NEXT_PUBLIC_ETL_SERVICE }}
|
||||
NEXT_PUBLIC_POSTHOG_KEY: ${{ secrets.NEXT_PUBLIC_POSTHOG_KEY }}
|
||||
|
||||
- name: Install desktop dependencies
|
||||
|
|
@ -109,6 +111,7 @@ jobs:
|
|||
run: pnpm build
|
||||
working-directory: surfsense_desktop
|
||||
env:
|
||||
HOSTED_BACKEND_URL: ${{ vars.HOSTED_BACKEND_URL }}
|
||||
HOSTED_FRONTEND_URL: ${{ vars.HOSTED_FRONTEND_URL }}
|
||||
POSTHOG_KEY: ${{ secrets.POSTHOG_KEY }}
|
||||
POSTHOG_HOST: ${{ vars.POSTHOG_HOST }}
|
||||
|
|
|
|||
5
.github/workflows/docker-build.yml
vendored
5
.github/workflows/docker-build.yml
vendored
|
|
@ -199,11 +199,6 @@ jobs:
|
|||
build-args: |
|
||||
${{ matrix.image == 'backend' && format('USE_CUDA={0}', matrix.use_cuda) || '' }}
|
||||
${{ matrix.image == 'backend' && format('CUDA_EXTRA={0}', matrix.cuda_extra) || '' }}
|
||||
${{ matrix.image == 'web' && 'NEXT_PUBLIC_FASTAPI_BACKEND_URL=__NEXT_PUBLIC_FASTAPI_BACKEND_URL__' || '' }}
|
||||
${{ matrix.image == 'web' && 'NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=__NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE__' || '' }}
|
||||
${{ matrix.image == 'web' && 'NEXT_PUBLIC_ETL_SERVICE=__NEXT_PUBLIC_ETL_SERVICE__' || '' }}
|
||||
${{ matrix.image == 'web' && 'NEXT_PUBLIC_ZERO_CACHE_URL=__NEXT_PUBLIC_ZERO_CACHE_URL__' || '' }}
|
||||
${{ matrix.image == 'web' && 'NEXT_PUBLIC_DEPLOYMENT_MODE=__NEXT_PUBLIC_DEPLOYMENT_MODE__' || '' }}
|
||||
|
||||
- name: Export digest
|
||||
run: |
|
||||
|
|
|
|||
5
.github/workflows/e2e-tests.yml
vendored
5
.github/workflows/e2e-tests.yml
vendored
|
|
@ -27,9 +27,10 @@ jobs:
|
|||
PLAYWRIGHT_TEST_EMAIL: e2e-test@surfsense.net
|
||||
PLAYWRIGHT_TEST_PASSWORD: E2eTestPassword123!
|
||||
# Frontend env: Playwright's webServer (surfsense_web/playwright.config.ts)
|
||||
# spawns `pnpm build && pnpm start` in CI; these get baked into the build.
|
||||
# spawns `pnpm build && pnpm start` in CI.
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_URL: http://localhost:8000
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE: LOCAL
|
||||
SURFSENSE_BACKEND_INTERNAL_URL: http://localhost:8000
|
||||
AUTH_TYPE: LOCAL
|
||||
# Shared secret for the test-only POST /__e2e__/auth/token endpoint.
|
||||
# Must match docker-compose.e2e.yml's backend env (x-backend-env).
|
||||
E2E_MINT_SECRET: e2e-mint-secret-not-for-production
|
||||
|
|
|
|||
2
VERSION
2
VERSION
|
|
@ -1 +1 @@
|
|||
0.0.28
|
||||
0.0.29
|
||||
|
|
|
|||
|
|
@ -30,6 +30,9 @@ SECRET_KEY=replace_me_with_a_random_string
|
|||
# Auth type: LOCAL (email/password) or GOOGLE (OAuth)
|
||||
AUTH_TYPE=LOCAL
|
||||
|
||||
# Deployment mode: self-hosted enables local filesystem connectors; cloud hides them.
|
||||
DEPLOYMENT_MODE=self-hosted
|
||||
|
||||
# Allow new user registrations (TRUE or FALSE)
|
||||
# REGISTRATION_ENABLED=TRUE
|
||||
|
||||
|
|
@ -43,51 +46,47 @@ ETL_SERVICE=DOCLING
|
|||
EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Ports (change to avoid conflicts with other services on your machine)
|
||||
# How You Access SurfSense
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# BACKEND_PORT=8929
|
||||
# FRONTEND_PORT=3929
|
||||
# ZERO_CACHE_PORT=5929
|
||||
# SEARXNG_PORT=8888
|
||||
# FLOWER_PORT=5555
|
||||
|
||||
# ==============================================================================
|
||||
# DEV COMPOSE ONLY (docker-compose.dev.yml)
|
||||
# You only need them only if you are running `docker-compose.dev.yml`.
|
||||
# ==============================================================================
|
||||
|
||||
# -- pgAdmin (database GUI) --
|
||||
# PGADMIN_PORT=5050
|
||||
# PGADMIN_DEFAULT_EMAIL=admin@surfsense.com
|
||||
# PGADMIN_DEFAULT_PASSWORD=surfsense
|
||||
|
||||
# -- Redis exposed port (dev only; Redis is internal-only in prod) --
|
||||
# REDIS_PORT=6379
|
||||
|
||||
# -- WhatsApp bridge exposed port (dev/hybrid only; prod keeps it Docker-internal) --
|
||||
# WHATSAPP_BRIDGE_PORT=9929
|
||||
|
||||
# -- Frontend Build Args --
|
||||
# In dev, the frontend is built from source and these are passed as build args.
|
||||
# In prod, they are automatically derived from AUTH_TYPE, ETL_SERVICE, and the port settings above.
|
||||
# NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=LOCAL
|
||||
# NEXT_PUBLIC_ETL_SERVICE=DOCLING
|
||||
# NEXT_PUBLIC_DEPLOYMENT_MODE=self-hosted
|
||||
# One public URL. Browser traffic stays same-origin and Caddy routes internally.
|
||||
SURFSENSE_PUBLIC_URL=http://localhost:3929
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Custom Domain / Reverse Proxy
|
||||
# Public Ports
|
||||
# ------------------------------------------------------------------------------
|
||||
# ONLY set these if you are serving SurfSense on a real domain via a reverse
|
||||
# proxy (e.g. Caddy, Nginx, Cloudflare Tunnel).
|
||||
# For standard localhost deployments, leave all of these commented out.
|
||||
# they are automatically derived from the port settings above.
|
||||
# Production Docker exposes only Caddy to your machine. Caddy then routes
|
||||
# frontend, backend, and zero-cache traffic internally.
|
||||
#
|
||||
# NEXT_FRONTEND_URL=https://app.yourdomain.com
|
||||
# BACKEND_URL=https://api.yourdomain.com
|
||||
# NEXT_PUBLIC_FASTAPI_BACKEND_URL=https://api.yourdomain.com
|
||||
# NEXT_PUBLIC_ZERO_CACHE_URL=https://zero.yourdomain.com
|
||||
# FASTAPI_BACKEND_INTERNAL_URL=http://backend:8000
|
||||
# Local default: LISTEN_HTTP_PORT=3929
|
||||
# Domain default: LISTEN_HTTP_PORT=80 and LISTEN_HTTPS_PORT=443
|
||||
LISTEN_HTTP_PORT=3929
|
||||
LISTEN_HTTPS_PORT=443
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Custom Domain / HTTPS
|
||||
# ------------------------------------------------------------------------------
|
||||
# Leave SURFSENSE_SITE_ADDRESS as :80 for local HTTP.
|
||||
# Set it to your domain to enable automatic HTTPS:
|
||||
# SURFSENSE_SITE_ADDRESS=surf.example.com
|
||||
# CERT_EMAIL=you@example.com
|
||||
SURFSENSE_SITE_ADDRESS=:80
|
||||
CERT_EMAIL=
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Advanced Reverse Proxy Settings
|
||||
# ------------------------------------------------------------------------------
|
||||
# Usually do not change these. They are for custom certificate setup, CDNs/load
|
||||
# balancers, trusted proxy IPs, or changing upload limits.
|
||||
#
|
||||
# CERT_ACME_CA=https://acme-v02.api.letsencrypt.org/directory
|
||||
# CERT_ACME_DNS=
|
||||
# If a CDN/load balancer sits in front of Caddy, narrow this to that proxy's CIDRs.
|
||||
# TRUSTED_PROXIES=0.0.0.0/0
|
||||
# SURFSENSE_MAX_BODY_SIZE=5GB
|
||||
#
|
||||
# Browser API and Zero URLs are same-origin relative behind bundled Caddy.
|
||||
# Next.js server-side calls use Docker DNS through SURFSENSE_BACKEND_INTERNAL_URL
|
||||
# set internally by docker-compose.yml. Usually do not override it.
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Zero-cache (real-time sync)
|
||||
|
|
@ -108,10 +107,9 @@ EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
|
|||
|
||||
# Sync worker tuning. zero-cache defaults ZERO_NUM_SYNC_WORKERS to the number
|
||||
# of CPU cores, which can exceed the connection pool limits on high-core machines.
|
||||
# Each sync worker needs at least 1 connection from both the UPSTREAM and CVR
|
||||
# pools, so these constraints must hold:
|
||||
# ZERO_UPSTREAM_MAX_CONNS >= ZERO_NUM_SYNC_WORKERS
|
||||
# ZERO_CVR_MAX_CONNS >= ZERO_NUM_SYNC_WORKERS
|
||||
# Each sync worker needs at least 1 connection from both the UPSTREAM and CVR pools.
|
||||
# Keep ZERO_UPSTREAM_MAX_CONNS and ZERO_CVR_MAX_CONNS greater than or equal to
|
||||
# ZERO_NUM_SYNC_WORKERS.
|
||||
# Default of 4 workers is sufficient for self-hosted / personal use.
|
||||
# ZERO_NUM_SYNC_WORKERS=4
|
||||
# ZERO_UPSTREAM_MAX_CONNS=20
|
||||
|
|
@ -125,16 +123,16 @@ EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
|
|||
|
||||
# ZERO_QUERY_URL: where zero-cache forwards query requests for resolution.
|
||||
# ZERO_MUTATE_URL: required by zero-cache when auth tokens are used, even though
|
||||
# SurfSense does not use Zero mutators. Setting both URLs tells zero-cache to
|
||||
# skip its own JWT verification and let the app endpoints handle auth instead.
|
||||
# The mutate endpoint is a no-op that returns an empty response.
|
||||
# SurfSense does not use Zero mutators. Setting both URLs tells zero-cache to
|
||||
# skip its own JWT verification and let the app endpoints handle auth instead.
|
||||
# The mutate endpoint is a no-op that returns an empty response.
|
||||
# Default: Docker service networking (http://frontend:3000/api/zero/...).
|
||||
# Override when running the frontend outside Docker:
|
||||
# ZERO_QUERY_URL=http://host.docker.internal:3000/api/zero/query
|
||||
# ZERO_MUTATE_URL=http://host.docker.internal:3000/api/zero/mutate
|
||||
# Override for custom domain:
|
||||
# ZERO_QUERY_URL=https://app.yourdomain.com/api/zero/query
|
||||
# ZERO_MUTATE_URL=https://app.yourdomain.com/api/zero/mutate
|
||||
# ZERO_QUERY_URL=http://host.docker.internal:3000/api/zero/query
|
||||
# ZERO_MUTATE_URL=http://host.docker.internal:3000/api/zero/mutate
|
||||
# Override for custom domain only when zero-cache is not in the bundled Docker network:
|
||||
# ZERO_QUERY_URL=https://surf.example.com/api/zero/query
|
||||
# ZERO_MUTATE_URL=https://surf.example.com/api/zero/mutate
|
||||
# ZERO_QUERY_URL=http://frontend:3000/api/zero/query
|
||||
# ZERO_MUTATE_URL=http://frontend:3000/api/zero/mutate
|
||||
|
||||
|
|
@ -222,73 +220,74 @@ STT_SERVICE=local/base
|
|||
# ------------------------------------------------------------------------------
|
||||
|
||||
# -- Google Connectors --
|
||||
# GOOGLE_CALENDAR_REDIRECT_URI=http://localhost:8000/api/v1/auth/google/calendar/connector/callback
|
||||
# GOOGLE_GMAIL_REDIRECT_URI=http://localhost:8000/api/v1/auth/google/gmail/connector/callback
|
||||
# GOOGLE_DRIVE_REDIRECT_URI=http://localhost:8000/api/v1/auth/google/drive/connector/callback
|
||||
# GOOGLE_CALENDAR_REDIRECT_URI=http://localhost:3929/api/v1/auth/google/calendar/connector/callback
|
||||
# GOOGLE_GMAIL_REDIRECT_URI=http://localhost:3929/api/v1/auth/google/gmail/connector/callback
|
||||
# GOOGLE_DRIVE_REDIRECT_URI=http://localhost:3929/api/v1/auth/google/drive/connector/callback
|
||||
|
||||
# -- Notion --
|
||||
# NOTION_CLIENT_ID=
|
||||
# NOTION_CLIENT_SECRET=
|
||||
# NOTION_REDIRECT_URI=http://localhost:8000/api/v1/auth/notion/connector/callback
|
||||
# NOTION_REDIRECT_URI=http://localhost:3929/api/v1/auth/notion/connector/callback
|
||||
|
||||
# -- Slack --
|
||||
# SLACK_CLIENT_ID=
|
||||
# SLACK_CLIENT_SECRET=
|
||||
# SLACK_REDIRECT_URI=http://localhost:8000/api/v1/auth/slack/connector/callback
|
||||
# SLACK_REDIRECT_URI=http://localhost:3929/api/v1/auth/slack/connector/callback
|
||||
|
||||
# -- Discord --
|
||||
# DISCORD_CLIENT_ID=
|
||||
# DISCORD_CLIENT_SECRET=
|
||||
# DISCORD_REDIRECT_URI=http://localhost:8000/api/v1/auth/discord/connector/callback
|
||||
# DISCORD_REDIRECT_URI=http://localhost:3929/api/v1/auth/discord/connector/callback
|
||||
# DISCORD_BOT_TOKEN=
|
||||
|
||||
# -- Atlassian (Jira & Confluence) --
|
||||
# ATLASSIAN_CLIENT_ID=
|
||||
# ATLASSIAN_CLIENT_SECRET=
|
||||
# JIRA_REDIRECT_URI=http://localhost:8000/api/v1/auth/jira/connector/callback
|
||||
# CONFLUENCE_REDIRECT_URI=http://localhost:8000/api/v1/auth/confluence/connector/callback
|
||||
# JIRA_REDIRECT_URI=http://localhost:3929/api/v1/auth/jira/connector/callback
|
||||
# CONFLUENCE_REDIRECT_URI=http://localhost:3929/api/v1/auth/confluence/connector/callback
|
||||
|
||||
# -- Linear --
|
||||
# LINEAR_CLIENT_ID=
|
||||
# LINEAR_CLIENT_SECRET=
|
||||
# LINEAR_REDIRECT_URI=http://localhost:8000/api/v1/auth/linear/connector/callback
|
||||
# LINEAR_REDIRECT_URI=http://localhost:3929/api/v1/auth/linear/connector/callback
|
||||
|
||||
# -- ClickUp --
|
||||
# CLICKUP_CLIENT_ID=
|
||||
# CLICKUP_CLIENT_SECRET=
|
||||
# CLICKUP_REDIRECT_URI=http://localhost:8000/api/v1/auth/clickup/connector/callback
|
||||
# CLICKUP_REDIRECT_URI=http://localhost:3929/api/v1/auth/clickup/connector/callback
|
||||
|
||||
# -- Airtable --
|
||||
# AIRTABLE_CLIENT_ID=
|
||||
# AIRTABLE_CLIENT_SECRET=
|
||||
# AIRTABLE_REDIRECT_URI=http://localhost:8000/api/v1/auth/airtable/connector/callback
|
||||
# AIRTABLE_REDIRECT_URI=http://localhost:3929/api/v1/auth/airtable/connector/callback
|
||||
|
||||
# -- Microsoft OAuth (Teams & OneDrive) --
|
||||
# MICROSOFT_CLIENT_ID=
|
||||
# MICROSOFT_CLIENT_SECRET=
|
||||
# TEAMS_REDIRECT_URI=http://localhost:8000/api/v1/auth/teams/connector/callback
|
||||
# ONEDRIVE_REDIRECT_URI=http://localhost:8000/api/v1/auth/onedrive/connector/callback
|
||||
# TEAMS_REDIRECT_URI=http://localhost:3929/api/v1/auth/teams/connector/callback
|
||||
# ONEDRIVE_REDIRECT_URI=http://localhost:3929/api/v1/auth/onedrive/connector/callback
|
||||
|
||||
# -- Dropbox --
|
||||
# DROPBOX_APP_KEY=
|
||||
# DROPBOX_APP_SECRET=
|
||||
# DROPBOX_REDIRECT_URI=http://localhost:8000/api/v1/auth/dropbox/connector/callback
|
||||
# DROPBOX_REDIRECT_URI=http://localhost:3929/api/v1/auth/dropbox/connector/callback
|
||||
|
||||
# -- Composio --
|
||||
# COMPOSIO_API_KEY=
|
||||
# COMPOSIO_ENABLED=TRUE
|
||||
# COMPOSIO_REDIRECT_URI=http://localhost:8000/api/v1/auth/composio/connector/callback
|
||||
# COMPOSIO_REDIRECT_URI=http://localhost:3929/api/v1/auth/composio/connector/callback
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Messaging Channels (optional)
|
||||
# ------------------------------------------------------------------------------
|
||||
# Configure only the external chat channels you want to use.
|
||||
# GATEWAY_ENABLED=TRUE
|
||||
|
||||
# -- Telegram --
|
||||
# TELEGRAM_SHARED_BOT_TOKEN=
|
||||
# TELEGRAM_SHARED_BOT_USERNAME=
|
||||
# TELEGRAM_WEBHOOK_SECRET=
|
||||
# GATEWAY_BASE_URL=http://localhost:8929
|
||||
# GATEWAY_BASE_URL=http://localhost:3929
|
||||
# GATEWAY_TELEGRAM_INTAKE_MODE=webhook
|
||||
|
||||
# -- WhatsApp --
|
||||
|
|
@ -307,20 +306,20 @@ STT_SERVICE=local/base
|
|||
#
|
||||
# GATEWAY_SLACK_ENABLED=FALSE
|
||||
# GATEWAY_SLACK_SIGNING_SECRET=
|
||||
# GATEWAY_SLACK_REDIRECT_URI=http://localhost:8929/api/v1/gateway/slack/callback
|
||||
# GATEWAY_SLACK_REDIRECT_URI=http://localhost:3929/api/v1/gateway/slack/callback
|
||||
|
||||
# -- Discord --
|
||||
# Uses DISCORD_CLIENT_ID, DISCORD_CLIENT_SECRET, and DISCORD_BOT_TOKEN from the
|
||||
# Discord connector section.
|
||||
#
|
||||
# GATEWAY_DISCORD_ENABLED=FALSE
|
||||
# GATEWAY_DISCORD_REDIRECT_URI=http://localhost:8929/api/v1/gateway/discord/callback
|
||||
# GATEWAY_DISCORD_REDIRECT_URI=http://localhost:3929/api/v1/gateway/discord/callback
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# SearXNG (bundled web search, works out of the box with no config needed)
|
||||
# ------------------------------------------------------------------------------
|
||||
# SearXNG provides web search to all search spaces automatically.
|
||||
# To access the SearXNG UI directly: http://localhost:8888
|
||||
# To access the SearXNG UI directly in dev/deps-only compose: http://localhost:8888
|
||||
# To disable the service entirely: docker compose up --scale searxng=0
|
||||
# To point at your own SearXNG instance instead of the bundled one:
|
||||
# SEARXNG_DEFAULT_HOST=http://your-searxng:8080
|
||||
|
|
@ -457,3 +456,36 @@ NOLOGIN_MODE_ENABLED=FALSE
|
|||
# RESIDENTIAL_PROXY_HOSTNAME=
|
||||
# RESIDENTIAL_PROXY_LOCATION=
|
||||
# RESIDENTIAL_PROXY_TYPE=1
|
||||
|
||||
# ==============================================================================
|
||||
# DEV / DEPS-ONLY COMPOSE OVERRIDES
|
||||
# These are only needed for docker-compose.dev.yml or docker-compose.deps-only.yml.
|
||||
# Production Docker exposes Caddy only; raw app ports below do not affect
|
||||
# docker-compose.yml.
|
||||
# ==============================================================================
|
||||
|
||||
# -- pgAdmin (database GUI, dev/deps-only only) --
|
||||
# PGADMIN_PORT=5050
|
||||
# PGADMIN_DEFAULT_EMAIL=admin@surfsense.com
|
||||
# PGADMIN_DEFAULT_PASSWORD=surfsense
|
||||
|
||||
# -- Redis exposed port (dev/deps-only only; Redis is internal-only in prod) --
|
||||
# REDIS_PORT=6379
|
||||
|
||||
# -- SearXNG exposed port (dev/deps-only only; internal-only in prod) --
|
||||
# SEARXNG_PORT=8888
|
||||
|
||||
# -- WhatsApp bridge exposed port (dev/hybrid only; prod keeps it Docker-internal) --
|
||||
# WHATSAPP_BRIDGE_PORT=9929
|
||||
|
||||
# -- Raw app ports (dev/deps-only only; prod exposes Caddy instead) --
|
||||
# BACKEND_PORT=8000
|
||||
# FRONTEND_PORT=3000
|
||||
# ZERO_CACHE_PORT=4848
|
||||
|
||||
# -- Frontend runtime flags (prod and dev compose) --
|
||||
# The frontend reads these at request time in Docker; no NEXT_PUBLIC_* rebuild
|
||||
# or startup substitution is required.
|
||||
# AUTH_TYPE=LOCAL
|
||||
# ETL_SERVICE=DOCLING
|
||||
# DEPLOYMENT_MODE=self-hosted
|
||||
|
|
|
|||
|
|
@ -257,16 +257,15 @@ services:
|
|||
frontend:
|
||||
build:
|
||||
context: ../surfsense_web
|
||||
args:
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_URL: ${NEXT_PUBLIC_FASTAPI_BACKEND_URL:-http://localhost:8000}
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE: ${NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE:-LOCAL}
|
||||
NEXT_PUBLIC_ETL_SERVICE: ${NEXT_PUBLIC_ETL_SERVICE:-DOCLING}
|
||||
NEXT_PUBLIC_ZERO_CACHE_URL: ${NEXT_PUBLIC_ZERO_CACHE_URL:-http://localhost:${ZERO_CACHE_PORT:-4848}}
|
||||
NEXT_PUBLIC_DEPLOYMENT_MODE: ${NEXT_PUBLIC_DEPLOYMENT_MODE:-self-hosted}
|
||||
ports:
|
||||
- "${FRONTEND_PORT:-3000}:3000"
|
||||
env_file:
|
||||
- ../surfsense_web/.env
|
||||
environment:
|
||||
AUTH_TYPE: ${AUTH_TYPE:-LOCAL}
|
||||
ETL_SERVICE: ${ETL_SERVICE:-DOCLING}
|
||||
DEPLOYMENT_MODE: ${DEPLOYMENT_MODE:-self-hosted}
|
||||
SURFSENSE_BACKEND_INTERNAL_URL: http://backend:8000
|
||||
depends_on:
|
||||
backend:
|
||||
condition: service_healthy
|
||||
|
|
|
|||
54
docker/docker-compose.proxy.yml
Normal file
54
docker/docker-compose.proxy.yml
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
# =============================================================================
|
||||
# SurfSense — Optional Caddy reverse-proxy overlay
|
||||
# =============================================================================
|
||||
# Usage (from docker/):
|
||||
# PROXY_HTTP_PORT=8080 SURFSENSE_PUBLIC_URL=http://localhost:8080 \
|
||||
# docker compose -f docker-compose.yml -f docker-compose.proxy.yml up -d
|
||||
#
|
||||
# This overlay is for validation and custom deployments. The production
|
||||
# docker-compose.yml includes Caddy by default.
|
||||
# =============================================================================
|
||||
|
||||
services:
|
||||
backend:
|
||||
ports:
|
||||
- "${BACKEND_PORT:-8929}:8000"
|
||||
|
||||
zero-cache:
|
||||
ports:
|
||||
- "${ZERO_CACHE_PORT:-5929}:4848"
|
||||
|
||||
frontend:
|
||||
ports:
|
||||
- "${FRONTEND_PORT:-3929}:3000"
|
||||
|
||||
proxy:
|
||||
image: caddy:2-alpine
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "${PROXY_HTTP_PORT:-8080}:80"
|
||||
- "${PROXY_HTTPS_PORT:-8443}:443"
|
||||
volumes:
|
||||
- ./proxy/Caddyfile:/etc/caddy/Caddyfile:ro
|
||||
- caddy_data:/data
|
||||
- caddy_config:/config
|
||||
environment:
|
||||
SURFSENSE_SITE_ADDRESS: ${SURFSENSE_SITE_ADDRESS:-:80}
|
||||
CERT_EMAIL: ${CERT_EMAIL:-}
|
||||
CERT_ACME_CA: ${CERT_ACME_CA:-https://acme-v02.api.letsencrypt.org/directory}
|
||||
CERT_ACME_DNS: ${CERT_ACME_DNS:-}
|
||||
TRUSTED_PROXIES: ${TRUSTED_PROXIES:-0.0.0.0/0}
|
||||
SURFSENSE_MAX_BODY_SIZE: ${SURFSENSE_MAX_BODY_SIZE:-5GB}
|
||||
depends_on:
|
||||
frontend:
|
||||
condition: service_started
|
||||
backend:
|
||||
condition: service_healthy
|
||||
zero-cache:
|
||||
condition: service_healthy
|
||||
|
||||
volumes:
|
||||
caddy_data:
|
||||
name: surfsense-caddy-data
|
||||
caddy_config:
|
||||
name: surfsense-caddy-config
|
||||
|
|
@ -94,10 +94,39 @@ services:
|
|||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
# Single public entry point for the Docker stack. Comment this service out
|
||||
# only if you front SurfSense with your own reverse proxy.
|
||||
proxy:
|
||||
image: caddy:2-alpine
|
||||
# For DNS-01/wildcard certificates, replace image with:
|
||||
# build: ./proxy
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "${LISTEN_HTTP_PORT:-3929}:80"
|
||||
- "${LISTEN_HTTPS_PORT:-443}:443"
|
||||
volumes:
|
||||
- ./proxy/Caddyfile:/etc/caddy/Caddyfile:ro
|
||||
- caddy_data:/data
|
||||
- caddy_config:/config
|
||||
environment:
|
||||
SURFSENSE_SITE_ADDRESS: ${SURFSENSE_SITE_ADDRESS:-:80}
|
||||
CERT_EMAIL: ${CERT_EMAIL:-}
|
||||
CERT_ACME_CA: ${CERT_ACME_CA:-https://acme-v02.api.letsencrypt.org/directory}
|
||||
CERT_ACME_DNS: ${CERT_ACME_DNS:-}
|
||||
TRUSTED_PROXIES: ${TRUSTED_PROXIES:-0.0.0.0/0}
|
||||
SURFSENSE_MAX_BODY_SIZE: ${SURFSENSE_MAX_BODY_SIZE:-5GB}
|
||||
depends_on:
|
||||
frontend:
|
||||
condition: service_started
|
||||
backend:
|
||||
condition: service_healthy
|
||||
zero-cache:
|
||||
condition: service_healthy
|
||||
|
||||
backend:
|
||||
image: ghcr.io/modsetter/surfsense-backend:${SURFSENSE_VERSION:-latest}${SURFSENSE_VARIANT:+-${SURFSENSE_VARIANT}}
|
||||
ports:
|
||||
- "${BACKEND_PORT:-8929}:8000"
|
||||
expose:
|
||||
- "8000"
|
||||
volumes:
|
||||
- shared_temp:/shared_tmp
|
||||
- object_store:/app/.local_object_store
|
||||
|
|
@ -115,7 +144,8 @@ services:
|
|||
UVICORN_LOOP: asyncio
|
||||
UNSTRUCTURED_HAS_PATCHED_LOOP: "1"
|
||||
FILE_STORAGE_LOCAL_PATH: /app/.local_object_store
|
||||
NEXT_FRONTEND_URL: ${NEXT_FRONTEND_URL:-http://localhost:${FRONTEND_PORT:-3929}}
|
||||
NEXT_FRONTEND_URL: ${NEXT_FRONTEND_URL:-${SURFSENSE_PUBLIC_URL:-http://localhost:${LISTEN_HTTP_PORT:-3929}}}
|
||||
BACKEND_URL: ${BACKEND_URL:-${SURFSENSE_PUBLIC_URL:-http://localhost:${LISTEN_HTTP_PORT:-3929}}}
|
||||
SEARXNG_DEFAULT_HOST: ${SEARXNG_DEFAULT_HOST:-http://searxng:8080}
|
||||
WHATSAPP_BRIDGE_URL: ${WHATSAPP_BRIDGE_URL:-http://whatsapp-bridge:9929}
|
||||
# Daytona Sandbox – uncomment and set credentials to enable cloud code execution
|
||||
|
|
@ -221,8 +251,8 @@ services:
|
|||
|
||||
zero-cache:
|
||||
image: rocicorp/zero:1.4.0
|
||||
ports:
|
||||
- "${ZERO_CACHE_PORT:-5929}:4848"
|
||||
expose:
|
||||
- "4848"
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
environment:
|
||||
|
|
@ -256,16 +286,13 @@ services:
|
|||
|
||||
frontend:
|
||||
image: ghcr.io/modsetter/surfsense-web:${SURFSENSE_VERSION:-latest}
|
||||
ports:
|
||||
- "${FRONTEND_PORT:-3929}:3000"
|
||||
expose:
|
||||
- "3000"
|
||||
environment:
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_URL: ${NEXT_PUBLIC_FASTAPI_BACKEND_URL:-http://localhost:${BACKEND_PORT:-8929}}
|
||||
NEXT_PUBLIC_ZERO_CACHE_URL: ${NEXT_PUBLIC_ZERO_CACHE_URL:-http://localhost:${ZERO_CACHE_PORT:-5929}}
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE: ${AUTH_TYPE:-LOCAL}
|
||||
NEXT_PUBLIC_ETL_SERVICE: ${ETL_SERVICE:-DOCLING}
|
||||
NEXT_PUBLIC_DEPLOYMENT_MODE: ${DEPLOYMENT_MODE:-self-hosted}
|
||||
NEXT_PUBLIC_WHATSAPP_DISPLAY_PHONE_NUMBER: ${WHATSAPP_SHARED_DISPLAY_PHONE_NUMBER:-}
|
||||
FASTAPI_BACKEND_INTERNAL_URL: ${FASTAPI_BACKEND_INTERNAL_URL:-http://backend:8000}
|
||||
AUTH_TYPE: ${AUTH_TYPE:-LOCAL}
|
||||
ETL_SERVICE: ${ETL_SERVICE:-DOCLING}
|
||||
DEPLOYMENT_MODE: ${DEPLOYMENT_MODE:-self-hosted}
|
||||
SURFSENSE_BACKEND_INTERNAL_URL: http://backend:8000
|
||||
labels:
|
||||
- "com.centurylinklabs.watchtower.enable=true"
|
||||
depends_on:
|
||||
|
|
@ -286,5 +313,9 @@ volumes:
|
|||
name: surfsense-object-store
|
||||
zero_cache_data:
|
||||
name: surfsense-zero-cache
|
||||
caddy_data:
|
||||
name: surfsense-caddy-data
|
||||
caddy_config:
|
||||
name: surfsense-caddy-config
|
||||
whatsapp_sessions:
|
||||
name: surfsense-whatsapp-sessions
|
||||
|
|
|
|||
45
docker/proxy/Caddyfile
Normal file
45
docker/proxy/Caddyfile
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
{
|
||||
# Optional ACME/global settings. These are harmless in the default :80
|
||||
# localhost mode and become active when SURFSENSE_SITE_ADDRESS is a domain.
|
||||
{$CERT_EMAIL}
|
||||
acme_ca {$CERT_ACME_CA:https://acme-v02.api.letsencrypt.org/directory}
|
||||
{$CERT_ACME_DNS}
|
||||
servers {
|
||||
client_ip_headers X-Forwarded-For X-Real-IP
|
||||
trusted_proxies static {$TRUSTED_PROXIES:0.0.0.0/0}
|
||||
}
|
||||
}
|
||||
|
||||
(surfsense_proxy) {
|
||||
request_body {
|
||||
max_size {$SURFSENSE_MAX_BODY_SIZE:5GB}
|
||||
}
|
||||
|
||||
# Frontend-owned auth page (the post-login token handler). More specific than
|
||||
# /auth/*, so Caddy's matcher-specificity sort routes it here, not to backend.
|
||||
reverse_proxy /auth/callback* frontend:3000
|
||||
|
||||
# Backend auth routes (FastAPI Users + OAuth helpers).
|
||||
reverse_proxy /auth/* backend:8000
|
||||
|
||||
# Backend user profile routes (FastAPI Users users router, mounted at /users).
|
||||
reverse_proxy /users/* backend:8000
|
||||
|
||||
# Backend REST, streaming, connector OAuth, and messaging gateway endpoints.
|
||||
# FastAPI already serves /api/v1, so the path is forwarded unchanged.
|
||||
reverse_proxy /api/v1/* backend:8000 {
|
||||
flush_interval -1
|
||||
}
|
||||
|
||||
# Zero accepts a single path-component base URL (Zero >= 0.6).
|
||||
# Preserve /zero so browser cacheURL can be ${SURFSENSE_PUBLIC_URL}/zero.
|
||||
reverse_proxy /zero/* zero-cache:4848
|
||||
|
||||
# Next.js app and frontend-owned API routes:
|
||||
# /api/zero/*, /api/search, /api/contact, etc.
|
||||
reverse_proxy /* frontend:3000
|
||||
}
|
||||
|
||||
{$SURFSENSE_SITE_ADDRESS::80} {
|
||||
import surfsense_proxy
|
||||
}
|
||||
10
docker/proxy/Dockerfile
Normal file
10
docker/proxy/Dockerfile
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
FROM caddy:2-builder-alpine AS builder
|
||||
|
||||
RUN xcaddy build \
|
||||
--with github.com/caddy-dns/cloudflare \
|
||||
--with github.com/caddy-dns/digitalocean
|
||||
|
||||
FROM caddy:2-alpine
|
||||
|
||||
COPY --from=builder /usr/bin/caddy /usr/bin/caddy
|
||||
COPY Caddyfile /etc/caddy/Caddyfile
|
||||
|
|
@ -333,11 +333,13 @@ step "Downloading SurfSense files"
|
|||
info "Installation directory: ${INSTALL_DIR}"
|
||||
mkdir -p "${INSTALL_DIR}/scripts"
|
||||
mkdir -p "${INSTALL_DIR}/searxng"
|
||||
mkdir -p "${INSTALL_DIR}/proxy"
|
||||
|
||||
FILES=(
|
||||
"docker/docker-compose.yml:docker-compose.yml"
|
||||
"docker/docker-compose.gpu.yml:docker-compose.gpu.yml"
|
||||
"docker/.env.example:.env.example"
|
||||
"docker/proxy/Caddyfile:proxy/Caddyfile"
|
||||
"docker/postgresql.conf:postgresql.conf"
|
||||
"docker/scripts/migrate-database.sh:scripts/migrate-database.sh"
|
||||
"docker/searxng/settings.yml:searxng/settings.yml"
|
||||
|
|
@ -532,9 +534,12 @@ _variant_display=$(grep '^SURFSENSE_VARIANT=' "${INSTALL_DIR}/.env" 2>/dev/null
|
|||
_variant_display="${_variant_display:-cpu}"
|
||||
step "SurfSense is now installed [${_version_display}]"
|
||||
|
||||
info " Frontend: http://localhost:3929"
|
||||
info " Backend: http://localhost:8929"
|
||||
info " API Docs: http://localhost:8929/docs"
|
||||
_public_url=$(grep '^SURFSENSE_PUBLIC_URL=' "${INSTALL_DIR}/.env" 2>/dev/null | cut -d= -f2- | tr -d '"' | head -1 || true)
|
||||
_public_url="${_public_url:-http://localhost:3929}"
|
||||
|
||||
info " SurfSense: ${_public_url}"
|
||||
info " Backend: ${_public_url}/api/v1"
|
||||
info " Zero sync: ${_public_url}/zero"
|
||||
info ""
|
||||
info " Config: ${INSTALL_DIR}/.env"
|
||||
info " Variant: ${_variant_display}"
|
||||
|
|
|
|||
|
|
@ -30,12 +30,9 @@ CELERY_TASK_DEFAULT_QUEUE=surfsense
|
|||
# Optional: TTL in seconds for connector indexing lock key
|
||||
# CONNECTOR_INDEXING_LOCK_TTL_SECONDS=28800
|
||||
|
||||
# Messaging Gateway (global)
|
||||
# GATEWAY_ENABLED: master switch for ALL messaging gateway channels (Telegram, WhatsApp,
|
||||
# Slack, Discord). When FALSE, no gateway background workers/supervisors start and all
|
||||
# gateway HTTP routes (webhooks, OAuth callbacks, pairing) return 404. Set per-channel
|
||||
# flags below to control individual platforms once the gateway is enabled.
|
||||
GATEWAY_ENABLED=TRUE
|
||||
# Messaging Gateway: disabled by default; set TRUE to enable chat integrations.
|
||||
# Supported messaging gateways: WhatsApp, Telegram, Discord, Slack
|
||||
# GATEWAY_ENABLED=TRUE
|
||||
|
||||
# Telegram Gateway
|
||||
# TELEGRAM_WEBHOOK_SECRET must be 1-256 chars and contain only A-Z, a-z, 0-9, _ or -
|
||||
|
|
@ -326,6 +323,42 @@ FILE_STORAGE_BACKEND=local
|
|||
# AZURE_STORAGE_CONNECTION_STRING=DefaultEndpointsProtocol=https;AccountName=...;AccountKey=...;EndpointSuffix=core.windows.net
|
||||
# AZURE_STORAGE_CONTAINER=surfsense-documents
|
||||
|
||||
# ETL Parse Cache
|
||||
# Reuse parser output for identical file bytes across workspaces (skips paid
|
||||
# re-parsing on LlamaCloud / Azure DI / Unstructured). Off by default.
|
||||
ETL_CACHE_ENABLED=false
|
||||
# Bump to invalidate all cached entries after a parser/behaviour change.
|
||||
# ETL_CACHE_PARSER_VERSION=1
|
||||
# Prune entries unused for this many days.
|
||||
# ETL_CACHE_TTL_DAYS=90
|
||||
# Soft cap on total cached markdown; coldest entries are evicted past it.
|
||||
# ETL_CACHE_MAX_TOTAL_MB=5120
|
||||
# Rows deleted per eviction pass.
|
||||
# ETL_CACHE_EVICTION_BATCH=500
|
||||
# Optional dedicated blob storage; unset reuses the main file storage backend.
|
||||
# ETL_CACHE_STORAGE_BACKEND=azure
|
||||
# ETL_CACHE_STORAGE_CONTAINER=surfsense-etl-cache
|
||||
# ETL_CACHE_STORAGE_LOCAL_PATH=/var/lib/surfsense/etl-cache
|
||||
|
||||
# Embedding Cache
|
||||
# Reuse chunk+embedding output for identical markdown across workspaces (skips
|
||||
# re-chunking and re-embedding). Blobs share the ETL_CACHE_STORAGE_* backend.
|
||||
# Off by default.
|
||||
EMBEDDING_CACHE_ENABLED=false
|
||||
# Bump to invalidate all cached embedding sets after a chunker change.
|
||||
# EMBEDDING_CACHE_CHUNKER_VERSION=1
|
||||
# Prune entries unused for this many days.
|
||||
# EMBEDDING_CACHE_TTL_DAYS=90
|
||||
# Soft cap on total cached embeddings; coldest entries are evicted past it.
|
||||
# EMBEDDING_CACHE_MAX_TOTAL_MB=5120
|
||||
# Rows deleted per eviction pass.
|
||||
# EMBEDDING_CACHE_EVICTION_BATCH=500
|
||||
|
||||
# Incremental re-indexing: on document edits, keep chunks whose text is
|
||||
# unchanged (reusing their embeddings) and embed only new/changed ones.
|
||||
# Set to false to fall back to delete-all + full re-embed (kill switch).
|
||||
# CHUNK_RECONCILE_ENABLED=true
|
||||
|
||||
# Daytona Sandbox (isolated code execution)
|
||||
# DAYTONA_SANDBOX_ENABLED=FALSE
|
||||
# DAYTONA_API_KEY=your-daytona-api-key
|
||||
|
|
@ -365,7 +398,9 @@ LANGSMITH_PROJECT=surfsense
|
|||
# SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false # adds a per-turn LLM call
|
||||
|
||||
# Observability - OTel
|
||||
# SURFSENSE_ENABLE_OTEL=false
|
||||
# Disabled by default. Uncomment to enable OpenTelemetry.
|
||||
# SURFSENSE_ENABLE_OTEL=true
|
||||
|
||||
# OpenTelemetry - endpoint enables export; absent = no-op.
|
||||
# Production should point at an OTel Collector. For local docker-compose.dev.yml,
|
||||
# use http://otel-lgtm:4317 instead.
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ Revision ID: 138
|
|||
Revises: 137
|
||||
Create Date: 2026-04-30
|
||||
|
||||
Add a single thread-level column to persist the Auto (Fastest) model pin:
|
||||
Add a single thread-level column to persist the Auto model pin:
|
||||
- pinned_llm_config_id: concrete resolved global LLM config id used for this
|
||||
thread. NULL means "no pin; Auto will resolve on next turn".
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,19 @@ down_revision: str | None = "157"
|
|||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
PUBLICATION_NAME = "zero_publication"
|
||||
TARGET_STATUS_LABELS = (
|
||||
"pending",
|
||||
"awaiting_brief",
|
||||
"drafting",
|
||||
"awaiting_review",
|
||||
"rendering",
|
||||
"ready",
|
||||
"failed",
|
||||
"cancelled",
|
||||
)
|
||||
LEGACY_STATUS_LABELS = ("pending", "generating", "ready", "failed")
|
||||
|
||||
|
||||
def _drop_podcasts_from_publication() -> None:
|
||||
"""Detach podcasts from zero_publication so status can be retyped.
|
||||
|
|
@ -28,31 +41,103 @@ def _drop_podcasts_from_publication() -> None:
|
|||
published = conn.execute(
|
||||
sa.text(
|
||||
"SELECT 1 FROM pg_publication_tables "
|
||||
"WHERE pubname = 'zero_publication' "
|
||||
"WHERE pubname = :publication "
|
||||
"AND schemaname = current_schema() AND tablename = 'podcasts'"
|
||||
)
|
||||
),
|
||||
{"publication": PUBLICATION_NAME},
|
||||
).fetchone()
|
||||
if published:
|
||||
op.execute('ALTER PUBLICATION "zero_publication" DROP TABLE "podcasts";')
|
||||
op.execute(f'ALTER PUBLICATION "{PUBLICATION_NAME}" DROP TABLE "podcasts";')
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
_drop_podcasts_from_publication()
|
||||
def _enum_labels(type_name: str) -> list[str] | None:
|
||||
rows = (
|
||||
op.get_bind()
|
||||
.execute(
|
||||
sa.text(
|
||||
"SELECT e.enumlabel "
|
||||
"FROM pg_type t "
|
||||
"JOIN pg_namespace n ON n.oid = t.typnamespace "
|
||||
"JOIN pg_enum e ON e.enumtypid = t.oid "
|
||||
"WHERE n.nspname = current_schema() AND t.typname = :type_name "
|
||||
"ORDER BY e.enumsortorder"
|
||||
),
|
||||
{"type_name": type_name},
|
||||
)
|
||||
.fetchall()
|
||||
)
|
||||
if not rows:
|
||||
return None
|
||||
return [str(row[0]) for row in rows]
|
||||
|
||||
# Retype the status enum by swapping in a fresh type and casting existing
|
||||
# rows. The legacy transient value 'generating' maps onto 'rendering'.
|
||||
op.execute("ALTER TYPE podcast_status RENAME TO podcast_status_old;")
|
||||
|
||||
def _column_type_name(table: str, column: str) -> str | None:
|
||||
row = (
|
||||
op.get_bind()
|
||||
.execute(
|
||||
sa.text(
|
||||
"SELECT udt_name "
|
||||
"FROM information_schema.columns "
|
||||
"WHERE table_schema = current_schema() "
|
||||
"AND table_name = :table AND column_name = :column"
|
||||
),
|
||||
{"table": table, "column": column},
|
||||
)
|
||||
.fetchone()
|
||||
)
|
||||
return str(row[0]) if row else None
|
||||
|
||||
|
||||
def _ensure_status_enum(
|
||||
*,
|
||||
desired_labels: tuple[str, ...],
|
||||
temporary_type: str,
|
||||
create_sql: str,
|
||||
alter_sql: str,
|
||||
default_value: str,
|
||||
) -> None:
|
||||
current_labels = _enum_labels("podcast_status")
|
||||
desired = list(desired_labels)
|
||||
|
||||
if current_labels != desired:
|
||||
if current_labels is None:
|
||||
if _enum_labels(temporary_type) is None:
|
||||
raise RuntimeError("podcast_status enum is missing")
|
||||
elif _enum_labels(temporary_type) is None:
|
||||
op.execute(f"ALTER TYPE podcast_status RENAME TO {temporary_type};")
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"podcast_status and its temporary replacement both exist"
|
||||
)
|
||||
|
||||
if _enum_labels("podcast_status") is None:
|
||||
op.execute(create_sql)
|
||||
|
||||
if _enum_labels("podcast_status") != desired:
|
||||
raise RuntimeError("podcast_status enum is not in the expected shape")
|
||||
|
||||
op.execute("ALTER TABLE podcasts ALTER COLUMN status DROP DEFAULT;")
|
||||
if _column_type_name("podcasts", "status") != "podcast_status":
|
||||
op.execute(alter_sql)
|
||||
op.execute(
|
||||
"""
|
||||
f"ALTER TABLE podcasts ALTER COLUMN status SET DEFAULT '{default_value}';"
|
||||
)
|
||||
|
||||
if _enum_labels(temporary_type) is not None:
|
||||
op.execute(f"DROP TYPE {temporary_type};")
|
||||
|
||||
|
||||
def _upgrade_status_enum() -> None:
|
||||
_ensure_status_enum(
|
||||
desired_labels=TARGET_STATUS_LABELS,
|
||||
temporary_type="podcast_status_old",
|
||||
create_sql="""
|
||||
CREATE TYPE podcast_status AS ENUM (
|
||||
'pending', 'awaiting_brief', 'drafting', 'awaiting_review',
|
||||
'rendering', 'ready', 'failed', 'cancelled'
|
||||
);
|
||||
"""
|
||||
)
|
||||
op.execute("ALTER TABLE podcasts ALTER COLUMN status DROP DEFAULT;")
|
||||
op.execute(
|
||||
"""
|
||||
""",
|
||||
alter_sql="""
|
||||
ALTER TABLE podcasts
|
||||
ALTER COLUMN status TYPE podcast_status
|
||||
USING (
|
||||
|
|
@ -61,10 +146,43 @@ def upgrade() -> None:
|
|||
ELSE status::text
|
||||
END
|
||||
)::podcast_status;
|
||||
"""
|
||||
""",
|
||||
default_value="pending",
|
||||
)
|
||||
op.execute("ALTER TABLE podcasts ALTER COLUMN status SET DEFAULT 'pending';")
|
||||
op.execute("DROP TYPE podcast_status_old;")
|
||||
|
||||
|
||||
def _downgrade_status_enum() -> None:
|
||||
_ensure_status_enum(
|
||||
desired_labels=LEGACY_STATUS_LABELS,
|
||||
temporary_type="podcast_status_new",
|
||||
create_sql=(
|
||||
"CREATE TYPE podcast_status AS ENUM "
|
||||
"('pending', 'generating', 'ready', 'failed');"
|
||||
),
|
||||
alter_sql="""
|
||||
ALTER TABLE podcasts
|
||||
ALTER COLUMN status TYPE podcast_status
|
||||
USING (
|
||||
CASE status::text
|
||||
WHEN 'awaiting_brief' THEN 'pending'
|
||||
WHEN 'drafting' THEN 'generating'
|
||||
WHEN 'awaiting_review' THEN 'generating'
|
||||
WHEN 'rendering' THEN 'generating'
|
||||
WHEN 'cancelled' THEN 'failed'
|
||||
ELSE status::text
|
||||
END
|
||||
)::podcast_status;
|
||||
""",
|
||||
default_value="ready",
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
_drop_podcasts_from_publication()
|
||||
|
||||
# Retype the status enum by swapping in a fresh type and casting existing
|
||||
# rows. The legacy transient value 'generating' maps onto 'rendering'.
|
||||
_upgrade_status_enum()
|
||||
|
||||
op.execute("ALTER TABLE podcasts ADD COLUMN IF NOT EXISTS source_content TEXT;")
|
||||
op.execute("ALTER TABLE podcasts ADD COLUMN IF NOT EXISTS spec JSONB;")
|
||||
|
|
@ -83,6 +201,8 @@ def upgrade() -> None:
|
|||
|
||||
|
||||
def downgrade() -> None:
|
||||
_drop_podcasts_from_publication()
|
||||
|
||||
op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS error;")
|
||||
op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS duration_seconds;")
|
||||
op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS storage_key;")
|
||||
|
|
@ -92,27 +212,4 @@ def downgrade() -> None:
|
|||
op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS source_content;")
|
||||
|
||||
# Collapse the expanded lifecycle back onto the original four values.
|
||||
op.execute("ALTER TYPE podcast_status RENAME TO podcast_status_new;")
|
||||
op.execute(
|
||||
"CREATE TYPE podcast_status AS ENUM "
|
||||
"('pending', 'generating', 'ready', 'failed');"
|
||||
)
|
||||
op.execute("ALTER TABLE podcasts ALTER COLUMN status DROP DEFAULT;")
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE podcasts
|
||||
ALTER COLUMN status TYPE podcast_status
|
||||
USING (
|
||||
CASE status::text
|
||||
WHEN 'awaiting_brief' THEN 'pending'
|
||||
WHEN 'drafting' THEN 'generating'
|
||||
WHEN 'awaiting_review' THEN 'generating'
|
||||
WHEN 'rendering' THEN 'generating'
|
||||
WHEN 'cancelled' THEN 'failed'
|
||||
ELSE status::text
|
||||
END
|
||||
)::podcast_status;
|
||||
"""
|
||||
)
|
||||
op.execute("ALTER TABLE podcasts ALTER COLUMN status SET DEFAULT 'ready';")
|
||||
op.execute("DROP TYPE podcast_status_new;")
|
||||
_downgrade_status_enum()
|
||||
|
|
|
|||
299
surfsense_backend/alembic/versions/160_add_model_connections.py
Normal file
299
surfsense_backend/alembic/versions/160_add_model_connections.py
Normal file
|
|
@ -0,0 +1,299 @@
|
|||
"""add model connections
|
||||
|
||||
Revision ID: 160
|
||||
Revises: 159
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "160"
|
||||
down_revision: str | None = "159"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
connection_scope = postgresql.ENUM(
|
||||
"GLOBAL",
|
||||
"SEARCH_SPACE",
|
||||
"USER",
|
||||
name="connectionscope",
|
||||
create_type=False,
|
||||
)
|
||||
model_source = postgresql.ENUM(
|
||||
"DISCOVERED",
|
||||
"MANUAL",
|
||||
name="modelsource",
|
||||
create_type=False,
|
||||
)
|
||||
|
||||
|
||||
def _table_exists(table_name: str) -> bool:
|
||||
return table_name in sa.inspect(op.get_bind()).get_table_names()
|
||||
|
||||
|
||||
def _column_exists(table_name: str, column_name: str) -> bool:
|
||||
if not _table_exists(table_name):
|
||||
return False
|
||||
return column_name in {
|
||||
column["name"] for column in sa.inspect(op.get_bind()).get_columns(table_name)
|
||||
}
|
||||
|
||||
|
||||
def _index_exists(table_name: str, index_name: str) -> bool:
|
||||
if not _table_exists(table_name):
|
||||
return False
|
||||
return index_name in {
|
||||
index["name"] for index in sa.inspect(op.get_bind()).get_indexes(table_name)
|
||||
}
|
||||
|
||||
|
||||
def _create_index_if_missing(
|
||||
index_name: str,
|
||||
table_name: str,
|
||||
columns: list[str],
|
||||
) -> None:
|
||||
if not _index_exists(table_name, index_name):
|
||||
op.create_index(index_name, table_name, columns, unique=False)
|
||||
|
||||
|
||||
def _add_searchspace_column_if_missing(
|
||||
column_name: str,
|
||||
*,
|
||||
server_default: object | None = None,
|
||||
) -> None:
|
||||
if not _column_exists("searchspaces", column_name):
|
||||
op.add_column(
|
||||
"searchspaces",
|
||||
sa.Column(
|
||||
column_name,
|
||||
sa.Integer(),
|
||||
nullable=True,
|
||||
server_default=server_default,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _drop_column_if_exists(table_name: str, column_name: str) -> None:
|
||||
if _column_exists(table_name, column_name):
|
||||
op.drop_column(table_name, column_name)
|
||||
|
||||
|
||||
def _drop_index_if_exists(table_name: str, index_name: str) -> None:
|
||||
if _index_exists(table_name, index_name):
|
||||
op.drop_index(index_name, table_name=table_name)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
bind = op.get_bind()
|
||||
connection_scope.create(bind, checkfirst=True)
|
||||
model_source.create(bind, checkfirst=True)
|
||||
|
||||
if _table_exists("connections"):
|
||||
if _column_exists("connections", "litellm_provider") and not _column_exists(
|
||||
"connections", "provider"
|
||||
):
|
||||
op.alter_column(
|
||||
"connections",
|
||||
"litellm_provider",
|
||||
new_column_name="provider",
|
||||
existing_type=sa.String(length=100),
|
||||
existing_nullable=True,
|
||||
)
|
||||
op.alter_column(
|
||||
"connections",
|
||||
"provider",
|
||||
existing_type=sa.String(length=100),
|
||||
nullable=False,
|
||||
)
|
||||
elif _column_exists("connections", "native_provider") and not _column_exists(
|
||||
"connections", "provider"
|
||||
):
|
||||
op.alter_column(
|
||||
"connections",
|
||||
"native_provider",
|
||||
new_column_name="provider",
|
||||
existing_type=sa.String(length=100),
|
||||
existing_nullable=True,
|
||||
)
|
||||
op.alter_column(
|
||||
"connections",
|
||||
"provider",
|
||||
existing_type=sa.String(length=100),
|
||||
nullable=False,
|
||||
)
|
||||
elif not _column_exists("connections", "provider"):
|
||||
op.add_column(
|
||||
"connections",
|
||||
sa.Column("provider", sa.String(length=100), nullable=False),
|
||||
)
|
||||
_drop_index_if_exists("connections", "ix_connections_protocol")
|
||||
_drop_column_if_exists("connections", "protocol")
|
||||
else:
|
||||
op.create_table(
|
||||
"connections",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("provider", sa.String(length=100), nullable=False),
|
||||
sa.Column("base_url", sa.String(length=500), nullable=True),
|
||||
sa.Column("api_key", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"extra",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
server_default=sa.text("'{}'::jsonb"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("scope", connection_scope, nullable=False),
|
||||
sa.Column(
|
||||
"enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False
|
||||
),
|
||||
sa.Column("search_space_id", sa.Integer(), nullable=True),
|
||||
sa.Column("user_id", sa.UUID(), nullable=True),
|
||||
sa.CheckConstraint(
|
||||
"(scope = 'GLOBAL' AND search_space_id IS NULL AND user_id IS NULL) OR "
|
||||
"(scope = 'SEARCH_SPACE' AND search_space_id IS NOT NULL AND user_id IS NOT NULL) OR "
|
||||
"(scope = 'USER' AND user_id IS NOT NULL)",
|
||||
name="ck_connections_scope_owner",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["search_space_id"], ["searchspaces.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
if _index_exists(
|
||||
"connections", "ix_connections_native_provider"
|
||||
) and not _index_exists("connections", "ix_connections_provider"):
|
||||
op.execute(
|
||||
"ALTER INDEX ix_connections_native_provider "
|
||||
"RENAME TO ix_connections_provider"
|
||||
)
|
||||
if _index_exists(
|
||||
"connections", "ix_connections_litellm_provider"
|
||||
) and not _index_exists("connections", "ix_connections_provider"):
|
||||
op.execute(
|
||||
"ALTER INDEX ix_connections_litellm_provider "
|
||||
"RENAME TO ix_connections_provider"
|
||||
)
|
||||
_create_index_if_missing("ix_connections_provider", "connections", ["provider"])
|
||||
_create_index_if_missing("ix_connections_scope", "connections", ["scope"])
|
||||
|
||||
if not _table_exists("models"):
|
||||
op.create_table(
|
||||
"models",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("connection_id", sa.Integer(), nullable=False),
|
||||
sa.Column("model_id", sa.String(length=255), nullable=False),
|
||||
sa.Column("display_name", sa.String(length=255), nullable=True),
|
||||
sa.Column(
|
||||
"source",
|
||||
model_source,
|
||||
server_default="DISCOVERED",
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("supports_chat", sa.Boolean(), nullable=True),
|
||||
sa.Column("max_input_tokens", sa.Integer(), nullable=True),
|
||||
sa.Column("supports_image_input", sa.Boolean(), nullable=True),
|
||||
sa.Column("supports_tools", sa.Boolean(), nullable=True),
|
||||
sa.Column("supports_image_generation", sa.Boolean(), nullable=True),
|
||||
sa.Column(
|
||||
"capabilities_override",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
server_default=sa.text("'{}'::jsonb"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False
|
||||
),
|
||||
sa.Column("billing_tier", sa.String(length=50), nullable=True),
|
||||
sa.Column(
|
||||
"catalog",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
server_default=sa.text("'{}'::jsonb"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["connection_id"], ["connections.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint(
|
||||
"connection_id", "model_id", name="uq_models_connection_model_id"
|
||||
),
|
||||
)
|
||||
else:
|
||||
if not _column_exists("models", "supports_chat"):
|
||||
op.add_column(
|
||||
"models", sa.Column("supports_chat", sa.Boolean(), nullable=True)
|
||||
)
|
||||
if not _column_exists("models", "max_input_tokens"):
|
||||
op.add_column(
|
||||
"models", sa.Column("max_input_tokens", sa.Integer(), nullable=True)
|
||||
)
|
||||
if not _column_exists("models", "supports_image_input"):
|
||||
op.add_column(
|
||||
"models", sa.Column("supports_image_input", sa.Boolean(), nullable=True)
|
||||
)
|
||||
if not _column_exists("models", "supports_tools"):
|
||||
op.add_column(
|
||||
"models", sa.Column("supports_tools", sa.Boolean(), nullable=True)
|
||||
)
|
||||
if not _column_exists("models", "supports_image_generation"):
|
||||
op.add_column(
|
||||
"models",
|
||||
sa.Column("supports_image_generation", sa.Boolean(), nullable=True),
|
||||
)
|
||||
_drop_column_if_exists("models", "capabilities")
|
||||
_drop_column_if_exists("models", "capabilities_declared")
|
||||
_drop_column_if_exists("models", "capabilities_verified")
|
||||
_create_index_if_missing("ix_models_connection_id", "models", ["connection_id"])
|
||||
_create_index_if_missing("ix_models_model_id", "models", ["model_id"])
|
||||
_create_index_if_missing("ix_models_billing_tier", "models", ["billing_tier"])
|
||||
|
||||
_add_searchspace_column_if_missing("chat_model_id", server_default=sa.text("0"))
|
||||
_add_searchspace_column_if_missing(
|
||||
"image_gen_model_id", server_default=sa.text("0")
|
||||
)
|
||||
_add_searchspace_column_if_missing("vision_model_id", server_default=sa.text("0"))
|
||||
for column_name in ("chat_model_id", "image_gen_model_id", "vision_model_id"):
|
||||
op.alter_column(
|
||||
"searchspaces",
|
||||
column_name,
|
||||
existing_type=sa.Integer(),
|
||||
existing_nullable=True,
|
||||
server_default=sa.text("0"),
|
||||
)
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE searchspaces
|
||||
SET
|
||||
chat_model_id = COALESCE(chat_model_id, 0),
|
||||
image_gen_model_id = COALESCE(image_gen_model_id, 0),
|
||||
vision_model_id = COALESCE(vision_model_id, 0)
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute("DROP TYPE IF EXISTS connectionprotocol")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("searchspaces", "vision_model_id")
|
||||
op.drop_column("searchspaces", "image_gen_model_id")
|
||||
op.drop_column("searchspaces", "chat_model_id")
|
||||
|
||||
op.drop_index(op.f("ix_models_billing_tier"), table_name="models")
|
||||
op.drop_index("ix_models_model_id", table_name="models")
|
||||
op.drop_index(op.f("ix_models_connection_id"), table_name="models")
|
||||
op.drop_table("models")
|
||||
|
||||
op.drop_index(op.f("ix_connections_scope"), table_name="connections")
|
||||
op.drop_index(op.f("ix_connections_provider"), table_name="connections")
|
||||
op.drop_table("connections")
|
||||
|
||||
bind = op.get_bind()
|
||||
model_source.drop(bind, checkfirst=True)
|
||||
connection_scope.drop(bind, checkfirst=True)
|
||||
|
|
@ -0,0 +1,270 @@
|
|||
"""remove legacy model config tables
|
||||
|
||||
Revision ID: 161
|
||||
Revises: 160
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.types import TypeEngine
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "161"
|
||||
down_revision: str | None = "160"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
litellm_provider = postgresql.ENUM(
|
||||
"OPENAI",
|
||||
"ANTHROPIC",
|
||||
"GOOGLE",
|
||||
"AZURE_OPENAI",
|
||||
"BEDROCK",
|
||||
"VERTEX_AI",
|
||||
"GROQ",
|
||||
"COHERE",
|
||||
"MISTRAL",
|
||||
"DEEPSEEK",
|
||||
"XAI",
|
||||
"OPENROUTER",
|
||||
"TOGETHER_AI",
|
||||
"FIREWORKS_AI",
|
||||
"REPLICATE",
|
||||
"PERPLEXITY",
|
||||
"OLLAMA",
|
||||
"ALIBABA_QWEN",
|
||||
"MOONSHOT",
|
||||
"ZHIPU",
|
||||
"ANYSCALE",
|
||||
"DEEPINFRA",
|
||||
"CEREBRAS",
|
||||
"SAMBANOVA",
|
||||
"AI21",
|
||||
"CLOUDFLARE",
|
||||
"DATABRICKS",
|
||||
"COMETAPI",
|
||||
"HUGGINGFACE",
|
||||
"GITHUB_MODELS",
|
||||
"MINIMAX",
|
||||
"CUSTOM",
|
||||
name="litellmprovider",
|
||||
create_type=False,
|
||||
)
|
||||
image_gen_provider = postgresql.ENUM(
|
||||
"OPENAI",
|
||||
"AZURE_OPENAI",
|
||||
"GOOGLE",
|
||||
"VERTEX_AI",
|
||||
"BEDROCK",
|
||||
"RECRAFT",
|
||||
"OPENROUTER",
|
||||
"XINFERENCE",
|
||||
"NSCALE",
|
||||
name="imagegenprovider",
|
||||
create_type=False,
|
||||
)
|
||||
vision_provider = postgresql.ENUM(
|
||||
"OPENAI",
|
||||
"ANTHROPIC",
|
||||
"GOOGLE",
|
||||
"AZURE_OPENAI",
|
||||
"VERTEX_AI",
|
||||
"BEDROCK",
|
||||
"XAI",
|
||||
"OPENROUTER",
|
||||
"OLLAMA",
|
||||
"GROQ",
|
||||
"TOGETHER_AI",
|
||||
"FIREWORKS_AI",
|
||||
"DEEPSEEK",
|
||||
"MISTRAL",
|
||||
"CUSTOM",
|
||||
name="visionprovider",
|
||||
create_type=False,
|
||||
)
|
||||
|
||||
|
||||
def _table_exists(table_name: str) -> bool:
|
||||
return table_name in sa.inspect(op.get_bind()).get_table_names()
|
||||
|
||||
|
||||
def _column_exists(table_name: str, column_name: str) -> bool:
|
||||
if not _table_exists(table_name):
|
||||
return False
|
||||
return column_name in {
|
||||
column["name"] for column in sa.inspect(op.get_bind()).get_columns(table_name)
|
||||
}
|
||||
|
||||
|
||||
def _drop_column_if_exists(table_name: str, column_name: str) -> None:
|
||||
if _column_exists(table_name, column_name):
|
||||
op.drop_column(table_name, column_name)
|
||||
|
||||
|
||||
def _rename_column_if_exists(
|
||||
table_name: str,
|
||||
old_column_name: str,
|
||||
new_column_name: str,
|
||||
*,
|
||||
existing_type: TypeEngine,
|
||||
existing_nullable: bool = True,
|
||||
) -> None:
|
||||
if _column_exists(table_name, old_column_name) and not _column_exists(
|
||||
table_name, new_column_name
|
||||
):
|
||||
op.alter_column(
|
||||
table_name,
|
||||
old_column_name,
|
||||
new_column_name=new_column_name,
|
||||
existing_type=existing_type,
|
||||
existing_nullable=existing_nullable,
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
for table_name in (
|
||||
"new_llm_configs",
|
||||
"vision_llm_configs",
|
||||
"image_generation_configs",
|
||||
):
|
||||
if _table_exists(table_name):
|
||||
op.drop_table(table_name)
|
||||
|
||||
_drop_column_if_exists("searchspaces", "agent_llm_id")
|
||||
_drop_column_if_exists("searchspaces", "image_generation_config_id")
|
||||
_drop_column_if_exists("searchspaces", "vision_llm_config_id")
|
||||
|
||||
_rename_column_if_exists(
|
||||
"image_generations",
|
||||
"image_generation_config_id",
|
||||
"image_gen_model_id",
|
||||
existing_type=sa.Integer(),
|
||||
)
|
||||
|
||||
op.execute("DROP TYPE IF EXISTS litellmprovider")
|
||||
op.execute("DROP TYPE IF EXISTS imagegenprovider")
|
||||
op.execute("DROP TYPE IF EXISTS visionprovider")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
bind = op.get_bind()
|
||||
litellm_provider.create(bind, checkfirst=True)
|
||||
image_gen_provider.create(bind, checkfirst=True)
|
||||
vision_provider.create(bind, checkfirst=True)
|
||||
|
||||
_rename_column_if_exists(
|
||||
"image_generations",
|
||||
"image_gen_model_id",
|
||||
"image_generation_config_id",
|
||||
existing_type=sa.Integer(),
|
||||
)
|
||||
|
||||
if _table_exists("searchspaces"):
|
||||
if not _column_exists("searchspaces", "agent_llm_id"):
|
||||
op.add_column(
|
||||
"searchspaces",
|
||||
sa.Column("agent_llm_id", sa.Integer(), nullable=True),
|
||||
)
|
||||
if not _column_exists("searchspaces", "image_generation_config_id"):
|
||||
op.add_column(
|
||||
"searchspaces",
|
||||
sa.Column("image_generation_config_id", sa.Integer(), nullable=True),
|
||||
)
|
||||
if not _column_exists("searchspaces", "vision_llm_config_id"):
|
||||
op.add_column(
|
||||
"searchspaces",
|
||||
sa.Column("vision_llm_config_id", sa.Integer(), nullable=True),
|
||||
)
|
||||
|
||||
if not _table_exists("image_generation_configs"):
|
||||
op.create_table(
|
||||
"image_generation_configs",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("name", sa.String(length=100), nullable=False),
|
||||
sa.Column("description", sa.String(length=500), nullable=True),
|
||||
sa.Column("provider", image_gen_provider, nullable=False),
|
||||
sa.Column("custom_provider", sa.String(length=100), nullable=True),
|
||||
sa.Column("model_name", sa.String(length=100), nullable=False),
|
||||
sa.Column("api_key", sa.String(), nullable=False),
|
||||
sa.Column("api_base", sa.String(length=500), nullable=True),
|
||||
sa.Column("api_version", sa.String(length=50), nullable=True),
|
||||
sa.Column("litellm_params", sa.JSON(), nullable=True),
|
||||
sa.Column("search_space_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["search_space_id"], ["searchspaces.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_image_generation_configs_name"),
|
||||
"image_generation_configs",
|
||||
["name"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
if not _table_exists("vision_llm_configs"):
|
||||
op.create_table(
|
||||
"vision_llm_configs",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("name", sa.String(length=100), nullable=False),
|
||||
sa.Column("description", sa.String(length=500), nullable=True),
|
||||
sa.Column("provider", vision_provider, nullable=False),
|
||||
sa.Column("custom_provider", sa.String(length=100), nullable=True),
|
||||
sa.Column("model_name", sa.String(length=100), nullable=False),
|
||||
sa.Column("api_key", sa.String(), nullable=False),
|
||||
sa.Column("api_base", sa.String(length=500), nullable=True),
|
||||
sa.Column("api_version", sa.String(length=50), nullable=True),
|
||||
sa.Column("litellm_params", sa.JSON(), nullable=True),
|
||||
sa.Column("search_space_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["search_space_id"], ["searchspaces.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_vision_llm_configs_name"),
|
||||
"vision_llm_configs",
|
||||
["name"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
if not _table_exists("new_llm_configs"):
|
||||
op.create_table(
|
||||
"new_llm_configs",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("name", sa.String(length=100), nullable=False),
|
||||
sa.Column("description", sa.String(length=500), nullable=True),
|
||||
sa.Column("provider", litellm_provider, nullable=False),
|
||||
sa.Column("custom_provider", sa.String(length=100), nullable=True),
|
||||
sa.Column("model_name", sa.String(length=100), nullable=False),
|
||||
sa.Column("api_key", sa.String(), nullable=False),
|
||||
sa.Column("api_base", sa.String(length=500), nullable=True),
|
||||
sa.Column("litellm_params", sa.JSON(), nullable=True),
|
||||
sa.Column("system_instructions", sa.Text(), nullable=False),
|
||||
sa.Column("use_default_system_instructions", sa.Boolean(), nullable=False),
|
||||
sa.Column("citations_enabled", sa.Boolean(), nullable=False),
|
||||
sa.Column("search_space_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["search_space_id"], ["searchspaces.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_new_llm_configs_name"),
|
||||
"new_llm_configs",
|
||||
["name"],
|
||||
unique=False,
|
||||
)
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
"""add etl_cache_parses table for content-addressed parse reuse
|
||||
|
||||
Revision ID: 162
|
||||
Revises: 161
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "162"
|
||||
down_revision: str | None = "161"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS etl_cache_parses (
|
||||
id SERIAL PRIMARY KEY,
|
||||
source_sha256 VARCHAR(64) NOT NULL,
|
||||
etl_service VARCHAR(32) NOT NULL,
|
||||
mode VARCHAR(16) NOT NULL,
|
||||
parser_version INTEGER NOT NULL,
|
||||
storage_backend VARCHAR(32) NOT NULL,
|
||||
storage_key TEXT NOT NULL,
|
||||
size_bytes BIGINT NOT NULL,
|
||||
content_type VARCHAR(32) NOT NULL,
|
||||
actual_pages INTEGER NOT NULL DEFAULT 0,
|
||||
times_reused BIGINT NOT NULL DEFAULT 0,
|
||||
last_used_at TIMESTAMP WITH TIME ZONE NOT NULL,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
CONSTRAINT uq_etl_cache_parses_key
|
||||
UNIQUE (source_sha256, etl_service, mode, parser_version)
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_etl_cache_parses_last_used_at "
|
||||
"ON etl_cache_parses(last_used_at);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_etl_cache_parses_created_at "
|
||||
"ON etl_cache_parses(created_at);"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP INDEX IF EXISTS ix_etl_cache_parses_created_at;")
|
||||
op.execute("DROP INDEX IF EXISTS ix_etl_cache_parses_last_used_at;")
|
||||
op.execute("DROP TABLE IF EXISTS etl_cache_parses;")
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
"""add embedding_cache_sets table for content-addressed embedding reuse
|
||||
|
||||
Revision ID: 163
|
||||
Revises: 162
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "163"
|
||||
down_revision: str | None = "162"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS embedding_cache_sets (
|
||||
id SERIAL PRIMARY KEY,
|
||||
markdown_sha256 VARCHAR(64) NOT NULL,
|
||||
embedding_model VARCHAR(255) NOT NULL,
|
||||
embedding_dim INTEGER NOT NULL,
|
||||
chunker_kind VARCHAR(8) NOT NULL,
|
||||
chunker_version INTEGER NOT NULL,
|
||||
storage_backend VARCHAR(32) NOT NULL,
|
||||
storage_key TEXT NOT NULL,
|
||||
size_bytes BIGINT NOT NULL,
|
||||
chunk_count INTEGER NOT NULL DEFAULT 0,
|
||||
times_reused BIGINT NOT NULL DEFAULT 0,
|
||||
last_used_at TIMESTAMP WITH TIME ZONE NOT NULL,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
CONSTRAINT uq_embedding_cache_sets_key
|
||||
UNIQUE (markdown_sha256, embedding_model, chunker_kind, chunker_version)
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_embedding_cache_sets_last_used_at "
|
||||
"ON embedding_cache_sets(last_used_at);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_embedding_cache_sets_created_at "
|
||||
"ON embedding_cache_sets(created_at);"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP INDEX IF EXISTS ix_embedding_cache_sets_created_at;")
|
||||
op.execute("DROP INDEX IF EXISTS ix_embedding_cache_sets_last_used_at;")
|
||||
op.execute("DROP TABLE IF EXISTS embedding_cache_sets;")
|
||||
219
surfsense_backend/alembic/versions/164_remove_inactive_users.py
Normal file
219
surfsense_backend/alembic/versions/164_remove_inactive_users.py
Normal file
|
|
@ -0,0 +1,219 @@
|
|||
"""remove users that never logged back in (last_login IS NULL)
|
||||
|
||||
Migration 103 added ``user.last_login``. Any user whose ``last_login`` is still
|
||||
NULL has never authenticated since that column existed, i.e. they never logged
|
||||
back in. This migration purges those users together with everything that hangs
|
||||
off them: the search spaces they own, and (via ON DELETE CASCADE)
|
||||
``searchspaces -> documents -> chunks`` plus all other user/space-scoped rows.
|
||||
|
||||
This runs BEFORE the chunks.position backfill (revision 165) on purpose: it
|
||||
removes a large amount of dead chunk data first, so the expensive backfill has
|
||||
far fewer rows to rewrite.
|
||||
|
||||
Work is done in committed batches (not one giant cascading DELETE) so that on a
|
||||
large table it streams progress to the alembic console, keeps each transaction
|
||||
small, bounds WAL/bloat growth, and is resumable if interrupted.
|
||||
|
||||
Revision ID: 164
|
||||
Revises: 163
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "164"
|
||||
down_revision: str | None = "163"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
# Documents removed per committed batch. Each document delete cascades to its
|
||||
# chunks (via ix_chunks_document_id), so keep this modest to bound batch size.
|
||||
DOC_BATCH = 1_000
|
||||
# Users removed per committed batch. Each cascades to owned search spaces and
|
||||
# the remaining space-/user-scoped rows.
|
||||
USER_BATCH = 500
|
||||
# Minimum seconds between progress log lines (keeps the console readable).
|
||||
LOG_EVERY_SECONDS = 5.0
|
||||
|
||||
USER_SCRATCH = "_inactive_user_ids"
|
||||
DOC_SCRATCH = "_inactive_doc_ids"
|
||||
|
||||
logger = logging.getLogger("alembic.runtime.migration")
|
||||
|
||||
|
||||
def _fmt_duration(seconds: float) -> str:
|
||||
seconds = int(seconds)
|
||||
h, rem = divmod(seconds, 3600)
|
||||
m, s = divmod(rem, 60)
|
||||
if h:
|
||||
return f"{h}h{m:02d}m{s:02d}s"
|
||||
if m:
|
||||
return f"{m}m{s:02d}s"
|
||||
return f"{s}s"
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
bind = op.get_bind()
|
||||
|
||||
# Run the heavy work outside the migration's single transaction so each
|
||||
# batch can commit on its own.
|
||||
with op.get_context().autocommit_block():
|
||||
# Materialize the target user ids once. Rebuilt from scratch on every
|
||||
# run, so a re-run after an interruption simply picks up whoever still
|
||||
# has NULL last_login -> the migration is idempotent and resumable.
|
||||
op.execute(f"DROP TABLE IF EXISTS {USER_SCRATCH};")
|
||||
op.execute(
|
||||
f"CREATE UNLOGGED TABLE {USER_SCRATCH} AS "
|
||||
'SELECT id FROM "user" WHERE last_login IS NULL;'
|
||||
)
|
||||
op.execute(f"ALTER TABLE {USER_SCRATCH} ADD PRIMARY KEY (id);")
|
||||
|
||||
total_users = (
|
||||
bind.execute(sa.text(f"SELECT count(*) FROM {USER_SCRATCH}")).scalar() or 0
|
||||
)
|
||||
if total_users == 0:
|
||||
logger.info("no users with NULL last_login; nothing to remove")
|
||||
op.execute(f"DROP TABLE IF EXISTS {USER_SCRATCH};")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"found %s users with NULL last_login (never logged back in); "
|
||||
"removing them and all data in search spaces they own",
|
||||
f"{total_users:,}",
|
||||
)
|
||||
|
||||
# Documents living in search spaces owned by those users. Deleting these
|
||||
# explicitly (in batches) is what bounds the otherwise-unbounded
|
||||
# chunks cascade.
|
||||
op.execute(f"DROP TABLE IF EXISTS {DOC_SCRATCH};")
|
||||
op.execute(
|
||||
f"""
|
||||
CREATE UNLOGGED TABLE {DOC_SCRATCH} AS
|
||||
SELECT d.id
|
||||
FROM documents d
|
||||
JOIN searchspaces s ON s.id = d.search_space_id
|
||||
WHERE s.user_id IN (SELECT id FROM {USER_SCRATCH});
|
||||
"""
|
||||
)
|
||||
op.execute(f"ALTER TABLE {DOC_SCRATCH} ADD PRIMARY KEY (id);")
|
||||
total_docs = (
|
||||
bind.execute(sa.text(f"SELECT count(*) FROM {DOC_SCRATCH}")).scalar() or 0
|
||||
)
|
||||
|
||||
# Phase 1: delete documents (cascades chunks, document_versions,
|
||||
# document_files) in committed batches.
|
||||
logger.info(
|
||||
"phase 1/2: deleting %s documents (cascades their chunks) "
|
||||
"in batches of %s...",
|
||||
f"{total_docs:,}",
|
||||
f"{DOC_BATCH:,}",
|
||||
)
|
||||
_batched_delete(
|
||||
bind,
|
||||
scratch=DOC_SCRATCH,
|
||||
target_table="documents",
|
||||
target_col="id",
|
||||
batch_size=DOC_BATCH,
|
||||
total=total_docs,
|
||||
label="documents",
|
||||
)
|
||||
op.execute(f"DROP TABLE IF EXISTS {DOC_SCRATCH};")
|
||||
|
||||
# Phase 2: delete the users themselves. This cascades the now-empty
|
||||
# search spaces plus all remaining user-/space-scoped rows.
|
||||
logger.info(
|
||||
"phase 2/2: deleting %s users (cascades search spaces and "
|
||||
"remaining data) in batches of %s...",
|
||||
f"{total_users:,}",
|
||||
f"{USER_BATCH:,}",
|
||||
)
|
||||
_batched_delete(
|
||||
bind,
|
||||
scratch=USER_SCRATCH,
|
||||
target_table='"user"',
|
||||
target_col="id",
|
||||
batch_size=USER_BATCH,
|
||||
total=total_users,
|
||||
label="users",
|
||||
)
|
||||
op.execute(f"DROP TABLE IF EXISTS {USER_SCRATCH};")
|
||||
|
||||
logger.info("migration 164 finished")
|
||||
|
||||
|
||||
def _batched_delete(
|
||||
bind: sa.engine.Connection,
|
||||
*,
|
||||
scratch: str,
|
||||
target_table: str,
|
||||
target_col: str,
|
||||
batch_size: int,
|
||||
total: int,
|
||||
label: str,
|
||||
) -> None:
|
||||
"""Pop ids from ``scratch`` and delete the matching rows, one committed
|
||||
batch at a time, logging progress. Atomic per batch: the row delete and the
|
||||
scratch pop happen in a single statement, so an interrupted run leaves the
|
||||
scratch table in sync with what has actually been deleted."""
|
||||
started = time.monotonic()
|
||||
last_log = 0.0
|
||||
done = 0
|
||||
|
||||
stmt = sa.text(
|
||||
f"""
|
||||
WITH batch AS (
|
||||
SELECT id FROM {scratch} LIMIT :n
|
||||
), deleted AS (
|
||||
DELETE FROM {target_table}
|
||||
WHERE {target_col} IN (SELECT id FROM batch)
|
||||
), popped AS (
|
||||
DELETE FROM {scratch}
|
||||
WHERE id IN (SELECT id FROM batch)
|
||||
RETURNING id
|
||||
)
|
||||
SELECT count(*) FROM popped
|
||||
"""
|
||||
)
|
||||
|
||||
while True:
|
||||
popped = bind.execute(stmt, {"n": batch_size}).scalar() or 0
|
||||
if popped == 0:
|
||||
break
|
||||
done += popped
|
||||
|
||||
now = time.monotonic()
|
||||
if now - last_log >= LOG_EVERY_SECONDS or done >= total:
|
||||
elapsed = now - started
|
||||
pct = (100.0 * done / total) if total else 100.0
|
||||
eta = (elapsed / pct * (100.0 - pct)) if pct > 0 else 0.0
|
||||
logger.info(
|
||||
"%s deleted: %.1f%% (%s/%s) elapsed %s eta %s",
|
||||
label,
|
||||
pct,
|
||||
f"{done:,}",
|
||||
f"{total:,}",
|
||||
_fmt_duration(elapsed),
|
||||
_fmt_duration(eta),
|
||||
)
|
||||
last_log = now
|
||||
|
||||
logger.info(
|
||||
"deleted %s %s in %s",
|
||||
f"{done:,}",
|
||||
label,
|
||||
_fmt_duration(time.monotonic() - started),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Irreversible: deleted users and their cascaded data cannot be restored.
|
||||
# No-op so the downgrade chain can still pass through this revision.
|
||||
logger.warning(
|
||||
"migration 164 (remove_inactive_users) is irreversible; "
|
||||
"downgrade is a no-op (deleted users/data are not restored)"
|
||||
)
|
||||
183
surfsense_backend/alembic/versions/165_add_chunk_position.py
Normal file
183
surfsense_backend/alembic/versions/165_add_chunk_position.py
Normal file
|
|
@ -0,0 +1,183 @@
|
|||
"""add chunks.position for explicit document order
|
||||
|
||||
Incremental re-indexing keeps unchanged chunk rows, so auto-increment ids no
|
||||
longer reflect document order. Backfill preserves the historical id ordering.
|
||||
|
||||
The backfill is done in committed batches (not one giant UPDATE) so that on a
|
||||
large table it: streams progress to the alembic console, keeps each transaction
|
||||
small, bounds WAL/bloat growth, and is resumable if interrupted.
|
||||
|
||||
Revision ID: 165
|
||||
Revises: 164
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "165"
|
||||
down_revision: str | None = "164"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
# Number of chunk ids processed per committed batch.
|
||||
BATCH_SIZE = 100_000
|
||||
# Minimum seconds between progress log lines (keeps the console readable).
|
||||
LOG_EVERY_SECONDS = 5.0
|
||||
SCRATCH_TABLE = "_chunk_position_backfill"
|
||||
|
||||
logger = logging.getLogger("alembic.runtime.migration")
|
||||
|
||||
|
||||
def _fmt_duration(seconds: float) -> str:
|
||||
seconds = int(seconds)
|
||||
h, rem = divmod(seconds, 3600)
|
||||
m, s = divmod(rem, 60)
|
||||
if h:
|
||||
return f"{h}h{m:02d}m{s:02d}s"
|
||||
if m:
|
||||
return f"{m}m{s:02d}s"
|
||||
return f"{s}s"
|
||||
|
||||
|
||||
def _index_exists(bind: sa.engine.Connection, name: str) -> bool:
|
||||
return bool(
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"SELECT EXISTS (SELECT 1 FROM pg_class "
|
||||
"WHERE relkind = 'i' AND relname = :n)"
|
||||
),
|
||||
{"n": name},
|
||||
).scalar()
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
bind = op.get_bind()
|
||||
|
||||
# Adding a NOT NULL column with a constant default is metadata-only on
|
||||
# PostgreSQL 11+, so this is fast even on very large tables.
|
||||
op.execute(
|
||||
"ALTER TABLE chunks ADD COLUMN IF NOT EXISTS position INTEGER NOT NULL DEFAULT 0;"
|
||||
)
|
||||
|
||||
# Idempotent fast path: both indexes are created only after the backfill
|
||||
# has fully completed, so their presence is a reliable "already applied"
|
||||
# marker. This makes re-running the migration a cheap no-op.
|
||||
if _index_exists(bind, "ix_chunks_position") and _index_exists(
|
||||
bind, "ix_chunks_document_id_position"
|
||||
):
|
||||
logger.info("migration 165 already applied; skipping backfill")
|
||||
return
|
||||
|
||||
# Run the heavy work outside the migration's single transaction so each
|
||||
# batch can commit on its own.
|
||||
with op.get_context().autocommit_block():
|
||||
# reltuples is a planner estimate and is -1 on never-analyzed tables;
|
||||
# it is only used for the log line below, so treat <= 0 as "unknown".
|
||||
total_rows = (
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"SELECT reltuples::bigint FROM pg_class WHERE relname = 'chunks'"
|
||||
)
|
||||
).scalar()
|
||||
or 0
|
||||
)
|
||||
total_rows_display = (
|
||||
f"~{total_rows:,}" if total_rows > 0 else "an unknown number of"
|
||||
)
|
||||
|
||||
bounds = bind.execute(sa.text("SELECT min(id), max(id) FROM chunks")).one()
|
||||
min_id, max_id = bounds[0], bounds[1]
|
||||
|
||||
if min_id is None:
|
||||
logger.info("chunks table is empty; nothing to backfill")
|
||||
else:
|
||||
# Precompute per-document ordering once into an UNLOGGED scratch
|
||||
# table (low WAL). ROW_NUMBER must see each whole document, so it
|
||||
# cannot be computed per id-range slice.
|
||||
logger.info(
|
||||
"building position mapping for %s chunks (this is a single "
|
||||
"scan; the batched UPDATE below reports progress)...",
|
||||
total_rows_display,
|
||||
)
|
||||
op.execute(f"DROP TABLE IF EXISTS {SCRATCH_TABLE};")
|
||||
op.execute(
|
||||
f"""
|
||||
CREATE UNLOGGED TABLE {SCRATCH_TABLE} AS
|
||||
SELECT id,
|
||||
(ROW_NUMBER() OVER (PARTITION BY document_id ORDER BY id) - 1)::int AS rn
|
||||
FROM chunks;
|
||||
"""
|
||||
)
|
||||
op.execute(f"ALTER TABLE {SCRATCH_TABLE} ADD PRIMARY KEY (id);")
|
||||
|
||||
id_span = max(max_id - min_id + 1, 1)
|
||||
started = time.monotonic()
|
||||
last_log = 0.0
|
||||
updated_total = 0
|
||||
|
||||
lo = min_id
|
||||
while lo <= max_id:
|
||||
hi = lo + BATCH_SIZE # exclusive upper bound
|
||||
result = bind.execute(
|
||||
sa.text(
|
||||
f"""
|
||||
UPDATE chunks c
|
||||
SET position = m.rn
|
||||
FROM {SCRATCH_TABLE} m
|
||||
WHERE c.id = m.id
|
||||
AND c.id >= :lo
|
||||
AND c.id < :hi
|
||||
AND c.position IS DISTINCT FROM m.rn
|
||||
"""
|
||||
),
|
||||
{"lo": lo, "hi": hi},
|
||||
)
|
||||
updated_total += result.rowcount or 0
|
||||
|
||||
now = time.monotonic()
|
||||
processed_ids = min(hi, max_id + 1) - min_id
|
||||
pct = min(100.0, 100.0 * processed_ids / id_span)
|
||||
if now - last_log >= LOG_EVERY_SECONDS or hi > max_id:
|
||||
elapsed = now - started
|
||||
eta = (elapsed / pct * (100.0 - pct)) if pct > 0 else 0.0
|
||||
logger.info(
|
||||
"backfill position: %.1f%% (id<%s, %s rows rewritten) "
|
||||
"elapsed %s eta %s",
|
||||
pct,
|
||||
f"{min(hi, max_id + 1):,}",
|
||||
f"{updated_total:,}",
|
||||
_fmt_duration(elapsed),
|
||||
_fmt_duration(eta),
|
||||
)
|
||||
last_log = now
|
||||
|
||||
lo = hi
|
||||
|
||||
logger.info(
|
||||
"backfill complete: %s rows rewritten in %s",
|
||||
f"{updated_total:,}",
|
||||
_fmt_duration(time.monotonic() - started),
|
||||
)
|
||||
op.execute(f"DROP TABLE IF EXISTS {SCRATCH_TABLE};")
|
||||
|
||||
logger.info("creating index ix_chunks_position...")
|
||||
op.execute("CREATE INDEX IF NOT EXISTS ix_chunks_position ON chunks(position);")
|
||||
logger.info("creating index ix_chunks_document_id_position...")
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_chunks_document_id_position "
|
||||
"ON chunks(document_id, position);"
|
||||
)
|
||||
logger.info("migration 165 finished")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(f"DROP TABLE IF EXISTS {SCRATCH_TABLE};")
|
||||
op.execute("DROP INDEX IF EXISTS ix_chunks_document_id_position;")
|
||||
op.execute("DROP INDEX IF EXISTS ix_chunks_position;")
|
||||
op.execute("ALTER TABLE chunks DROP COLUMN IF EXISTS position;")
|
||||
|
|
@ -241,8 +241,15 @@ async def _create_document(
|
|||
chunk_embeddings = await asyncio.to_thread(embed_texts, chunks)
|
||||
session.add_all(
|
||||
[
|
||||
Chunk(document_id=doc.id, content=text, embedding=embedding)
|
||||
for text, embedding in zip(chunks, chunk_embeddings, strict=True)
|
||||
Chunk(
|
||||
document_id=doc.id,
|
||||
content=text,
|
||||
embedding=embedding,
|
||||
position=i,
|
||||
)
|
||||
for i, (text, embedding) in enumerate(
|
||||
zip(chunks, chunk_embeddings, strict=True)
|
||||
)
|
||||
]
|
||||
)
|
||||
return doc
|
||||
|
|
@ -289,8 +296,15 @@ async def _update_document(
|
|||
chunk_embeddings = await asyncio.to_thread(embed_texts, chunks)
|
||||
session.add_all(
|
||||
[
|
||||
Chunk(document_id=document.id, content=text, embedding=embedding)
|
||||
for text, embedding in zip(chunks, chunk_embeddings, strict=True)
|
||||
Chunk(
|
||||
document_id=document.id,
|
||||
content=text,
|
||||
embedding=embedding,
|
||||
position=i,
|
||||
)
|
||||
for i, (text, embedding) in enumerate(
|
||||
zip(chunks, chunk_embeddings, strict=True)
|
||||
)
|
||||
]
|
||||
)
|
||||
return document
|
||||
|
|
@ -475,7 +489,9 @@ async def _load_chunks_for_snapshot(
|
|||
session: AsyncSession, *, doc_id: int
|
||||
) -> list[dict[str, str]]:
|
||||
rows = await session.execute(
|
||||
select(Chunk.content).where(Chunk.document_id == doc_id).order_by(Chunk.id)
|
||||
select(Chunk.content)
|
||||
.where(Chunk.document_id == doc_id)
|
||||
.order_by(Chunk.position, Chunk.id)
|
||||
)
|
||||
return [{"content": row.content} for row in rows.all() if row.content is not None]
|
||||
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ async def build_agent_with_cache(
|
|||
mcp_tools_by_agent: dict[str, list[BaseTool]],
|
||||
disabled_tools: list[str] | None,
|
||||
config_id: str | None,
|
||||
image_generation_config_id_override: int | None = None,
|
||||
image_gen_model_id_override: int | None = None,
|
||||
) -> Any:
|
||||
"""Compile the multi-agent graph, serving from cache when key components are stable."""
|
||||
|
||||
|
|
@ -121,7 +121,7 @@ async def build_agent_with_cache(
|
|||
# Bound into the generate_image subagent tool at construction time, so it
|
||||
# must key the compiled-agent cache to avoid leaking one automation's
|
||||
# image model into another with the same config_id/search_space.
|
||||
image_generation_config_id_override,
|
||||
image_gen_model_id_override,
|
||||
)
|
||||
return await get_cache().get_or_build(cache_key, builder=_build)
|
||||
|
||||
|
|
|
|||
|
|
@ -72,11 +72,11 @@ async def create_multi_agent_chat_deep_agent(
|
|||
mentioned_document_ids: list[int] | None = None,
|
||||
anon_session_id: str | None = None,
|
||||
filesystem_selection: FilesystemSelection | None = None,
|
||||
image_generation_config_id: int | None = None,
|
||||
image_gen_model_id: int | None = None,
|
||||
):
|
||||
"""Deep agent with SurfSense tools/middleware; registry route subagents behind ``task`` when enabled.
|
||||
|
||||
``image_generation_config_id`` overrides the search space's image model for
|
||||
``image_gen_model_id`` overrides the search space's image model for
|
||||
this invocation (used by automations to run on their captured model). When
|
||||
``None``, the ``generate_image`` tool resolves the live search-space pref.
|
||||
"""
|
||||
|
|
@ -147,7 +147,7 @@ async def create_multi_agent_chat_deep_agent(
|
|||
"llm": llm,
|
||||
# Per-invocation image model override (automations run on their captured
|
||||
# model). Reaches the generate_image subagent tool via subagent_dependencies.
|
||||
"image_generation_config_id_override": image_generation_config_id,
|
||||
"image_gen_model_id_override": image_gen_model_id,
|
||||
}
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
|
|
@ -303,7 +303,7 @@ async def create_multi_agent_chat_deep_agent(
|
|||
mcp_tools_by_agent=mcp_tools_by_agent,
|
||||
disabled_tools=disabled_tools,
|
||||
config_id=config_id,
|
||||
image_generation_config_id_override=image_generation_config_id,
|
||||
image_gen_model_id_override=image_gen_model_id,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[create_agent] Middleware stack + graph compiled in %.3fs",
|
||||
|
|
|
|||
|
|
@ -508,7 +508,7 @@ class KBPostgresBackend(BackendProtocol):
|
|||
chunk_rows = await session.execute(
|
||||
select(Chunk.id, Chunk.content)
|
||||
.where(Chunk.document_id == document.id)
|
||||
.order_by(Chunk.id)
|
||||
.order_by(Chunk.position, Chunk.id)
|
||||
)
|
||||
chunks = [
|
||||
{"chunk_id": row.id, "content": row.content} for row in chunk_rows.all()
|
||||
|
|
@ -725,7 +725,7 @@ class KBPostgresBackend(BackendProtocol):
|
|||
.join(Document, Document.id == Chunk.document_id)
|
||||
.where(Document.search_space_id == self.search_space_id)
|
||||
.where(Chunk.content.ilike(f"%{pattern}%"))
|
||||
.order_by(Chunk.document_id, Chunk.id)
|
||||
.order_by(Chunk.document_id, Chunk.position, Chunk.id)
|
||||
)
|
||||
chunk_rows = await session.execute(sub)
|
||||
per_doc: dict[int, int] = {}
|
||||
|
|
|
|||
|
|
@ -394,7 +394,10 @@ async def browse_recent_documents(
|
|||
Chunk.document_id,
|
||||
Chunk.content,
|
||||
func.row_number()
|
||||
.over(partition_by=Chunk.document_id, order_by=Chunk.id)
|
||||
.over(
|
||||
partition_by=Chunk.document_id,
|
||||
order_by=(Chunk.position, Chunk.id),
|
||||
)
|
||||
.label("rn"),
|
||||
)
|
||||
.where(Chunk.document_id.in_(doc_ids))
|
||||
|
|
@ -404,7 +407,7 @@ async def browse_recent_documents(
|
|||
chunk_query = (
|
||||
select(numbered.c.chunk_id, numbered.c.document_id, numbered.c.content)
|
||||
.where(numbered.c.rn <= _RECENCY_MAX_CHUNKS_PER_DOC)
|
||||
.order_by(numbered.c.document_id, numbered.c.chunk_id)
|
||||
.order_by(numbered.c.document_id, numbered.c.rn)
|
||||
)
|
||||
chunk_result = await session.execute(chunk_query)
|
||||
fetched_chunks = chunk_result.all()
|
||||
|
|
@ -531,7 +534,7 @@ async def fetch_mentioned_documents(
|
|||
chunk_result = await session.execute(
|
||||
select(Chunk.id, Chunk.content, Chunk.document_id)
|
||||
.where(Chunk.document_id.in_(list(docs.keys())))
|
||||
.order_by(Chunk.document_id, Chunk.id)
|
||||
.order_by(Chunk.document_id, Chunk.position, Chunk.id)
|
||||
)
|
||||
chunks_by_doc: dict[int, list[dict[str, Any]]] = {doc_id: [] for doc_id in docs}
|
||||
for row in chunk_result.all():
|
||||
|
|
|
|||
|
|
@ -10,70 +10,53 @@ from langgraph.types import Command
|
|||
from litellm import aimage_generation
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.agents.chat.multi_agent_chat.shared.receipts.command import with_receipt
|
||||
from app.agents.chat.multi_agent_chat.shared.receipts.receipt import make_receipt
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
ImageGeneration,
|
||||
ImageGenerationConfig,
|
||||
Model,
|
||||
SearchSpace,
|
||||
shielded_async_session,
|
||||
)
|
||||
from app.services.auto_model_pin_service import (
|
||||
auto_model_candidates,
|
||||
choose_auto_model_candidate,
|
||||
)
|
||||
from app.services.image_gen_router_service import (
|
||||
IMAGE_GEN_AUTO_MODE_ID,
|
||||
ImageGenRouterService,
|
||||
is_image_gen_auto_mode,
|
||||
)
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
from app.services.model_capabilities import has_capability
|
||||
from app.services.model_resolver import to_litellm
|
||||
from app.utils.signed_image_urls import generate_image_token
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Provider mapping (same as routes)
|
||||
_PROVIDER_MAP = {
|
||||
"OPENAI": "openai",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"GOOGLE": "gemini",
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"BEDROCK": "bedrock",
|
||||
"RECRAFT": "recraft",
|
||||
"OPENROUTER": "openrouter",
|
||||
"XINFERENCE": "xinference",
|
||||
"NSCALE": "nscale",
|
||||
}
|
||||
|
||||
def _get_global_model(model_id: int) -> dict | None:
|
||||
return next((m for m in config.GLOBAL_MODELS if m.get("id") == model_id), None)
|
||||
|
||||
|
||||
def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str:
|
||||
if custom_provider:
|
||||
return custom_provider
|
||||
return _PROVIDER_MAP.get(provider.upper(), provider.lower())
|
||||
|
||||
|
||||
def _build_model_string(
|
||||
provider: str, model_name: str, custom_provider: str | None
|
||||
) -> str:
|
||||
return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}"
|
||||
|
||||
|
||||
def _get_global_image_gen_config(config_id: int) -> dict | None:
|
||||
"""Get a global image gen config by negative ID."""
|
||||
for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS:
|
||||
if cfg.get("id") == config_id:
|
||||
return cfg
|
||||
return None
|
||||
def _get_global_connection(connection_id: int) -> dict | None:
|
||||
return next(
|
||||
(c for c in config.GLOBAL_CONNECTIONS if c.get("id") == connection_id),
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def create_generate_image_tool(
|
||||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
image_generation_config_id_override: int | None = None,
|
||||
image_gen_model_id_override: int | None = None,
|
||||
):
|
||||
"""Create ``generate_image`` with bound search space; DB work uses a per-call session.
|
||||
|
||||
``image_generation_config_id_override``: when set (automations running on a
|
||||
captured model), use this config id instead of reading the search space's
|
||||
live ``image_generation_config_id``.
|
||||
``image_gen_model_id_override``: when set (automations running on a
|
||||
captured model), use this model id instead of reading the search space's
|
||||
live ``image_gen_model_id``.
|
||||
"""
|
||||
del db_session # tool uses a fresh per-call session instead
|
||||
|
||||
|
|
@ -118,26 +101,23 @@ def create_generate_image_tool(
|
|||
# task's session is shared across every tool; without isolation,
|
||||
# autoflushes from a concurrent writer poison this tool too.
|
||||
async with shielded_async_session() as session:
|
||||
if image_generation_config_id_override is not None:
|
||||
result = await session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
||||
)
|
||||
search_space = result.scalars().first()
|
||||
if not search_space:
|
||||
return _failed(
|
||||
{"error": "Search space not found"},
|
||||
error="Search space not found",
|
||||
)
|
||||
|
||||
if image_gen_model_id_override is not None:
|
||||
# Automation run: use the captured image model, insulated from
|
||||
# later search-space changes. No search-space read needed.
|
||||
config_id = (
|
||||
image_generation_config_id_override or IMAGE_GEN_AUTO_MODE_ID
|
||||
)
|
||||
config_id = image_gen_model_id_override or IMAGE_GEN_AUTO_MODE_ID
|
||||
else:
|
||||
result = await session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
||||
)
|
||||
search_space = result.scalars().first()
|
||||
if not search_space:
|
||||
return _failed(
|
||||
{"error": "Search space not found"},
|
||||
error="Search space not found",
|
||||
)
|
||||
|
||||
config_id = (
|
||||
search_space.image_generation_config_id
|
||||
or IMAGE_GEN_AUTO_MODE_ID
|
||||
search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID
|
||||
)
|
||||
|
||||
# size/quality/style are intentionally omitted: valid values
|
||||
|
|
@ -147,73 +127,86 @@ def create_generate_image_tool(
|
|||
gen_kwargs["n"] = n
|
||||
|
||||
if is_image_gen_auto_mode(config_id):
|
||||
if not ImageGenRouterService.is_initialized():
|
||||
candidates = await auto_model_candidates(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
user_id=search_space.user_id,
|
||||
capability="image_gen",
|
||||
)
|
||||
if not candidates:
|
||||
err = (
|
||||
"No image generation models configured. "
|
||||
"No image generation models available. "
|
||||
"Please add an image model in Settings > Image Models."
|
||||
)
|
||||
return _failed({"error": err}, error=err)
|
||||
response = await ImageGenRouterService.aimage_generation(
|
||||
prompt=prompt, model="auto", **gen_kwargs
|
||||
config_id = int(
|
||||
choose_auto_model_candidate(candidates, search_space_id)["id"]
|
||||
)
|
||||
elif config_id < 0:
|
||||
cfg = _get_global_image_gen_config(config_id)
|
||||
if not cfg:
|
||||
err = f"Image generation config {config_id} not found"
|
||||
|
||||
provider_base_url: str | None = None
|
||||
|
||||
if config_id < 0:
|
||||
global_model = _get_global_model(config_id)
|
||||
if not global_model or not has_capability(
|
||||
global_model, "image_gen"
|
||||
):
|
||||
err = f"Image generation model {config_id} not found"
|
||||
return _failed({"error": err}, error=err)
|
||||
global_connection = _get_global_connection(
|
||||
global_model["connection_id"]
|
||||
)
|
||||
if not global_connection:
|
||||
err = f"Image generation connection for model {config_id} not found"
|
||||
return _failed({"error": err}, error=err)
|
||||
|
||||
provider_prefix = _resolve_provider_prefix(
|
||||
cfg.get("provider", ""), cfg.get("custom_provider")
|
||||
model_string, resolved_kwargs = to_litellm(
|
||||
global_connection,
|
||||
global_model["model_id"],
|
||||
)
|
||||
model_string = f"{provider_prefix}/{cfg['model_name']}"
|
||||
gen_kwargs["api_key"] = cfg.get("api_key")
|
||||
# Defense-in-depth: an empty ``api_base`` must not fall
|
||||
# through to LiteLLM's global ``api_base`` (e.g. Azure).
|
||||
api_base = resolve_api_base(
|
||||
provider=cfg.get("provider"),
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=cfg.get("api_base"),
|
||||
)
|
||||
if api_base:
|
||||
gen_kwargs["api_base"] = api_base
|
||||
if cfg.get("api_version"):
|
||||
gen_kwargs["api_version"] = cfg["api_version"]
|
||||
if cfg.get("litellm_params"):
|
||||
gen_kwargs.update(cfg["litellm_params"])
|
||||
gen_kwargs.update(resolved_kwargs)
|
||||
provider_base_url = resolved_kwargs.get("api_base")
|
||||
|
||||
response = await aimage_generation(
|
||||
prompt=prompt, model=model_string, **gen_kwargs
|
||||
)
|
||||
else:
|
||||
# Positive ID = user-created ImageGenerationConfig
|
||||
# Positive ID = Model + Connection
|
||||
cfg_result = await session.execute(
|
||||
select(ImageGenerationConfig).filter(
|
||||
ImageGenerationConfig.id == config_id
|
||||
)
|
||||
select(Model)
|
||||
.options(selectinload(Model.connection))
|
||||
.filter(Model.id == config_id, Model.enabled.is_(True))
|
||||
)
|
||||
db_cfg = cfg_result.scalars().first()
|
||||
if not db_cfg:
|
||||
err = f"Image generation config {config_id} not found"
|
||||
db_model = cfg_result.scalars().first()
|
||||
if (
|
||||
not db_model
|
||||
or not db_model.connection
|
||||
or not db_model.connection.enabled
|
||||
):
|
||||
err = f"Image generation model {config_id} not found"
|
||||
return _failed({"error": err}, error=err)
|
||||
conn = db_model.connection
|
||||
if (
|
||||
conn.search_space_id is not None
|
||||
and conn.search_space_id != search_space_id
|
||||
):
|
||||
err = f"Image generation model {config_id} not found"
|
||||
return _failed({"error": err}, error=err)
|
||||
if (
|
||||
conn.user_id is not None
|
||||
and conn.user_id != search_space.user_id
|
||||
):
|
||||
err = f"Image generation model {config_id} not found"
|
||||
return _failed({"error": err}, error=err)
|
||||
if not has_capability(db_model, "image_gen"):
|
||||
err = f"Model {config_id} is not image-generation capable"
|
||||
return _failed({"error": err}, error=err)
|
||||
|
||||
provider_prefix = _resolve_provider_prefix(
|
||||
db_cfg.provider.value, db_cfg.custom_provider
|
||||
model_string, resolved_kwargs = to_litellm(
|
||||
db_model.connection,
|
||||
db_model.model_id,
|
||||
)
|
||||
model_string = f"{provider_prefix}/{db_cfg.model_name}"
|
||||
gen_kwargs["api_key"] = db_cfg.api_key
|
||||
# Defense-in-depth: an empty ``api_base`` must not fall
|
||||
# through to LiteLLM's global ``api_base`` (e.g. Azure).
|
||||
api_base = resolve_api_base(
|
||||
provider=db_cfg.provider.value,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=db_cfg.api_base,
|
||||
)
|
||||
if api_base:
|
||||
gen_kwargs["api_base"] = api_base
|
||||
if db_cfg.api_version:
|
||||
gen_kwargs["api_version"] = db_cfg.api_version
|
||||
if db_cfg.litellm_params:
|
||||
gen_kwargs.update(db_cfg.litellm_params)
|
||||
gen_kwargs.update(resolved_kwargs)
|
||||
provider_base_url = resolved_kwargs.get("api_base")
|
||||
|
||||
response = await aimage_generation(
|
||||
prompt=prompt, model=model_string, **gen_kwargs
|
||||
|
|
@ -230,7 +223,7 @@ def create_generate_image_tool(
|
|||
prompt=prompt,
|
||||
model=getattr(response, "_hidden_params", {}).get("model"),
|
||||
n=n,
|
||||
image_generation_config_id=config_id,
|
||||
image_gen_model_id=config_id,
|
||||
response_data=response_dict,
|
||||
search_space_id=search_space_id,
|
||||
access_token=access_token,
|
||||
|
|
@ -252,8 +245,19 @@ def create_generate_image_tool(
|
|||
|
||||
# b64_json (e.g. gpt-image-1) is served via our backend endpoint so
|
||||
# megabytes of base64 don't bloat the LLM context.
|
||||
# Some OpenAI-compatible backends (e.g. Xinference) return a relative
|
||||
# URL like /files/image.png. Browsers can't resolve these, so we
|
||||
# prepend the provider's base origin when the URL starts with "/".
|
||||
if first_image.get("url"):
|
||||
image_url = first_image["url"]
|
||||
raw_url: str = first_image["url"]
|
||||
if raw_url.startswith("/") and provider_base_url:
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(provider_base_url)
|
||||
origin = f"{parsed.scheme}://{parsed.netloc}"
|
||||
image_url = f"{origin}{raw_url}"
|
||||
else:
|
||||
image_url = raw_url
|
||||
elif first_image.get("b64_json"):
|
||||
backend_url = config.BACKEND_URL or "http://localhost:8000"
|
||||
image_url = (
|
||||
|
|
|
|||
|
|
@ -51,8 +51,6 @@ def load_tools(
|
|||
create_generate_image_tool(
|
||||
search_space_id=d["search_space_id"],
|
||||
db_session=d["db_session"],
|
||||
image_generation_config_id_override=d.get(
|
||||
"image_generation_config_id_override"
|
||||
),
|
||||
image_gen_model_id_override=d.get("image_gen_model_id_override"),
|
||||
),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -122,7 +122,7 @@ async def _browse_recent_documents(
|
|||
chunk_query = (
|
||||
select(Chunk)
|
||||
.where(Chunk.document_id.in_(doc_ids))
|
||||
.order_by(Chunk.document_id, Chunk.id)
|
||||
.order_by(Chunk.document_id, Chunk.position, Chunk.id)
|
||||
)
|
||||
chunk_result = await session.execute(chunk_query)
|
||||
raw_chunks = chunk_result.scalars().all()
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@
|
|||
LLM configuration utilities for SurfSense agents.
|
||||
|
||||
This module provides functions for loading LLM configurations from:
|
||||
1. Auto mode (ID 0) - Uses LiteLLM Router for load balancing
|
||||
1. Auto mode (ID 0) - Resolved by callers to a concrete model-connection model
|
||||
2. YAML files (global configs with negative IDs)
|
||||
3. Database NewLLMConfig table (user-created configs with positive IDs)
|
||||
3. Database model-connections table (user-created configs with positive IDs)
|
||||
|
||||
It also provides utilities for creating ChatLiteLLM instances and
|
||||
managing prompt configurations.
|
||||
|
|
@ -24,8 +24,6 @@ from langchain_core.messages import AIMessage, BaseMessage
|
|||
from langchain_core.outputs import ChatGenerationChunk, ChatResult
|
||||
from langchain_litellm import ChatLiteLLM
|
||||
from litellm import get_model_info
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.chat.runtime.prompt_caching import (
|
||||
apply_litellm_prompt_caching,
|
||||
|
|
@ -33,10 +31,7 @@ from app.agents.chat.runtime.prompt_caching import (
|
|||
from app.services.llm_router_service import (
|
||||
AUTO_MODE_ID,
|
||||
ChatLiteLLMRouter,
|
||||
LLMRouterService,
|
||||
_sanitize_content,
|
||||
get_auto_mode_llm,
|
||||
is_auto_mode,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -51,16 +46,19 @@ def _sanitize_messages(messages: list[BaseMessage]) -> list[BaseMessage]:
|
|||
reject the blank text. The OpenAI spec says ``content`` should be
|
||||
``null`` when an assistant message only carries tool calls.
|
||||
"""
|
||||
sanitized: list[BaseMessage] = []
|
||||
for msg in messages:
|
||||
if isinstance(msg.content, list):
|
||||
msg.content = _sanitize_content(msg.content)
|
||||
next_msg = msg.model_copy(deep=True)
|
||||
if isinstance(next_msg.content, list):
|
||||
next_msg.content = _sanitize_content(next_msg.content)
|
||||
if (
|
||||
isinstance(msg, AIMessage)
|
||||
and (not msg.content or msg.content == "")
|
||||
and getattr(msg, "tool_calls", None)
|
||||
isinstance(next_msg, AIMessage)
|
||||
and (not next_msg.content or next_msg.content == "")
|
||||
and getattr(next_msg, "tool_calls", None)
|
||||
):
|
||||
msg.content = None # type: ignore[assignment]
|
||||
return messages
|
||||
next_msg.content = None # type: ignore[assignment]
|
||||
sanitized.append(next_msg)
|
||||
return sanitized
|
||||
|
||||
|
||||
class SanitizedChatLiteLLM(ChatLiteLLM):
|
||||
|
|
@ -91,13 +89,21 @@ class SanitizedChatLiteLLM(ChatLiteLLM):
|
|||
):
|
||||
yield chunk
|
||||
|
||||
|
||||
# Re-exported under the historical name ``PROVIDER_MAP``. Source of truth lives
|
||||
# in provider_capabilities so the YAML loader can resolve prefixes during
|
||||
# app.config init without importing the agent/tools tree.
|
||||
from app.services.provider_capabilities import ( # noqa: E402
|
||||
_PROVIDER_PREFIX_MAP as PROVIDER_MAP,
|
||||
)
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
stream: bool | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
return await super()._agenerate(
|
||||
_sanitize_messages(messages),
|
||||
stop=stop,
|
||||
run_manager=run_manager,
|
||||
stream=stream,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None:
|
||||
|
|
@ -121,8 +127,9 @@ class AgentConfig:
|
|||
"""
|
||||
Complete configuration for the SurfSense agent.
|
||||
|
||||
This combines LLM settings with prompt configuration from NewLLMConfig.
|
||||
Supports Auto mode (ID 0) which uses LiteLLM Router for load balancing.
|
||||
This combines resolved model settings with prompt configuration.
|
||||
Supports Auto mode metadata (ID 0). Runtime callers must resolve Auto to
|
||||
a concrete global or BYOK model before constructing ChatLiteLLM.
|
||||
"""
|
||||
|
||||
# LLM Model Settings
|
||||
|
|
@ -170,7 +177,7 @@ class AgentConfig:
|
|||
use_default_system_instructions=True,
|
||||
citations_enabled=True,
|
||||
config_id=AUTO_MODE_ID,
|
||||
config_name="Auto (Fastest)",
|
||||
config_name="Auto",
|
||||
is_auto_mode=True,
|
||||
billing_tier="free",
|
||||
is_premium=False,
|
||||
|
|
@ -181,64 +188,21 @@ class AgentConfig:
|
|||
supports_image_input=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_new_llm_config(cls, config) -> "AgentConfig":
|
||||
"""Build an AgentConfig from a NewLLMConfig database model."""
|
||||
# Lazy import: keeps provider_capabilities (and litellm) out of init order.
|
||||
from app.services.provider_capabilities import derive_supports_image_input
|
||||
|
||||
provider_value = (
|
||||
config.provider.value
|
||||
if hasattr(config.provider, "value")
|
||||
else str(config.provider)
|
||||
)
|
||||
litellm_params = config.litellm_params or {}
|
||||
base_model = (
|
||||
litellm_params.get("base_model")
|
||||
if isinstance(litellm_params, dict)
|
||||
else None
|
||||
)
|
||||
|
||||
return cls(
|
||||
provider=provider_value,
|
||||
model_name=config.model_name,
|
||||
api_key=config.api_key,
|
||||
api_base=config.api_base,
|
||||
custom_provider=config.custom_provider,
|
||||
litellm_params=config.litellm_params,
|
||||
system_instructions=config.system_instructions,
|
||||
use_default_system_instructions=config.use_default_system_instructions,
|
||||
citations_enabled=config.citations_enabled,
|
||||
config_id=config.id,
|
||||
config_name=config.name,
|
||||
is_auto_mode=False,
|
||||
billing_tier="free",
|
||||
is_premium=False,
|
||||
anonymous_enabled=False,
|
||||
quota_reserve_tokens=None,
|
||||
# BYOK rows have no curated flag; ask LiteLLM (default-allow on
|
||||
# unknown). The streaming safety net still blocks explicit text-only.
|
||||
supports_image_input=derive_supports_image_input(
|
||||
provider=provider_value,
|
||||
model_name=config.model_name,
|
||||
base_model=base_model,
|
||||
custom_provider=config.custom_provider,
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_yaml_config(cls, yaml_config: dict) -> "AgentConfig":
|
||||
"""Build an AgentConfig from a YAML configuration dictionary.
|
||||
|
||||
Supports the same prompt fields as NewLLMConfig (system_instructions,
|
||||
use_default_system_instructions, citations_enabled).
|
||||
Supports prompt fields such as system_instructions,
|
||||
use_default_system_instructions, and citations_enabled.
|
||||
"""
|
||||
# Lazy import: keeps provider_capabilities (and litellm) out of init order.
|
||||
from app.services.provider_capabilities import derive_supports_image_input
|
||||
|
||||
system_instructions = yaml_config.get("system_instructions", "")
|
||||
|
||||
provider = yaml_config.get("provider", "").upper()
|
||||
provider = yaml_config.get("provider") or yaml_config.get(
|
||||
"litellm_provider", ""
|
||||
)
|
||||
model_name = yaml_config.get("model_name", "")
|
||||
custom_provider = yaml_config.get("custom_provider")
|
||||
litellm_params = yaml_config.get("litellm_params") or {}
|
||||
|
|
@ -324,93 +288,15 @@ def load_global_llm_config_by_id(llm_config_id: int) -> dict | None:
|
|||
return load_llm_config_from_yaml(llm_config_id)
|
||||
|
||||
|
||||
async def load_new_llm_config_from_db(
|
||||
session: AsyncSession,
|
||||
config_id: int,
|
||||
) -> "AgentConfig | None":
|
||||
"""Load a NewLLMConfig from the database by ID."""
|
||||
from app.db import NewLLMConfig
|
||||
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(NewLLMConfig).filter(NewLLMConfig.id == config_id)
|
||||
)
|
||||
config = result.scalars().first()
|
||||
|
||||
if not config:
|
||||
print(f"Error: NewLLMConfig with id {config_id} not found")
|
||||
return None
|
||||
|
||||
return AgentConfig.from_new_llm_config(config)
|
||||
except Exception as e:
|
||||
print(f"Error loading NewLLMConfig from database: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def load_agent_llm_config_for_search_space(
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
) -> "AgentConfig | None":
|
||||
"""Load the agent LLM config for a search space via its agent_llm_id.
|
||||
|
||||
Positive id -> DB; negative -> YAML; None -> first global config (-1).
|
||||
"""
|
||||
from app.db import SearchSpace
|
||||
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
||||
)
|
||||
search_space = result.scalars().first()
|
||||
|
||||
if not search_space:
|
||||
print(f"Error: SearchSpace with id {search_space_id} not found")
|
||||
return None
|
||||
|
||||
config_id = (
|
||||
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
|
||||
)
|
||||
return await load_agent_config(session, config_id, search_space_id)
|
||||
except Exception as e:
|
||||
print(f"Error loading agent LLM config for search space {search_space_id}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def load_agent_config(
|
||||
session: AsyncSession,
|
||||
config_id: int,
|
||||
search_space_id: int | None = None,
|
||||
) -> "AgentConfig | None":
|
||||
"""Main config loader: id 0 -> Auto mode; negative -> YAML; positive -> DB."""
|
||||
if is_auto_mode(config_id):
|
||||
if not LLMRouterService.is_initialized():
|
||||
print("Error: Auto mode requested but LLM Router not initialized")
|
||||
return None
|
||||
return AgentConfig.from_auto_mode()
|
||||
|
||||
if config_id < 0:
|
||||
# In-memory covers static YAML + dynamic OpenRouter configs.
|
||||
from app.config import config as app_config
|
||||
|
||||
for cfg in app_config.GLOBAL_LLM_CONFIGS:
|
||||
if cfg.get("id") == config_id:
|
||||
return AgentConfig.from_yaml_config(cfg)
|
||||
yaml_config = load_llm_config_from_yaml(config_id)
|
||||
if yaml_config:
|
||||
return AgentConfig.from_yaml_config(yaml_config)
|
||||
return None
|
||||
else:
|
||||
return await load_new_llm_config_from_db(session, config_id)
|
||||
|
||||
|
||||
def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
|
||||
"""Create a ChatLiteLLM instance from a global LLM config dictionary."""
|
||||
if llm_config.get("custom_provider"):
|
||||
model_string = f"{llm_config['custom_provider']}/{llm_config['model_name']}"
|
||||
else:
|
||||
provider = llm_config.get("provider", "").upper()
|
||||
provider_prefix = PROVIDER_MAP.get(provider, provider.lower())
|
||||
model_string = f"{provider_prefix}/{llm_config['model_name']}"
|
||||
provider = llm_config.get("provider") or llm_config.get(
|
||||
"litellm_provider", "openai"
|
||||
)
|
||||
model_string = f"{provider}/{llm_config['model_name']}"
|
||||
|
||||
litellm_kwargs = {
|
||||
"model": model_string,
|
||||
|
|
@ -433,29 +319,17 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
|
|||
def create_chat_litellm_from_agent_config(
|
||||
agent_config: AgentConfig,
|
||||
) -> ChatLiteLLM | ChatLiteLLMRouter | None:
|
||||
"""Create a ChatLiteLLM (or, for Auto mode, a load-balancing router) from config."""
|
||||
"""Create a ChatLiteLLM from an already resolved concrete model config."""
|
||||
if agent_config.is_auto_mode:
|
||||
if not LLMRouterService.is_initialized():
|
||||
print("Error: Auto mode requested but LLM Router not initialized")
|
||||
return None
|
||||
try:
|
||||
router_llm = get_auto_mode_llm()
|
||||
if router_llm is not None:
|
||||
# Universal injection points only: auto-mode fans out across
|
||||
# providers, so provider-specific kwargs have no known target.
|
||||
apply_litellm_prompt_caching(router_llm, agent_config=agent_config)
|
||||
return router_llm
|
||||
except Exception as e:
|
||||
print(f"Error creating ChatLiteLLMRouter: {e}")
|
||||
return None
|
||||
print(
|
||||
"Error: Auto mode must be resolved to a concrete model before LLM creation"
|
||||
)
|
||||
return None
|
||||
|
||||
if agent_config.custom_provider:
|
||||
model_string = f"{agent_config.custom_provider}/{agent_config.model_name}"
|
||||
else:
|
||||
provider_prefix = PROVIDER_MAP.get(
|
||||
agent_config.provider, agent_config.provider.lower()
|
||||
)
|
||||
model_string = f"{provider_prefix}/{agent_config.model_name}"
|
||||
model_string = f"{agent_config.provider}/{agent_config.model_name}"
|
||||
|
||||
litellm_kwargs = {
|
||||
"model": model_string,
|
||||
|
|
|
|||
|
|
@ -33,7 +33,6 @@ from app.config import (
|
|||
initialize_llm_router,
|
||||
initialize_openrouter_integration,
|
||||
initialize_pricing_registration,
|
||||
initialize_vision_llm_router,
|
||||
)
|
||||
from app.db import User, create_db_and_tables, get_async_session
|
||||
from app.exceptions import GENERIC_5XX_MESSAGE, ISSUES_URL, SurfSenseError
|
||||
|
|
@ -622,7 +621,6 @@ async def lifespan(app: FastAPI):
|
|||
initialize_pricing_registration()
|
||||
initialize_llm_router()
|
||||
initialize_image_gen_router()
|
||||
initialize_vision_llm_router()
|
||||
|
||||
# Phase 1.7 — JIT warmup. Bounded so a stuck warmup never delays
|
||||
# worker readiness. ``shield`` so Uvicorn cancelling startup
|
||||
|
|
|
|||
|
|
@ -39,31 +39,31 @@ async def build_dependencies(
|
|||
*,
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
agent_llm_id: int | None = None,
|
||||
image_generation_config_id: int | None = None,
|
||||
vision_llm_config_id: int | None = None,
|
||||
chat_model_id: int | None = None,
|
||||
image_gen_model_id: int | None = None,
|
||||
vision_model_id: int | None = None,
|
||||
) -> AgentDependencies:
|
||||
"""Load the LLM bundle, connector service, and a per-invoke in-memory checkpointer.
|
||||
|
||||
Resolves the agent LLM from the automation's *captured* model snapshot
|
||||
(``agent_llm_id``) so runs are insulated from later chat/search-space model
|
||||
Resolves the chat model from the automation's *captured* model snapshot
|
||||
(``chat_model_id``) so runs are insulated from later chat/search-space model
|
||||
changes. The model policy is enforced here as a runtime backstop: a captured
|
||||
model that is no longer billable (e.g. a premium global config was removed)
|
||||
fails the run clearly instead of silently consuming a free model.
|
||||
|
||||
When ``agent_llm_id`` is ``None`` (no captured snapshot — defensive fallback),
|
||||
fall back to the live search space's ``agent_llm_id`` and validate that.
|
||||
When ``chat_model_id`` is ``None`` (no captured snapshot — defensive fallback),
|
||||
fall back to the live search space's ``chat_model_id`` and validate that.
|
||||
"""
|
||||
if agent_llm_id is not None:
|
||||
if chat_model_id is not None:
|
||||
try:
|
||||
assert_models_billable(
|
||||
agent_llm_id=agent_llm_id,
|
||||
image_generation_config_id=image_generation_config_id,
|
||||
vision_llm_config_id=vision_llm_config_id,
|
||||
chat_model_id=chat_model_id,
|
||||
image_gen_model_id=image_gen_model_id,
|
||||
vision_model_id=vision_model_id,
|
||||
)
|
||||
except AutomationModelPolicyError as exc:
|
||||
raise DependencyError(str(exc)) from exc
|
||||
resolved_agent_llm_id = agent_llm_id or 0
|
||||
resolved_chat_model_id = chat_model_id or 0
|
||||
else:
|
||||
search_space = await session.get(SearchSpace, search_space_id)
|
||||
if search_space is None:
|
||||
|
|
@ -72,15 +72,15 @@ async def build_dependencies(
|
|||
assert_automation_models_billable(search_space)
|
||||
except AutomationModelPolicyError as exc:
|
||||
raise DependencyError(str(exc)) from exc
|
||||
resolved_agent_llm_id = search_space.agent_llm_id or 0
|
||||
resolved_chat_model_id = search_space.chat_model_id or 0
|
||||
|
||||
llm, agent_config, err = await load_llm_bundle(
|
||||
session,
|
||||
config_id=resolved_agent_llm_id,
|
||||
config_id=resolved_chat_model_id,
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
if err is not None or llm is None:
|
||||
raise DependencyError(err or "failed to load agent LLM config")
|
||||
raise DependencyError(err or "failed to load chat model config")
|
||||
|
||||
connector_service, firecrawl_api_key = await setup_connector_and_firecrawl(
|
||||
session, search_space_id=search_space_id
|
||||
|
|
|
|||
|
|
@ -150,9 +150,9 @@ async def run_agent_task(
|
|||
deps = await build_dependencies(
|
||||
session=agent_session,
|
||||
search_space_id=ctx.search_space_id,
|
||||
agent_llm_id=ctx.agent_llm_id,
|
||||
image_generation_config_id=ctx.image_generation_config_id,
|
||||
vision_llm_config_id=ctx.vision_llm_config_id,
|
||||
chat_model_id=ctx.chat_model_id,
|
||||
image_gen_model_id=ctx.image_gen_model_id,
|
||||
vision_model_id=ctx.vision_model_id,
|
||||
)
|
||||
|
||||
agent = await create_multi_agent_chat_deep_agent(
|
||||
|
|
@ -167,7 +167,7 @@ async def run_agent_task(
|
|||
firecrawl_api_key=deps.firecrawl_api_key,
|
||||
thread_visibility=ChatVisibility.PRIVATE,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
image_generation_config_id=ctx.image_generation_config_id,
|
||||
image_gen_model_id=ctx.image_gen_model_id,
|
||||
)
|
||||
|
||||
agent_query, runtime_context = await _resolve_mention_context(
|
||||
|
|
|
|||
|
|
@ -23,9 +23,9 @@ class ActionContext:
|
|||
# Captured model snapshot from the automation definition (``definition.models``),
|
||||
# resolved per run instead of the live search space. ``None`` falls back to the
|
||||
# search space's current prefs (defensive; should not happen post-capture).
|
||||
agent_llm_id: int | None = None
|
||||
image_generation_config_id: int | None = None
|
||||
vision_llm_config_id: int | None = None
|
||||
chat_model_id: int | None = None
|
||||
image_gen_model_id: int | None = None
|
||||
vision_model_id: int | None = None
|
||||
|
||||
|
||||
ActionHandler = Callable[[dict[str, Any]], Awaitable[Any]]
|
||||
|
|
|
|||
|
|
@ -132,9 +132,7 @@ def _build_action_ctx(
|
|||
step_id=step.step_id,
|
||||
search_space_id=automation.search_space_id,
|
||||
creator_user_id=automation.created_by_user_id,
|
||||
agent_llm_id=models.agent_llm_id if models else None,
|
||||
image_generation_config_id=(
|
||||
models.image_generation_config_id if models else None
|
||||
),
|
||||
vision_llm_config_id=models.vision_llm_config_id if models else None,
|
||||
chat_model_id=models.chat_model_id if models else None,
|
||||
image_gen_model_id=models.image_gen_model_id if models else None,
|
||||
vision_model_id=models.vision_model_id if models else None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -14,16 +14,16 @@ from .trigger_spec import TriggerSpec
|
|||
class AutomationModels(BaseModel):
|
||||
"""Captured model profile for an automation.
|
||||
|
||||
Snapshotted from the search space's preferences at create time so runs are
|
||||
insulated from later chat/search-space model changes. Config-id conventions
|
||||
Snapshotted from the search space's model roles at create time so runs are
|
||||
insulated from later chat/search-space model changes. Model-id conventions
|
||||
match the shared scheme (``0`` Auto, ``< 0`` global, ``> 0`` BYOK).
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
agent_llm_id: int = 0
|
||||
image_generation_config_id: int = 0
|
||||
vision_llm_config_id: int = 0
|
||||
chat_model_id: int = 0
|
||||
image_gen_model_id: int = 0
|
||||
vision_model_id: int = 0
|
||||
|
||||
|
||||
class AutomationDefinition(BaseModel):
|
||||
|
|
|
|||
|
|
@ -57,9 +57,9 @@ class AutomationService:
|
|||
else:
|
||||
search_space = await self._assert_models_billable(payload.search_space_id)
|
||||
payload.definition.models = AutomationModels(
|
||||
agent_llm_id=search_space.agent_llm_id or 0,
|
||||
image_generation_config_id=search_space.image_generation_config_id or 0,
|
||||
vision_llm_config_id=search_space.vision_llm_config_id or 0,
|
||||
chat_model_id=search_space.chat_model_id or 0,
|
||||
image_gen_model_id=search_space.image_gen_model_id or 0,
|
||||
vision_model_id=search_space.vision_model_id or 0,
|
||||
)
|
||||
|
||||
automation = Automation(
|
||||
|
|
@ -225,9 +225,9 @@ class AutomationService:
|
|||
"""
|
||||
try:
|
||||
assert_models_billable(
|
||||
agent_llm_id=models.agent_llm_id,
|
||||
image_generation_config_id=models.image_generation_config_id,
|
||||
vision_llm_config_id=models.vision_llm_config_id,
|
||||
chat_model_id=models.chat_model_id,
|
||||
image_gen_model_id=models.image_gen_model_id,
|
||||
vision_model_id=models.vision_model_id,
|
||||
)
|
||||
except AutomationModelPolicyError as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
|
|
|
|||
|
|
@ -2,11 +2,11 @@
|
|||
|
||||
Automations run unattended, so every run must be **billable**: it may only use
|
||||
either a premium global model (``billing_tier == "premium"``) or a user-provided
|
||||
BYOK model (a positive config id pointing at a per-user/per-space DB row). Free
|
||||
BYOK model (a positive model id pointing at a per-user/per-space DB row). Free
|
||||
global models and Auto mode are blocked, because Auto can dispatch to a free
|
||||
deployment and free models aren't metered in premium credits.
|
||||
|
||||
Config id conventions (shared across chat / image / vision):
|
||||
Model id conventions (shared across chat / image / vision):
|
||||
- ``id == 0`` → Auto mode (``AUTO_MODE_ID`` / ``IMAGE_GEN_AUTO_MODE_ID`` /
|
||||
``VISION_AUTO_MODE_ID``). Blocked.
|
||||
- ``id < 0`` → global YAML/OpenRouter config. Allowed only if premium.
|
||||
|
|
@ -24,70 +24,45 @@ from typing import TYPE_CHECKING, Literal
|
|||
if TYPE_CHECKING:
|
||||
from app.db import SearchSpace
|
||||
|
||||
ModelKind = Literal["llm", "image", "vision"]
|
||||
ModelKind = Literal["chat", "image", "vision"]
|
||||
|
||||
_KIND_LABEL: dict[ModelKind, str] = {
|
||||
"llm": "agent LLM",
|
||||
"chat": "chat model",
|
||||
"image": "image generation model",
|
||||
"vision": "vision model",
|
||||
}
|
||||
|
||||
|
||||
def _is_premium_global(kind: ModelKind, config_id: int) -> bool:
|
||||
"""Return True if a negative (global) config id is a premium tier model."""
|
||||
def _is_premium_global(model_id: int) -> bool:
|
||||
"""Return True if a negative (global) model id is a premium tier model."""
|
||||
from app.config import config as app_config
|
||||
|
||||
cfg: dict | None = None
|
||||
if kind == "llm":
|
||||
from app.agents.chat.runtime.llm_config import (
|
||||
load_global_llm_config_by_id,
|
||||
)
|
||||
|
||||
cfg = load_global_llm_config_by_id(config_id)
|
||||
elif kind == "image":
|
||||
cfg = next(
|
||||
(
|
||||
c
|
||||
for c in app_config.GLOBAL_IMAGE_GEN_CONFIGS
|
||||
if c.get("id") == config_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
else: # vision
|
||||
cfg = next(
|
||||
(
|
||||
c
|
||||
for c in app_config.GLOBAL_VISION_LLM_CONFIGS
|
||||
if c.get("id") == config_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not cfg:
|
||||
model = next((m for m in app_config.GLOBAL_MODELS if m.get("id") == model_id), None)
|
||||
if not model:
|
||||
return False
|
||||
return str(cfg.get("billing_tier", "free")).lower() == "premium"
|
||||
return str(model.get("billing_tier", "free")).lower() == "premium"
|
||||
|
||||
|
||||
def _classify(kind: ModelKind, config_id: int | None) -> tuple[bool, str]:
|
||||
"""Classify a resolved config id as allowed or blocked.
|
||||
def _classify(kind: ModelKind, model_id: int | None) -> tuple[bool, str]:
|
||||
"""Classify a resolved model id as allowed or blocked.
|
||||
|
||||
Returns ``(allowed, reason)``; ``reason`` is empty when allowed.
|
||||
"""
|
||||
label = _KIND_LABEL[kind]
|
||||
|
||||
if config_id is None or config_id == 0:
|
||||
if model_id is None or model_id == 0:
|
||||
return (
|
||||
False,
|
||||
f"The {label} is set to Auto mode. Automations require an explicit "
|
||||
"premium model or your own (BYOK) model so every run is billable.",
|
||||
)
|
||||
|
||||
if config_id > 0:
|
||||
# Positive id → user-owned BYOK config. Always allowed.
|
||||
if model_id > 0:
|
||||
# Positive id -> user/search-space BYOK model. Always allowed.
|
||||
return True, ""
|
||||
|
||||
# Negative id → global config. Allowed only if premium.
|
||||
if _is_premium_global(kind, config_id):
|
||||
# Negative id -> global model. Allowed only if premium.
|
||||
if _is_premium_global(model_id):
|
||||
return True, ""
|
||||
|
||||
return (
|
||||
|
|
@ -99,27 +74,27 @@ def _classify(kind: ModelKind, config_id: int | None) -> tuple[bool, str]:
|
|||
|
||||
def get_model_eligibility(
|
||||
*,
|
||||
agent_llm_id: int | None,
|
||||
image_generation_config_id: int | None,
|
||||
vision_llm_config_id: int | None,
|
||||
chat_model_id: int | None,
|
||||
image_gen_model_id: int | None,
|
||||
vision_model_id: int | None,
|
||||
) -> dict:
|
||||
"""Return ``{"allowed": bool, "violations": [...]}`` for explicit config ids.
|
||||
"""Return ``{"allowed": bool, "violations": [...]}`` for explicit model ids.
|
||||
|
||||
The ID-based core shared by both the search-space path (creation/eligibility)
|
||||
and the captured-snapshot path (runtime backstop). Each violation is
|
||||
``{"kind", "config_id", "reason"}``.
|
||||
``{"kind", "model_id", "reason"}``.
|
||||
"""
|
||||
checks: list[tuple[ModelKind, int | None]] = [
|
||||
("llm", agent_llm_id),
|
||||
("image", image_generation_config_id),
|
||||
("vision", vision_llm_config_id),
|
||||
("chat", chat_model_id),
|
||||
("image", image_gen_model_id),
|
||||
("vision", vision_model_id),
|
||||
]
|
||||
|
||||
violations: list[dict] = []
|
||||
for kind, config_id in checks:
|
||||
allowed, reason = _classify(kind, config_id)
|
||||
for kind, model_id in checks:
|
||||
allowed, reason = _classify(kind, model_id)
|
||||
if not allowed:
|
||||
violations.append({"kind": kind, "config_id": config_id, "reason": reason})
|
||||
violations.append({"kind": kind, "model_id": model_id, "reason": reason})
|
||||
|
||||
return {"allowed": not violations, "violations": violations}
|
||||
|
||||
|
|
@ -131,9 +106,9 @@ def get_automation_model_eligibility(search_space: SearchSpace) -> dict:
|
|||
wrapper over :func:`get_model_eligibility`.
|
||||
"""
|
||||
return get_model_eligibility(
|
||||
agent_llm_id=search_space.agent_llm_id,
|
||||
image_generation_config_id=search_space.image_generation_config_id,
|
||||
vision_llm_config_id=search_space.vision_llm_config_id,
|
||||
chat_model_id=search_space.chat_model_id,
|
||||
image_gen_model_id=search_space.image_gen_model_id,
|
||||
vision_model_id=search_space.vision_model_id,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -150,9 +125,9 @@ class AutomationModelPolicyError(Exception):
|
|||
|
||||
def assert_models_billable(
|
||||
*,
|
||||
agent_llm_id: int | None,
|
||||
image_generation_config_id: int | None,
|
||||
vision_llm_config_id: int | None,
|
||||
chat_model_id: int | None,
|
||||
image_gen_model_id: int | None,
|
||||
vision_model_id: int | None,
|
||||
) -> None:
|
||||
"""Raise :class:`AutomationModelPolicyError` if any explicit id is not billable.
|
||||
|
||||
|
|
@ -160,9 +135,9 @@ def assert_models_billable(
|
|||
captured model snapshot.
|
||||
"""
|
||||
result = get_model_eligibility(
|
||||
agent_llm_id=agent_llm_id,
|
||||
image_generation_config_id=image_generation_config_id,
|
||||
vision_llm_config_id=vision_llm_config_id,
|
||||
chat_model_id=chat_model_id,
|
||||
image_gen_model_id=image_gen_model_id,
|
||||
vision_model_id=vision_model_id,
|
||||
)
|
||||
if not result["allowed"]:
|
||||
raise AutomationModelPolicyError(result["violations"])
|
||||
|
|
|
|||
|
|
@ -115,14 +115,12 @@ def init_worker(**kwargs):
|
|||
initialize_llm_router,
|
||||
initialize_openrouter_integration,
|
||||
initialize_pricing_registration,
|
||||
initialize_vision_llm_router,
|
||||
)
|
||||
|
||||
initialize_openrouter_integration()
|
||||
initialize_pricing_registration()
|
||||
initialize_llm_router()
|
||||
initialize_image_gen_router()
|
||||
initialize_vision_llm_router()
|
||||
|
||||
|
||||
# Celery configuration, sourced from the central Config singleton
|
||||
|
|
@ -192,6 +190,8 @@ celery_app = Celery(
|
|||
"app.tasks.celery_tasks.stripe_reconciliation_task",
|
||||
"app.tasks.celery_tasks.auto_reload_task",
|
||||
"app.tasks.celery_tasks.gateway_tasks",
|
||||
"app.etl_pipeline.cache.eviction.task",
|
||||
"app.indexing_pipeline.cache.eviction.task",
|
||||
"app.automations.tasks.execute_run",
|
||||
"app.automations.triggers.builtin.schedule.selector",
|
||||
"app.automations.triggers.builtin.event.selector",
|
||||
|
|
@ -306,6 +306,18 @@ celery_app.conf.beat_schedule = {
|
|||
"schedule": crontab(hour="3", minute="17"),
|
||||
"options": {"expires": 600},
|
||||
},
|
||||
# Prune the ETL parse cache (TTL + size budget) once daily, off-peak.
|
||||
"evict-etl-cache": {
|
||||
"task": "evict_etl_cache",
|
||||
"schedule": crontab(hour="4", minute="0"),
|
||||
"options": {"expires": 600},
|
||||
},
|
||||
# Prune the embedding cache (chunk+embedding sets) once daily, off-peak.
|
||||
"evict-embedding-cache": {
|
||||
"task": "evict_embedding_cache",
|
||||
"schedule": crontab(hour="4", minute="30"),
|
||||
"options": {"expires": 600},
|
||||
},
|
||||
# Fire due automation schedule triggers (Beat entry owned by the schedule
|
||||
# trigger; see app.automations.triggers.builtin.schedule.source).
|
||||
**SCHEDULE_BEAT_SCHEDULE,
|
||||
|
|
|
|||
|
|
@ -78,8 +78,7 @@ def load_global_llm_configs():
|
|||
# stamps) never leak into the cached YAML structure.
|
||||
configs = copy.deepcopy(data.get("global_llm_configs", []))
|
||||
|
||||
# Lazy import keeps the `app.config` -> `app.services` edge one-way
|
||||
# and matches the `provider_api_base` pattern used elsewhere.
|
||||
# Lazy import keeps the `app.config` -> `app.services` edge one-way.
|
||||
from app.services.provider_capabilities import derive_supports_image_input
|
||||
|
||||
seen_slugs: dict[str, int] = {}
|
||||
|
|
@ -104,7 +103,7 @@ def load_global_llm_configs():
|
|||
else None
|
||||
)
|
||||
cfg["supports_image_input"] = derive_supports_image_input(
|
||||
provider=cfg.get("provider"),
|
||||
provider=cfg.get("provider") or cfg.get("litellm_provider"),
|
||||
model_name=cfg.get("model_name"),
|
||||
base_model=base_model,
|
||||
custom_provider=cfg.get("custom_provider"),
|
||||
|
|
@ -120,10 +119,10 @@ def load_global_llm_configs():
|
|||
else:
|
||||
seen_slugs[slug] = cfg.get("id", 0)
|
||||
|
||||
# Stamp Auto (Fastest) ranking metadata. YAML configs are always
|
||||
# Stamp Auto ranking metadata. YAML configs are always
|
||||
# Tier A — operator-curated, locked first when premium-eligible.
|
||||
# The OpenRouter refresh tick later re-stamps health for any cfg
|
||||
# whose provider == "OPENROUTER" via _enrich_health.
|
||||
# whose provider == "openrouter" via _enrich_health.
|
||||
try:
|
||||
from app.services.quality_score import static_score_yaml
|
||||
|
||||
|
|
@ -133,7 +132,7 @@ def load_global_llm_configs():
|
|||
cfg["quality_score_static"] = static_q
|
||||
cfg["quality_score"] = static_q
|
||||
cfg["quality_score_health"] = None
|
||||
# YAML cfgs whose provider is OPENROUTER are also subject
|
||||
# YAML cfgs whose provider is openrouter are also subject
|
||||
# to health gating against their own /endpoints data — a
|
||||
# hand-picked dead OR model is still dead. _enrich_health
|
||||
# re-stamps health_gated for them on the next refresh tick.
|
||||
|
|
@ -211,42 +210,6 @@ def load_global_image_gen_configs():
|
|||
return []
|
||||
|
||||
|
||||
def load_global_vision_llm_configs():
|
||||
data = _global_config_data()
|
||||
if not data:
|
||||
return []
|
||||
|
||||
try:
|
||||
configs = copy.deepcopy(data.get("global_vision_llm_configs", []) or [])
|
||||
for cfg in configs:
|
||||
if isinstance(cfg, dict):
|
||||
cfg.setdefault("billing_tier", "free")
|
||||
return configs
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load global vision LLM configs: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def load_vision_llm_router_settings():
|
||||
default_settings = {
|
||||
"routing_strategy": "usage-based-routing",
|
||||
"num_retries": 3,
|
||||
"allowed_fails": 3,
|
||||
"cooldown_time": 60,
|
||||
}
|
||||
|
||||
data = _global_config_data()
|
||||
if not data:
|
||||
return default_settings
|
||||
|
||||
try:
|
||||
settings = data.get("vision_llm_router_settings", {})
|
||||
return {**default_settings, **settings}
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load vision LLM router settings: {e}")
|
||||
return default_settings
|
||||
|
||||
|
||||
def load_image_gen_router_settings():
|
||||
"""
|
||||
Load router settings for image generation Auto mode from YAML file.
|
||||
|
|
@ -363,8 +326,8 @@ def initialize_openrouter_integration():
|
|||
else:
|
||||
print("Info: OpenRouter integration enabled but no models fetched")
|
||||
|
||||
# Image generation + vision LLM emissions are opt-in (issue L).
|
||||
# Both reuse the catalogue already cached by ``service.initialize``
|
||||
# Image generation emissions reuse the catalogue already cached by
|
||||
# ``service.initialize``
|
||||
# so we don't make additional network calls here.
|
||||
if settings.get("image_generation_enabled"):
|
||||
try:
|
||||
|
|
@ -378,21 +341,26 @@ def initialize_openrouter_integration():
|
|||
except Exception as e:
|
||||
print(f"Warning: Failed to inject OpenRouter image-gen configs: {e}")
|
||||
|
||||
if settings.get("vision_enabled"):
|
||||
try:
|
||||
vision_configs = service.get_vision_llm_configs()
|
||||
if vision_configs:
|
||||
config.GLOBAL_VISION_LLM_CONFIGS.extend(vision_configs)
|
||||
print(
|
||||
f"Info: OpenRouter integration added {len(vision_configs)} "
|
||||
f"vision LLM models"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to inject OpenRouter vision-LLM configs: {e}")
|
||||
refresh_global_model_catalog()
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to initialize OpenRouter integration: {e}")
|
||||
|
||||
|
||||
def materialize_global_configs():
|
||||
from app.services.global_model_catalog import materialize_global_model_catalog
|
||||
|
||||
return materialize_global_model_catalog(
|
||||
chat_configs=getattr(config, "GLOBAL_LLM_CONFIGS", []),
|
||||
image_configs=getattr(config, "GLOBAL_IMAGE_GEN_CONFIGS", []),
|
||||
)
|
||||
|
||||
|
||||
def refresh_global_model_catalog():
|
||||
connections, models = materialize_global_configs()
|
||||
config.GLOBAL_CONNECTIONS = connections
|
||||
config.GLOBAL_MODELS = models
|
||||
|
||||
|
||||
def initialize_pricing_registration():
|
||||
"""
|
||||
Teach LiteLLM the per-token cost of every deployment in
|
||||
|
|
@ -430,7 +398,10 @@ def initialize_llm_router():
|
|||
router_settings = config.ROUTER_SETTINGS
|
||||
|
||||
if not all_configs:
|
||||
print("Info: No global LLM configs found, Auto mode will not be available")
|
||||
print(
|
||||
"Info: No global LLM configs found; global Auto pool is unavailable. "
|
||||
"Auto can still use enabled BYOK models."
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
|
|
@ -475,32 +446,6 @@ def initialize_image_gen_router():
|
|||
print(f"Warning: Failed to initialize Image Generation Router: {e}")
|
||||
|
||||
|
||||
def initialize_vision_llm_router():
|
||||
vision_configs = load_global_vision_llm_configs()
|
||||
# Reuse the router settings already parsed at Config construction. The
|
||||
# *configs* list is intentionally re-read from YAML (it must exclude the
|
||||
# OpenRouter-injected dynamic models held in config.GLOBAL_VISION_LLM_CONFIGS).
|
||||
router_settings = config.VISION_LLM_ROUTER_SETTINGS
|
||||
|
||||
if not vision_configs:
|
||||
print(
|
||||
"Info: No global vision LLM configs found, "
|
||||
"Vision LLM Auto mode will not be available"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
from app.services.vision_llm_router_service import VisionLLMRouterService
|
||||
|
||||
VisionLLMRouterService.initialize(vision_configs, router_settings)
|
||||
print(
|
||||
f"Info: Vision LLM Router initialized with {len(vision_configs)} models "
|
||||
f"(strategy: {router_settings.get('routing_strategy', 'usage-based-routing')})"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to initialize Vision LLM Router: {e}")
|
||||
|
||||
|
||||
class Config:
|
||||
# Check if ffmpeg is installed
|
||||
if not is_ffmpeg_installed():
|
||||
|
|
@ -612,14 +557,15 @@ class Config:
|
|||
# Platform web search (SearXNG)
|
||||
SEARXNG_DEFAULT_HOST = os.getenv("SEARXNG_DEFAULT_HOST")
|
||||
|
||||
NEXT_FRONTEND_URL = os.getenv("NEXT_FRONTEND_URL")
|
||||
SURFSENSE_PUBLIC_URL = os.getenv("SURFSENSE_PUBLIC_URL")
|
||||
NEXT_FRONTEND_URL = os.getenv("NEXT_FRONTEND_URL") or SURFSENSE_PUBLIC_URL
|
||||
# Backend URL to override the http to https in the OAuth redirect URI
|
||||
BACKEND_URL = os.getenv("BACKEND_URL")
|
||||
BACKEND_URL = os.getenv("BACKEND_URL") or SURFSENSE_PUBLIC_URL
|
||||
|
||||
# Messaging gateway (Telegram v1)
|
||||
# Messaging gateway
|
||||
# Global master switch: when FALSE, no gateway supervisors/workers start and all
|
||||
# gateway HTTP routes return 404, regardless of the per-channel flags below.
|
||||
GATEWAY_ENABLED = os.getenv("GATEWAY_ENABLED", "TRUE").upper() == "TRUE"
|
||||
# gated gateway HTTP routes return 404, regardless of the per-channel flags below.
|
||||
GATEWAY_ENABLED = os.getenv("GATEWAY_ENABLED", "FALSE").upper() == "TRUE"
|
||||
TELEGRAM_SHARED_BOT_TOKEN = os.getenv("TELEGRAM_SHARED_BOT_TOKEN")
|
||||
TELEGRAM_SHARED_BOT_USERNAME = os.getenv("TELEGRAM_SHARED_BOT_USERNAME")
|
||||
TELEGRAM_WEBHOOK_SECRET = os.getenv("TELEGRAM_WEBHOOK_SECRET")
|
||||
|
|
@ -784,7 +730,7 @@ class Config:
|
|||
os.getenv("QUOTA_DEFAULT_IMAGE_RESERVE_MICROS", "50000")
|
||||
)
|
||||
|
||||
# Per-podcast reservation (in micro-USD). One agent LLM call generating
|
||||
# Per-podcast reservation (in micro-USD). One chat model call generating
|
||||
# a transcript, typically 5k-20k completion tokens. $0.20 covers a long
|
||||
# premium-model run. Tune via env.
|
||||
QUOTA_DEFAULT_PODCAST_RESERVE_MICROS = int(
|
||||
|
|
@ -890,6 +836,13 @@ class Config:
|
|||
# LLM instances are now managed per-user through the LLMConfig system
|
||||
# Legacy environment variables removed in favor of user-specific configurations
|
||||
|
||||
# True when an operator-provided global_llm_config.yaml is present.
|
||||
# Used to gate the per-search-space LLM onboarding flow: when a global
|
||||
# config file exists, search spaces inherit it and onboarding is skipped.
|
||||
GLOBAL_LLM_CONFIG_FILE_EXISTS = (
|
||||
BASE_DIR / "app" / "config" / "global_llm_config.yaml"
|
||||
).exists()
|
||||
|
||||
# Global LLM Configurations (optional)
|
||||
# Load from global_llm_config.yaml if available
|
||||
# These can be used as default options for users
|
||||
|
|
@ -904,11 +857,17 @@ class Config:
|
|||
# Router settings for Image Generation Auto mode
|
||||
IMAGE_GEN_ROUTER_SETTINGS = load_image_gen_router_settings()
|
||||
|
||||
# Global Vision LLM Configurations (optional)
|
||||
GLOBAL_VISION_LLM_CONFIGS = load_global_vision_llm_configs()
|
||||
# Virtual GLOBAL connection/model catalog. This is server-only metadata
|
||||
# derived from global_llm_config.yaml; GLOBAL keys are not stored in DB.
|
||||
from app.services.global_model_catalog import (
|
||||
materialize_global_model_catalog as _materialize_global_model_catalog,
|
||||
)
|
||||
|
||||
# Router settings for Vision LLM Auto mode
|
||||
VISION_LLM_ROUTER_SETTINGS = load_vision_llm_router_settings()
|
||||
GLOBAL_CONNECTIONS, GLOBAL_MODELS = _materialize_global_model_catalog(
|
||||
chat_configs=GLOBAL_LLM_CONFIGS,
|
||||
image_configs=GLOBAL_IMAGE_GEN_CONFIGS,
|
||||
)
|
||||
del _materialize_global_model_catalog
|
||||
|
||||
# OpenRouter Integration settings (optional)
|
||||
OPENROUTER_INTEGRATION_SETTINGS = load_openrouter_integration_settings()
|
||||
|
|
@ -974,6 +933,47 @@ class Config:
|
|||
AZURE_DI_ENDPOINT = os.getenv("AZURE_DI_ENDPOINT")
|
||||
AZURE_DI_KEY = os.getenv("AZURE_DI_KEY")
|
||||
|
||||
# ETL parse cache: reuse parser output for identical bytes across workspaces.
|
||||
ETL_CACHE_ENABLED = (
|
||||
os.getenv("ETL_CACHE_ENABLED", "false").strip().lower() == "true"
|
||||
)
|
||||
# Bump to invalidate every cached entry after a parser/behaviour change.
|
||||
ETL_CACHE_PARSER_VERSION = int(os.getenv("ETL_CACHE_PARSER_VERSION", "1"))
|
||||
ETL_CACHE_TTL_DAYS = int(os.getenv("ETL_CACHE_TTL_DAYS", "90"))
|
||||
ETL_CACHE_MAX_TOTAL_MB = int(os.getenv("ETL_CACHE_MAX_TOTAL_MB", "5120"))
|
||||
ETL_CACHE_EVICTION_BATCH = int(os.getenv("ETL_CACHE_EVICTION_BATCH", "500"))
|
||||
# Optional dedicated blob storage; unset reuses the main file_storage backend.
|
||||
ETL_CACHE_STORAGE_BACKEND = os.getenv("ETL_CACHE_STORAGE_BACKEND")
|
||||
ETL_CACHE_STORAGE_CONTAINER = os.getenv("ETL_CACHE_STORAGE_CONTAINER")
|
||||
ETL_CACHE_STORAGE_LOCAL_PATH = os.getenv("ETL_CACHE_STORAGE_LOCAL_PATH")
|
||||
|
||||
# Embedding cache: reuse chunk+embedding output for identical markdown across
|
||||
# workspaces. Blobs share the ETL_CACHE_STORAGE_* backend.
|
||||
EMBEDDING_CACHE_ENABLED = (
|
||||
os.getenv("EMBEDDING_CACHE_ENABLED", "false").strip().lower() == "true"
|
||||
)
|
||||
# Bump to invalidate every cached embedding set after a chunker change.
|
||||
EMBEDDING_CACHE_CHUNKER_VERSION = int(
|
||||
os.getenv("EMBEDDING_CACHE_CHUNKER_VERSION", "1")
|
||||
)
|
||||
EMBEDDING_CACHE_TTL_DAYS = int(os.getenv("EMBEDDING_CACHE_TTL_DAYS", "90"))
|
||||
EMBEDDING_CACHE_MAX_TOTAL_MB = int(
|
||||
os.getenv("EMBEDDING_CACHE_MAX_TOTAL_MB", "5120")
|
||||
)
|
||||
EMBEDDING_CACHE_EVICTION_BATCH = int(
|
||||
os.getenv("EMBEDDING_CACHE_EVICTION_BATCH", "500")
|
||||
)
|
||||
|
||||
# Incremental re-indexing: on document edits, keep chunk rows whose text is
|
||||
# unchanged (reusing their embeddings) and embed only new/changed chunks.
|
||||
# Kill switch -- disabling falls back to delete-all + full re-embed.
|
||||
CHUNK_RECONCILE_ENABLED = (
|
||||
os.getenv("CHUNK_RECONCILE_ENABLED", "true").strip().lower() == "true"
|
||||
)
|
||||
INDEXING_CHUNK_INSERT_BATCH_SIZE = int(
|
||||
os.getenv("INDEXING_CHUNK_INSERT_BATCH_SIZE", "200")
|
||||
)
|
||||
|
||||
# Proxy provider selection. Maps to a ProxyProvider implementation registered
|
||||
# in app/utils/proxy/registry.py. Add new vendors there and switch via this var.
|
||||
PROXY_PROVIDER = os.getenv("PROXY_PROVIDER", "anonymous_proxies")
|
||||
|
|
|
|||
|
|
@ -1,362 +1,236 @@
|
|||
# Global LLM Configuration
|
||||
#
|
||||
# SETUP INSTRUCTIONS:
|
||||
# 1. For production: Copy this file to global_llm_config.yaml and add your real API keys
|
||||
# 2. For testing: The system will use this example file automatically if global_llm_config.yaml doesn't exist
|
||||
# 1. Copy this file to global_llm_config.yaml.
|
||||
# 2. Replace placeholder credentials, endpoints, deployment names, and pricing
|
||||
# with values from your own provider accounts.
|
||||
#
|
||||
# NOTE: The example API keys below are placeholders and won't work.
|
||||
# Replace them with your actual API keys to enable global configurations.
|
||||
# This file is intentionally safe to commit. Do not put real API keys in this
|
||||
# example file.
|
||||
#
|
||||
# These configurations will be available to all users as a convenient option
|
||||
# Users can choose to use these global configs or add their own
|
||||
# These YAML entries are materialized at startup as server-owned GLOBAL
|
||||
# connections and models:
|
||||
#
|
||||
# AUTO MODE (Recommended):
|
||||
# - Auto mode (ID: 0) uses LiteLLM Router to automatically load balance across all global configs
|
||||
# - This helps avoid rate limits by distributing requests across multiple providers
|
||||
# - New users are automatically assigned Auto mode by default
|
||||
# - Configure router_settings below to customize the load balancing behavior
|
||||
# global_llm_configs -> GLOBAL chat models
|
||||
# global_image_generation_configs -> GLOBAL image generation models
|
||||
#
|
||||
# Structure matches NewLLMConfig:
|
||||
# - Model configuration (provider, model_name, api_key, etc.)
|
||||
# - Prompt configuration (system_instructions, citations_enabled)
|
||||
# Do not add global_connections or global_models sections here. They are
|
||||
# runtime-derived metadata exposed through the model-connections APIs.
|
||||
#
|
||||
# Static config shape:
|
||||
# - Connection fields: provider, api_key, api_base, api_version
|
||||
# - Model fields: model_name, billing_tier, rpm/tpm, capabilities, litellm_params
|
||||
# - Public no-login SEO metadata: seo_title, seo_description
|
||||
# - Prompt defaults: system_instructions, use_default_system_instructions,
|
||||
# citations_enabled
|
||||
#
|
||||
# Provider notes:
|
||||
# - Use the canonical provider field.
|
||||
# - For Azure, use the bare deployment name in model_name, for example
|
||||
# model_name: "gpt-5.1". The resolver prefixes the LiteLLM model string from
|
||||
# provider: "azure".
|
||||
#
|
||||
# GLOBAL ID namespace:
|
||||
# - ID 0 is reserved for Auto mode.
|
||||
# - Negative IDs are server-owned GLOBAL models.
|
||||
# - Positive IDs are user/BYOK database models.
|
||||
# - Keep static IDs unique across chat and image generation.
|
||||
# - Suggested static ranges: chat -1..-999, image -2001..-2999.
|
||||
# - Vision is not a separate config/table. Chat models that accept images use
|
||||
# supports_image_input: true.
|
||||
#
|
||||
# COST-BASED PREMIUM CREDITS:
|
||||
# Each premium config bills the user's USD-credit balance based on the
|
||||
# actual provider cost reported by LiteLLM. For models LiteLLM already
|
||||
# knows (most OpenAI/Anthropic/etc. names) you don't need to do anything.
|
||||
# For custom Azure deployment names (e.g. an in-house "gpt-5.4" deployment)
|
||||
# or any model LiteLLM doesn't have in its built-in pricing table, declare
|
||||
# per-token costs inline so they bill correctly:
|
||||
# Each premium model bills the user's USD-credit balance based on provider cost
|
||||
# reported by LiteLLM. For custom Azure deployments or any model LiteLLM does
|
||||
# not know, declare per-token costs inline:
|
||||
#
|
||||
# litellm_params:
|
||||
# base_model: "my-custom-azure-deploy"
|
||||
# # USD per token; e.g. 0.000003 == $3.00 per million input tokens
|
||||
# input_cost_per_token: 0.000003
|
||||
# output_cost_per_token: 0.000015
|
||||
# base_model: "my-custom-deployment"
|
||||
# # USD per token; 0.00000125 == $1.25 per million input tokens.
|
||||
# input_cost_per_token: 0.00000125
|
||||
# output_cost_per_token: 0.00001
|
||||
#
|
||||
# OpenRouter dynamic models pull pricing automatically from OpenRouter's
|
||||
# API — no inline declaration needed. Models without resolvable pricing
|
||||
# debit $0 from the user's balance and log a WARNING.
|
||||
# OpenRouter dynamic chat models pull pricing automatically from OpenRouter's
|
||||
# API. Models without resolvable pricing debit $0 and log a warning.
|
||||
|
||||
# Router Settings for Auto Mode
|
||||
# These settings control how the LiteLLM Router distributes requests across models
|
||||
# =============================================================================
|
||||
# Chat Auto Mode Router Settings
|
||||
# =============================================================================
|
||||
# These settings control how the LiteLLM Router distributes Auto-mode requests
|
||||
# across curated router-eligible GLOBAL chat deployments.
|
||||
router_settings:
|
||||
# Routing strategy options:
|
||||
# - "usage-based-routing": Routes to deployment with lowest current usage (recommended for rate limits)
|
||||
# - "simple-shuffle": Random distribution with optional RPM/TPM weighting
|
||||
# - "least-busy": Routes to least busy deployment
|
||||
# - "latency-based-routing": Routes based on response latency
|
||||
# - "usage-based-routing": Routes to deployment with lowest current usage.
|
||||
# - "simple-shuffle": Random distribution with optional RPM/TPM weighting.
|
||||
# - "least-busy": Routes to least busy deployment.
|
||||
# - "latency-based-routing": Routes based on response latency.
|
||||
routing_strategy: "usage-based-routing"
|
||||
|
||||
# Number of retries before failing
|
||||
num_retries: 3
|
||||
|
||||
# Number of failures allowed before cooling down a deployment
|
||||
allowed_fails: 3
|
||||
|
||||
# Cooldown time in seconds after allowed_fails is exceeded
|
||||
cooldown_time: 60
|
||||
# Optional fallback map:
|
||||
# fallbacks:
|
||||
# - {"azure/gpt-5.1": ["azure/gpt-5.4-mini"]}
|
||||
|
||||
# Fallback models (optional) - when primary fails, try these
|
||||
# Format: [{"primary_model": ["fallback1", "fallback2"]}]
|
||||
# fallbacks: []
|
||||
|
||||
# =============================================================================
|
||||
# Static GLOBAL Chat Models
|
||||
# =============================================================================
|
||||
global_llm_configs:
|
||||
# Example: OpenAI GPT-4 Turbo with citations enabled
|
||||
# Premium Azure chat model with image input support and explicit custom
|
||||
# pricing. This is the current shape to use for hosted GPT 5.x deployments.
|
||||
- id: -1
|
||||
name: "Global GPT-4 Turbo"
|
||||
description: "OpenAI's GPT-4 Turbo with default prompts and citations"
|
||||
billing_tier: "free"
|
||||
anonymous_enabled: true
|
||||
seo_enabled: true
|
||||
seo_slug: "gpt-4-turbo"
|
||||
name: "Azure GPT 5.1"
|
||||
billing_tier: "premium"
|
||||
anonymous_enabled: false
|
||||
seo_enabled: false
|
||||
seo_slug: "azure-gpt-5-1"
|
||||
quota_reserve_tokens: 4000
|
||||
provider: "OPENAI"
|
||||
model_name: "gpt-4-turbo-preview"
|
||||
api_key: "sk-your-openai-api-key-here"
|
||||
api_base: ""
|
||||
# Rate limits for load balancing (requests/tokens per minute)
|
||||
rpm: 500 # Requests per minute
|
||||
tpm: 100000 # Tokens per minute
|
||||
provider: "azure"
|
||||
model_name: "gpt-5.1"
|
||||
supports_image_input: true
|
||||
supports_tools: true
|
||||
max_input_tokens: 400000
|
||||
api_key: "your-azure-api-key-here"
|
||||
api_base: "https://your-resource.openai.azure.com"
|
||||
# api_version is optional. Include it if your Azure deployment requires a
|
||||
# specific API version.
|
||||
# api_version: "2025-04-01-preview"
|
||||
rpm: 47500
|
||||
tpm: 14750000
|
||||
litellm_params:
|
||||
temperature: 0.7
|
||||
max_tokens: 4000
|
||||
# Prompt Configuration
|
||||
system_instructions: "" # Empty = use default SURFSENSE_SYSTEM_INSTRUCTIONS
|
||||
max_tokens: 16384
|
||||
base_model: "gpt-5.1"
|
||||
input_cost_per_token: 0.00000125
|
||||
output_cost_per_token: 0.00001
|
||||
system_instructions: ""
|
||||
use_default_system_instructions: true
|
||||
citations_enabled: true
|
||||
|
||||
# Example: Anthropic Claude 3 Opus
|
||||
# Larger premium chat model. If your provider prices long-context traffic
|
||||
# differently, choose a conservative flat price or document the limitation
|
||||
# next to the inline pricing.
|
||||
- id: -2
|
||||
name: "Global Claude 3 Opus"
|
||||
description: "Anthropic's most capable model with citations"
|
||||
billing_tier: "free"
|
||||
anonymous_enabled: true
|
||||
seo_enabled: true
|
||||
seo_slug: "claude-3-opus"
|
||||
name: "Azure GPT 5.4"
|
||||
billing_tier: "premium"
|
||||
anonymous_enabled: false
|
||||
seo_enabled: false
|
||||
seo_slug: "azure-gpt-5-4"
|
||||
quota_reserve_tokens: 4000
|
||||
provider: "ANTHROPIC"
|
||||
model_name: "claude-3-opus-20240229"
|
||||
api_key: "sk-ant-your-anthropic-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 1000
|
||||
tpm: 100000
|
||||
provider: "azure"
|
||||
model_name: "gpt-5.4"
|
||||
supports_image_input: true
|
||||
supports_tools: true
|
||||
max_input_tokens: 400000
|
||||
api_key: "your-azure-api-key-here"
|
||||
api_base: "https://your-resource.openai.azure.com"
|
||||
rpm: 150000
|
||||
tpm: 15000000
|
||||
litellm_params:
|
||||
temperature: 0.7
|
||||
max_tokens: 4000
|
||||
max_tokens: 16384
|
||||
base_model: "gpt-5.4"
|
||||
input_cost_per_token: 0.0000025
|
||||
output_cost_per_token: 0.000015
|
||||
system_instructions: ""
|
||||
use_default_system_instructions: true
|
||||
citations_enabled: true
|
||||
|
||||
# Example: Fast model - GPT-3.5 Turbo (citations disabled for speed)
|
||||
# Free/no-login hosted model. Free models are visible to users when
|
||||
# anonymous_enabled/seo_enabled are true but do not debit premium credits.
|
||||
- id: -3
|
||||
name: "Global GPT-3.5 Turbo (Fast)"
|
||||
description: "Fast responses without citations for quick queries"
|
||||
name: "Azure GPT 5.4 Mini"
|
||||
billing_tier: "free"
|
||||
anonymous_enabled: true
|
||||
seo_enabled: true
|
||||
seo_slug: "gpt-3.5-turbo-fast"
|
||||
quota_reserve_tokens: 2000
|
||||
provider: "OPENAI"
|
||||
model_name: "gpt-3.5-turbo"
|
||||
api_key: "sk-your-openai-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 3500 # GPT-3.5 has higher rate limits
|
||||
tpm: 200000
|
||||
litellm_params:
|
||||
temperature: 0.5
|
||||
max_tokens: 2000
|
||||
system_instructions: ""
|
||||
use_default_system_instructions: true
|
||||
citations_enabled: false # Disabled for faster responses
|
||||
|
||||
# Example: Chinese LLM - DeepSeek with custom instructions
|
||||
- id: -4
|
||||
name: "Global DeepSeek Chat (Chinese)"
|
||||
description: "DeepSeek optimized for Chinese language responses"
|
||||
billing_tier: "free"
|
||||
anonymous_enabled: true
|
||||
seo_enabled: true
|
||||
seo_slug: "deepseek-chat-chinese"
|
||||
seo_slug: "gpt-5-4-mini-no-login"
|
||||
seo_title: "Free GPT 5.4 Mini Chat"
|
||||
seo_description: "Chat with a hosted GPT 5.4 Mini model without signing in."
|
||||
quota_reserve_tokens: 4000
|
||||
provider: "DEEPSEEK"
|
||||
model_name: "deepseek-chat"
|
||||
api_key: "your-deepseek-api-key-here"
|
||||
api_base: "https://api.deepseek.com/v1"
|
||||
rpm: 60
|
||||
tpm: 100000
|
||||
litellm_params:
|
||||
temperature: 0.7
|
||||
max_tokens: 4000
|
||||
# Custom system instructions for Chinese responses
|
||||
system_instructions: |
|
||||
<system_instruction>
|
||||
You are SurfSense, a reasoning and acting AI agent designed to answer user questions using the user's personal knowledge base.
|
||||
|
||||
Today's date (UTC): {resolved_today}
|
||||
|
||||
IMPORTANT: Please respond in Chinese (简体中文) unless the user specifically requests another language.
|
||||
</system_instruction>
|
||||
use_default_system_instructions: false
|
||||
citations_enabled: true
|
||||
|
||||
# Example: Azure OpenAI GPT-4o
|
||||
# IMPORTANT: For Azure deployments, always include 'base_model' in litellm_params
|
||||
# to enable accurate token counting, cost tracking, and max token limits
|
||||
- id: -5
|
||||
name: "Global Azure GPT-4o"
|
||||
description: "Azure OpenAI GPT-4o deployment"
|
||||
billing_tier: "free"
|
||||
anonymous_enabled: true
|
||||
seo_enabled: true
|
||||
seo_slug: "azure-gpt-4o"
|
||||
quota_reserve_tokens: 4000
|
||||
provider: "AZURE"
|
||||
# model_name format for Azure: azure/<your-deployment-name>
|
||||
model_name: "azure/gpt-4o-deployment"
|
||||
provider: "azure"
|
||||
model_name: "gpt-5.4-mini"
|
||||
supports_image_input: false
|
||||
supports_tools: true
|
||||
max_input_tokens: 128000
|
||||
api_key: "your-azure-api-key-here"
|
||||
api_base: "https://your-resource.openai.azure.com"
|
||||
api_version: "2024-02-15-preview" # Azure API version
|
||||
rpm: 1000
|
||||
tpm: 150000
|
||||
rpm: 15000
|
||||
tpm: 15000000
|
||||
litellm_params:
|
||||
temperature: 0.7
|
||||
max_tokens: 4000
|
||||
# REQUIRED for Azure: Specify the underlying OpenAI model
|
||||
# This fixes "Could not identify azure model" warnings
|
||||
# Common base_model values: gpt-4, gpt-4-turbo, gpt-4o, gpt-4o-mini, gpt-3.5-turbo
|
||||
base_model: "gpt-4o"
|
||||
max_tokens: 16384
|
||||
base_model: "gpt-5.4-mini"
|
||||
system_instructions: ""
|
||||
use_default_system_instructions: true
|
||||
citations_enabled: true
|
||||
|
||||
# Example: Azure OpenAI GPT-4 Turbo
|
||||
- id: -6
|
||||
name: "Global Azure GPT-4 Turbo"
|
||||
description: "Azure OpenAI GPT-4 Turbo deployment"
|
||||
billing_tier: "free"
|
||||
anonymous_enabled: true
|
||||
seo_enabled: true
|
||||
seo_slug: "azure-gpt-4-turbo"
|
||||
quota_reserve_tokens: 4000
|
||||
provider: "AZURE"
|
||||
model_name: "azure/gpt-4-turbo-deployment"
|
||||
api_key: "your-azure-api-key-here"
|
||||
api_base: "https://your-resource.openai.azure.com"
|
||||
api_version: "2024-02-15-preview"
|
||||
rpm: 500
|
||||
tpm: 100000
|
||||
litellm_params:
|
||||
temperature: 0.7
|
||||
max_tokens: 4000
|
||||
base_model: "gpt-4-turbo" # Maps to gpt-4-turbo-preview
|
||||
system_instructions: ""
|
||||
use_default_system_instructions: true
|
||||
citations_enabled: true
|
||||
|
||||
# Example: Groq - Fast inference
|
||||
- id: -7
|
||||
name: "Global Groq Llama 3"
|
||||
description: "Ultra-fast Llama 3 70B via Groq"
|
||||
billing_tier: "free"
|
||||
anonymous_enabled: true
|
||||
seo_enabled: true
|
||||
seo_slug: "groq-llama-3"
|
||||
quota_reserve_tokens: 8000
|
||||
provider: "GROQ"
|
||||
model_name: "llama3-70b-8192"
|
||||
api_key: "your-groq-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 30 # Groq has lower rate limits on free tier
|
||||
tpm: 14400
|
||||
litellm_params:
|
||||
temperature: 0.7
|
||||
max_tokens: 8000
|
||||
system_instructions: ""
|
||||
use_default_system_instructions: true
|
||||
citations_enabled: true
|
||||
|
||||
# Example: MiniMax M3 - High-performance with 512K context window
|
||||
- id: -8
|
||||
name: "Global MiniMax M3"
|
||||
description: "MiniMax M3 with 512K context window and competitive pricing"
|
||||
billing_tier: "free"
|
||||
anonymous_enabled: true
|
||||
seo_enabled: true
|
||||
seo_slug: "minimax-m3"
|
||||
quota_reserve_tokens: 4000
|
||||
provider: "MINIMAX"
|
||||
model_name: "MiniMax-M3"
|
||||
api_key: "your-minimax-api-key-here"
|
||||
api_base: "https://api.minimax.io/v1"
|
||||
rpm: 60
|
||||
tpm: 100000
|
||||
litellm_params:
|
||||
temperature: 1.0 # MiniMax requires temperature in (0.0, 1.0], cannot be 0
|
||||
max_tokens: 4000
|
||||
system_instructions: ""
|
||||
use_default_system_instructions: true
|
||||
citations_enabled: true
|
||||
|
||||
# Example: Planner LLM - small, fast model used for internal utility tasks
|
||||
#
|
||||
# The PLANNER role handles short, structured internal calls (KB query
|
||||
# rewriting, date extraction, recency classification, etc.) that don't
|
||||
# need frontier-tier capability. Pointing the planner at a cheap+fast
|
||||
# model (gpt-4o-mini, Claude Haiku, Azure gpt-5.x-nano, Groq Llama, ...)
|
||||
# typically saves 500ms-1.5s per turn vs. routing those same internal
|
||||
# calls through the user's chat model.
|
||||
#
|
||||
# Activation:
|
||||
# - Mark EXACTLY ONE global config with ``is_planner: true``.
|
||||
# - If multiple are marked, the first one wins and a WARNING is logged.
|
||||
# - If none is marked, every internal call falls back to the user's
|
||||
# chat LLM (same behavior as before this flag existed).
|
||||
#
|
||||
# This config is operator-only — it is NOT exposed in the user-facing
|
||||
# model selector, never billed against premium quota, and the
|
||||
# billing_tier / anonymous_enabled fields below are ignored.
|
||||
# Planner LLM. This is operator-only and is not shown in the user-facing
|
||||
# model selector. Only one global_llm_configs entry should set is_planner.
|
||||
- id: -9
|
||||
name: "Global Planner (GPT-4o mini)"
|
||||
description: "Internal-only planner LLM for query rewriting and classification"
|
||||
name: "Azure GPT 5.x Nano Planner"
|
||||
is_planner: true
|
||||
billing_tier: "free"
|
||||
anonymous_enabled: false
|
||||
seo_enabled: false
|
||||
quota_reserve_tokens: 1000
|
||||
provider: "OPENAI"
|
||||
model_name: "gpt-4o-mini"
|
||||
api_key: "sk-your-openai-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 3500
|
||||
tpm: 200000
|
||||
provider: "azure"
|
||||
model_name: "gpt-5.4-nano"
|
||||
supports_image_input: false
|
||||
supports_tools: false
|
||||
router_pool_eligible: false
|
||||
api_key: "your-azure-api-key-here"
|
||||
api_base: "https://your-resource.openai.azure.com"
|
||||
rpm: 20000
|
||||
tpm: 4000000
|
||||
litellm_params:
|
||||
temperature: 0
|
||||
max_tokens: 1000
|
||||
base_model: "gpt-5.4-nano"
|
||||
system_instructions: ""
|
||||
use_default_system_instructions: true
|
||||
citations_enabled: false
|
||||
|
||||
# =============================================================================
|
||||
# OpenRouter Integration
|
||||
# OpenRouter Dynamic Model Integration
|
||||
# =============================================================================
|
||||
# When enabled, dynamically fetches ALL available models from the OpenRouter API
|
||||
# and injects them as global configs. This gives premium users access to any model
|
||||
# on OpenRouter (Claude, Gemini, Llama, Mistral, etc.) via their premium token quota,
|
||||
# while free-tier OpenRouter models show up with a green Free badge and do NOT
|
||||
# consume premium quota.
|
||||
# Models are fetched at startup and refreshed periodically in the background.
|
||||
# All calls go through LiteLLM with the openrouter/ prefix.
|
||||
# When enabled, SurfSense fetches the OpenRouter catalog at startup and injects
|
||||
# supported models as GLOBAL chat and optionally image-generation models.
|
||||
# Tier is derived per model from OpenRouter data:
|
||||
# - model id ends with ":free" -> billing_tier=free
|
||||
# - prompt and completion pricing are zero -> billing_tier=free
|
||||
# - otherwise -> billing_tier=premium
|
||||
#
|
||||
# Do not use deprecated openrouter_integration.billing_tier or
|
||||
# openrouter_integration.anonymous_enabled. Use the tier-specific anonymous
|
||||
# switches below.
|
||||
openrouter_integration:
|
||||
enabled: false
|
||||
api_key: "sk-or-your-openrouter-api-key"
|
||||
|
||||
# Tier is derived PER MODEL from OpenRouter's own API signals:
|
||||
# - id ends with ":free" -> billing_tier=free
|
||||
# - pricing.prompt AND pricing.completion == "0" -> billing_tier=free
|
||||
# - otherwise -> billing_tier=premium
|
||||
# No global billing_tier knob is honored; any legacy value emits a startup warning.
|
||||
|
||||
# Anonymous access is split by tier so operators can expose only free
|
||||
# models to no-login users without leaking paid inference.
|
||||
anonymous_enabled_paid: false
|
||||
anonymous_enabled_free: false
|
||||
|
||||
seo_enabled: false
|
||||
# quota_reserve_tokens: tokens reserved per call for quota enforcement
|
||||
quota_reserve_tokens: 4000
|
||||
# id_offset: base negative ID for dynamically generated configs.
|
||||
# Model IDs are derived deterministically via BLAKE2b so they survive
|
||||
# catalogue churn. Must not overlap with your static global_llm_configs IDs.
|
||||
|
||||
# Base negative ID namespace for dynamic chat models. IDs are derived
|
||||
# deterministically so they survive catalog churn. Do not overlap static IDs.
|
||||
id_offset: -10000
|
||||
# refresh_interval_hours: how often to re-fetch models from OpenRouter (0 = startup only)
|
||||
|
||||
# Separate base negative ID namespace for dynamic image-generation models.
|
||||
image_id_offset: -20000
|
||||
|
||||
# How often to refresh the OpenRouter catalog. 0 means startup only.
|
||||
refresh_interval_hours: 24
|
||||
|
||||
# Rate limits for PAID OpenRouter models. These are used by LiteLLM Router
|
||||
# for per-deployment accounting when OR premium models participate in the
|
||||
# shared sub-agent "auto" pool. They do NOT cap OpenRouter itself — your
|
||||
# real account limits live at https://openrouter.ai/settings/limits.
|
||||
# Paid OpenRouter models may join curated router pools when eligible.
|
||||
rpm: 200
|
||||
tpm: 1000000
|
||||
|
||||
# Rate limits for FREE OpenRouter models. Informational only: free OR
|
||||
# models are intentionally kept OUT of the LiteLLM Router pool, because
|
||||
# OpenRouter enforces free-tier limits globally per account (~20 RPM +
|
||||
# 50-1000 daily requests across every ":free" model combined) —
|
||||
# per-deployment router accounting can't represent a shared bucket
|
||||
# correctly. Free OR models stay fully available in the model selector
|
||||
# and for user-facing Auto thread pinning.
|
||||
# Free OpenRouter models are available for user-facing selection/pinning but
|
||||
# should be treated as a shared-account bucket, not normal router capacity.
|
||||
free_rpm: 20
|
||||
free_tpm: 100000
|
||||
|
||||
# Image generation + vision LLM emission are OPT-IN. OpenRouter's catalogue
|
||||
# contains hundreds of image- and vision-capable models; turning these on
|
||||
# injects them into the global Image-Generation / Vision-LLM model
|
||||
# selectors alongside any static configs. Tier (free/premium) is derived
|
||||
# per model the same way it is for chat (`:free` suffix or zero pricing).
|
||||
# When a user picks a premium image/vision model the call debits the
|
||||
# shared $5 USD-cost-based premium credit pool — so leaving these off
|
||||
# avoids surprise quota burn on existing deployments. Default: false.
|
||||
# Image generation is opt-in to avoid injecting a large image catalog during
|
||||
# upgrades. Vision-capable chat models are represented with
|
||||
# supports_image_input: true.
|
||||
image_generation_enabled: false
|
||||
vision_enabled: false
|
||||
|
||||
|
|
@ -367,191 +241,80 @@ openrouter_integration:
|
|||
citations_enabled: true
|
||||
|
||||
# =============================================================================
|
||||
# Image Generation Configuration
|
||||
# Image Generation Auto Mode Router Settings
|
||||
# =============================================================================
|
||||
# These configurations power the image generation feature using litellm.aimage_generation().
|
||||
# Supported providers: OpenAI, Azure, Google AI Studio, Vertex AI, AWS Bedrock,
|
||||
# Recraft, OpenRouter, Xinference, Nscale
|
||||
#
|
||||
# Auto mode (ID 0) uses LiteLLM Router for load balancing across all image gen configs.
|
||||
|
||||
# Router Settings for Image Generation Auto Mode
|
||||
image_generation_router_settings:
|
||||
routing_strategy: "usage-based-routing"
|
||||
num_retries: 3
|
||||
allowed_fails: 3
|
||||
cooldown_time: 60
|
||||
|
||||
# =============================================================================
|
||||
# Static GLOBAL Image Generation Models
|
||||
# =============================================================================
|
||||
global_image_generation_configs:
|
||||
# Example: OpenAI DALL-E 3
|
||||
- id: -1
|
||||
name: "Global DALL-E 3"
|
||||
description: "OpenAI's DALL-E 3 for high-quality image generation"
|
||||
provider: "OPENAI"
|
||||
model_name: "dall-e-3"
|
||||
api_key: "sk-your-openai-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 50 # Requests per minute (image gen is rate-limited by RPM, not tokens)
|
||||
litellm_params: {}
|
||||
|
||||
# Example: OpenAI GPT Image 1
|
||||
- id: -2
|
||||
name: "Global GPT Image 1"
|
||||
description: "OpenAI's GPT Image 1 model"
|
||||
provider: "OPENAI"
|
||||
model_name: "gpt-image-1"
|
||||
api_key: "sk-your-openai-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 50
|
||||
litellm_params: {}
|
||||
|
||||
# Example: Azure OpenAI DALL-E 3
|
||||
- id: -3
|
||||
name: "Global Azure DALL-E 3"
|
||||
description: "Azure-hosted DALL-E 3 deployment"
|
||||
provider: "AZURE_OPENAI"
|
||||
model_name: "azure/dall-e-3-deployment"
|
||||
- id: -2001
|
||||
name: "Azure GPT Image 1.5"
|
||||
billing_tier: "premium"
|
||||
provider: "azure"
|
||||
model_name: "gpt-image-1.5"
|
||||
api_key: "your-azure-api-key-here"
|
||||
api_base: "https://your-resource.openai.azure.com"
|
||||
api_version: "2024-02-15-preview"
|
||||
rpm: 50
|
||||
# api_version: "2025-04-01-preview"
|
||||
rpm: 60
|
||||
litellm_params:
|
||||
base_model: "dall-e-3"
|
||||
base_model: "gpt-image-1.5"
|
||||
|
||||
# Example: OpenRouter Gemini Image Generation
|
||||
# - id: -4
|
||||
# name: "Global Gemini Image Gen"
|
||||
# description: "Google Gemini image generation via OpenRouter"
|
||||
# provider: "OPENROUTER"
|
||||
# model_name: "google/gemini-2.5-flash-image"
|
||||
# api_key: "your-openrouter-api-key-here"
|
||||
# api_base: ""
|
||||
# rpm: 30
|
||||
# litellm_params: {}
|
||||
- id: -2002
|
||||
name: "Azure GPT Image 1 Mini"
|
||||
billing_tier: "free"
|
||||
provider: "azure"
|
||||
model_name: "gpt-image-1-mini"
|
||||
api_key: "your-azure-api-key-here"
|
||||
api_base: "https://your-resource.openai.azure.com"
|
||||
# api_version: "2025-04-01-preview"
|
||||
rpm: 120
|
||||
litellm_params:
|
||||
base_model: "gpt-image-1-mini"
|
||||
|
||||
# =============================================================================
|
||||
# Vision LLM Configuration
|
||||
# Field Notes
|
||||
# =============================================================================
|
||||
# These configurations power the vision autocomplete feature (screenshot analysis).
|
||||
# Only vision-capable models should be used here (e.g. GPT-4o, Gemini Pro, Claude 3).
|
||||
# Supported providers: OpenAI, Anthropic, Google, Azure OpenAI, Vertex AI, Bedrock,
|
||||
# xAI, OpenRouter, Ollama, Groq, Together AI, Fireworks AI, DeepSeek, Mistral, Custom
|
||||
# Common chat/image fields:
|
||||
# - provider: Canonical provider adapter name. Example: azure, openai,
|
||||
# anthropic, openrouter, groq, bedrock.
|
||||
# - model_name: Provider model or deployment id. For Azure, use the bare
|
||||
# deployment name. The resolver prefixes LiteLLM model strings from provider.
|
||||
# - api_base: Provider endpoint/root URL. For OpenAI-compatible providers, the
|
||||
# resolver adds /v1 when needed.
|
||||
# - api_version: Optional provider-specific API version, stored on the
|
||||
# materialized connection extra metadata.
|
||||
# - litellm_params: Passed to LiteLLM when invoking the model. Also used for
|
||||
# base_model and inline pricing registration.
|
||||
#
|
||||
# Auto mode (ID 0) uses LiteLLM Router for load balancing across all vision configs.
|
||||
|
||||
# Router Settings for Vision LLM Auto Mode
|
||||
vision_llm_router_settings:
|
||||
routing_strategy: "usage-based-routing"
|
||||
num_retries: 3
|
||||
allowed_fails: 3
|
||||
cooldown_time: 60
|
||||
|
||||
global_vision_llm_configs:
|
||||
# Example: OpenAI GPT-4o (recommended for vision)
|
||||
- id: -1
|
||||
name: "Global GPT-4o Vision"
|
||||
description: "OpenAI's GPT-4o with strong vision capabilities"
|
||||
provider: "OPENAI"
|
||||
model_name: "gpt-4o"
|
||||
api_key: "sk-your-openai-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 500
|
||||
tpm: 100000
|
||||
litellm_params:
|
||||
temperature: 0.3
|
||||
max_tokens: 1000
|
||||
|
||||
# Example: Google Gemini 2.0 Flash
|
||||
- id: -2
|
||||
name: "Global Gemini 2.0 Flash"
|
||||
description: "Google's fast vision model with large context"
|
||||
provider: "GOOGLE"
|
||||
model_name: "gemini-2.0-flash"
|
||||
api_key: "your-google-ai-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 1000
|
||||
tpm: 200000
|
||||
litellm_params:
|
||||
temperature: 0.3
|
||||
max_tokens: 1000
|
||||
|
||||
# Example: Anthropic Claude 3.5 Sonnet
|
||||
- id: -3
|
||||
name: "Global Claude 3.5 Sonnet Vision"
|
||||
description: "Anthropic's Claude 3.5 Sonnet with vision support"
|
||||
provider: "ANTHROPIC"
|
||||
model_name: "claude-3-5-sonnet-20241022"
|
||||
api_key: "sk-ant-your-anthropic-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 1000
|
||||
tpm: 100000
|
||||
litellm_params:
|
||||
temperature: 0.3
|
||||
max_tokens: 1000
|
||||
|
||||
# Example: Azure OpenAI GPT-4o
|
||||
# - id: -4
|
||||
# name: "Global Azure GPT-4o Vision"
|
||||
# description: "Azure-hosted GPT-4o for vision analysis"
|
||||
# provider: "AZURE_OPENAI"
|
||||
# model_name: "azure/gpt-4o-deployment"
|
||||
# api_key: "your-azure-api-key-here"
|
||||
# api_base: "https://your-resource.openai.azure.com"
|
||||
# api_version: "2024-02-15-preview"
|
||||
# rpm: 500
|
||||
# tpm: 100000
|
||||
# litellm_params:
|
||||
# temperature: 0.3
|
||||
# max_tokens: 1000
|
||||
# base_model: "gpt-4o"
|
||||
|
||||
# Notes:
|
||||
# - ID 0 is reserved for "Auto" mode - uses LiteLLM Router for load balancing
|
||||
# - Use negative IDs to distinguish global configs from user configs (NewLLMConfig in DB)
|
||||
# - IDs should be unique and sequential (e.g., -1, -2, -3, etc.)
|
||||
# - The 'api_key' field will not be exposed to users via API
|
||||
# - system_instructions: Custom prompt or empty string to use defaults
|
||||
# - use_default_system_instructions: true = use SURFSENSE_SYSTEM_INSTRUCTIONS when system_instructions is empty
|
||||
# - citations_enabled: true = include citation instructions, false = include anti-citation instructions
|
||||
# - All standard LiteLLM providers are supported
|
||||
# - rpm/tpm: Optional rate limits for load balancing (requests/tokens per minute)
|
||||
# These help the router distribute load evenly and avoid rate limit errors
|
||||
# Chat model fields:
|
||||
# - supports_image_input: true when the chat model can consume image inputs.
|
||||
# - supports_tools: true when the model can use tools/function calling.
|
||||
# - max_input_tokens: Optional UI/catalog metadata for context size.
|
||||
# - router_pool_eligible: false keeps a model out of shared router pools while
|
||||
# still allowing direct selection/pinning.
|
||||
# - is_planner: true marks the internal-only planner model. Only one config
|
||||
# should set this flag.
|
||||
#
|
||||
# Catalog and access fields:
|
||||
# - billing_tier: "free" or "premium".
|
||||
# - anonymous_enabled: Whether the model appears in the public no-login catalog.
|
||||
# - seo_enabled: Whether a /free/<seo_slug> landing page is generated.
|
||||
# - seo_slug: Stable URL slug for SEO pages. Keep unique and do not change once
|
||||
# public.
|
||||
# - seo_title / seo_description: Optional SEO metadata overrides.
|
||||
# - quota_reserve_tokens: Tokens reserved before each chat LLM call.
|
||||
# - rpm / tpm: Optional rate limits for router accounting and load balancing.
|
||||
#
|
||||
# IMAGE GENERATION NOTES:
|
||||
# - Image generation configs use the same ID scheme as LLM configs (negative for global)
|
||||
# - Supported models: dall-e-2, dall-e-3, gpt-image-1 (OpenAI), azure/* (Azure),
|
||||
# bedrock/* (AWS), vertex_ai/* (Google), recraft/* (Recraft), openrouter/* (OpenRouter)
|
||||
# - The router uses litellm.aimage_generation() for async image generation
|
||||
# - Only RPM (requests per minute) is relevant for image generation rate limiting.
|
||||
# TPM (tokens per minute) does not apply since image APIs are billed/rate-limited per request, not per token.
|
||||
#
|
||||
# VISION LLM NOTES:
|
||||
# - Vision configs use the same ID scheme (negative for global, positive for user DB)
|
||||
# - Only use vision-capable models (GPT-4o, Gemini, Claude 3, etc.)
|
||||
# - Lower temperature (0.3) is recommended for accurate screenshot analysis
|
||||
# - Lower max_tokens (1000) is sufficient since autocomplete produces short suggestions
|
||||
#
|
||||
# PLANNER LLM NOTES:
|
||||
# - is_planner: true marks a config as the internal-only planner LLM (small,
|
||||
# fast model used for KB query rewriting, date extraction, recency
|
||||
# classification, etc.). Only one config may carry this flag — if
|
||||
# multiple do, the first one wins and a startup WARNING is logged.
|
||||
# - When no config is marked is_planner, every internal utility call falls
|
||||
# back to the user's chat LLM (the historical behavior).
|
||||
# - Planner configs are NOT shown in the user-facing model selector and
|
||||
# are NOT billed against the user's premium quota. Their billing_tier,
|
||||
# anonymous_enabled, seo_* fields are ignored.
|
||||
# - Recommended models: gpt-4o-mini, claude-3-5-haiku, gemini-1.5-flash,
|
||||
# azure gpt-5.x-nano, groq llama3-8b — anything <200ms p50 on a 1-2k
|
||||
# prompt. Frontier models here defeat the purpose of the flag.
|
||||
#
|
||||
# TOKEN QUOTA & ANONYMOUS ACCESS NOTES:
|
||||
# - billing_tier: "free" or "premium". Controls whether registered users need premium token quota.
|
||||
# - anonymous_enabled: true/false. Whether the model appears in the public no-login catalog.
|
||||
# - seo_enabled: true/false. Whether a /free/<seo_slug> landing page is generated.
|
||||
# - seo_slug: Stable URL slug for SEO pages. Must be unique. Do NOT change once public.
|
||||
# - seo_title: Optional HTML title tag override for the model's /free/<slug> page.
|
||||
# - seo_description: Optional meta description override for the model's /free/<slug> page.
|
||||
# - quota_reserve_tokens: Tokens reserved before each LLM call for quota enforcement.
|
||||
# Independent of litellm_params.max_tokens. Used by the token quota service.
|
||||
# Image generation notes:
|
||||
# - Image-generation configs use the same GLOBAL ID namespace as chat models.
|
||||
# - Only RPM is relevant for most image-generation APIs.
|
||||
# - The runtime uses litellm.aimage_generation().
|
||||
# - Image billing currently uses billing_tier and model catalog metadata. Keep
|
||||
# quota reserve tuning in code/catalog unless the materializer copies a YAML
|
||||
# key for image quota reservation.
|
||||
|
|
|
|||
|
|
@ -90,11 +90,12 @@ async def download_and_extract_content(
|
|||
if error:
|
||||
return None, metadata, error
|
||||
|
||||
from app.etl_pipeline.cache import extract_with_cache
|
||||
from app.etl_pipeline.etl_document import EtlRequest
|
||||
from app.etl_pipeline.etl_pipeline_service import EtlPipelineService
|
||||
|
||||
result = await EtlPipelineService(vision_llm=vision_llm).extract(
|
||||
EtlRequest(file_path=temp_file_path, filename=file_name)
|
||||
result = await extract_with_cache(
|
||||
EtlRequest(file_path=temp_file_path, filename=file_name),
|
||||
vision_llm=vision_llm,
|
||||
)
|
||||
markdown = result.markdown_content
|
||||
return markdown, metadata, None
|
||||
|
|
|
|||
|
|
@ -122,12 +122,13 @@ async def download_and_extract_content(
|
|||
async def _parse_file_to_markdown(
|
||||
file_path: str, filename: str, *, vision_llm=None
|
||||
) -> str:
|
||||
"""Parse a local file to markdown using the unified ETL pipeline."""
|
||||
"""Parse a local file to markdown via the cache-aware ETL pipeline."""
|
||||
from app.etl_pipeline.cache import extract_with_cache
|
||||
from app.etl_pipeline.etl_document import EtlRequest
|
||||
from app.etl_pipeline.etl_pipeline_service import EtlPipelineService
|
||||
|
||||
result = await EtlPipelineService(vision_llm=vision_llm).extract(
|
||||
EtlRequest(file_path=file_path, filename=filename)
|
||||
result = await extract_with_cache(
|
||||
EtlRequest(file_path=file_path, filename=filename),
|
||||
vision_llm=vision_llm,
|
||||
)
|
||||
return result.markdown_content
|
||||
|
||||
|
|
|
|||
|
|
@ -84,11 +84,12 @@ async def download_and_extract_content(
|
|||
async def _parse_file_to_markdown(
|
||||
file_path: str, filename: str, *, vision_llm=None
|
||||
) -> str:
|
||||
"""Parse a local file to markdown using the unified ETL pipeline."""
|
||||
"""Parse a local file to markdown via the cache-aware ETL pipeline."""
|
||||
from app.etl_pipeline.cache import extract_with_cache
|
||||
from app.etl_pipeline.etl_document import EtlRequest
|
||||
from app.etl_pipeline.etl_pipeline_service import EtlPipelineService
|
||||
|
||||
result = await EtlPipelineService(vision_llm=vision_llm).extract(
|
||||
EtlRequest(file_path=file_path, filename=filename)
|
||||
result = await extract_with_cache(
|
||||
EtlRequest(file_path=file_path, filename=filename),
|
||||
vision_llm=vision_llm,
|
||||
)
|
||||
return result.markdown_content
|
||||
|
|
|
|||
|
|
@ -201,79 +201,15 @@ class DocumentStatus:
|
|||
return None
|
||||
|
||||
|
||||
class LiteLLMProvider(StrEnum):
|
||||
"""
|
||||
Enum for LLM providers supported by LiteLLM.
|
||||
"""
|
||||
|
||||
OPENAI = "OPENAI"
|
||||
ANTHROPIC = "ANTHROPIC"
|
||||
GOOGLE = "GOOGLE"
|
||||
AZURE_OPENAI = "AZURE_OPENAI"
|
||||
BEDROCK = "BEDROCK"
|
||||
VERTEX_AI = "VERTEX_AI"
|
||||
GROQ = "GROQ"
|
||||
COHERE = "COHERE"
|
||||
MISTRAL = "MISTRAL"
|
||||
DEEPSEEK = "DEEPSEEK"
|
||||
XAI = "XAI"
|
||||
OPENROUTER = "OPENROUTER"
|
||||
TOGETHER_AI = "TOGETHER_AI"
|
||||
FIREWORKS_AI = "FIREWORKS_AI"
|
||||
REPLICATE = "REPLICATE"
|
||||
PERPLEXITY = "PERPLEXITY"
|
||||
OLLAMA = "OLLAMA"
|
||||
ALIBABA_QWEN = "ALIBABA_QWEN"
|
||||
MOONSHOT = "MOONSHOT"
|
||||
ZHIPU = "ZHIPU"
|
||||
ANYSCALE = "ANYSCALE"
|
||||
DEEPINFRA = "DEEPINFRA"
|
||||
CEREBRAS = "CEREBRAS"
|
||||
SAMBANOVA = "SAMBANOVA"
|
||||
AI21 = "AI21"
|
||||
CLOUDFLARE = "CLOUDFLARE"
|
||||
DATABRICKS = "DATABRICKS"
|
||||
COMETAPI = "COMETAPI"
|
||||
HUGGINGFACE = "HUGGINGFACE"
|
||||
GITHUB_MODELS = "GITHUB_MODELS"
|
||||
MINIMAX = "MINIMAX"
|
||||
CUSTOM = "CUSTOM"
|
||||
class ConnectionScope(StrEnum):
|
||||
GLOBAL = "GLOBAL"
|
||||
SEARCH_SPACE = "SEARCH_SPACE"
|
||||
USER = "USER"
|
||||
|
||||
|
||||
class ImageGenProvider(StrEnum):
|
||||
"""
|
||||
Enum for image generation providers supported by LiteLLM.
|
||||
This is a subset of LLM providers — only those that support image generation.
|
||||
See: https://docs.litellm.ai/docs/image_generation#supported-providers
|
||||
"""
|
||||
|
||||
OPENAI = "OPENAI"
|
||||
AZURE_OPENAI = "AZURE_OPENAI"
|
||||
GOOGLE = "GOOGLE" # Google AI Studio
|
||||
VERTEX_AI = "VERTEX_AI"
|
||||
BEDROCK = "BEDROCK" # AWS Bedrock
|
||||
RECRAFT = "RECRAFT"
|
||||
OPENROUTER = "OPENROUTER"
|
||||
XINFERENCE = "XINFERENCE"
|
||||
NSCALE = "NSCALE"
|
||||
|
||||
|
||||
class VisionProvider(StrEnum):
|
||||
OPENAI = "OPENAI"
|
||||
ANTHROPIC = "ANTHROPIC"
|
||||
GOOGLE = "GOOGLE"
|
||||
AZURE_OPENAI = "AZURE_OPENAI"
|
||||
VERTEX_AI = "VERTEX_AI"
|
||||
BEDROCK = "BEDROCK"
|
||||
XAI = "XAI"
|
||||
OPENROUTER = "OPENROUTER"
|
||||
OLLAMA = "OLLAMA"
|
||||
GROQ = "GROQ"
|
||||
TOGETHER_AI = "TOGETHER_AI"
|
||||
FIREWORKS_AI = "FIREWORKS_AI"
|
||||
DEEPSEEK = "DEEPSEEK"
|
||||
MISTRAL = "MISTRAL"
|
||||
CUSTOM = "CUSTOM"
|
||||
class ModelSource(StrEnum):
|
||||
DISCOVERED = "DISCOVERED"
|
||||
MANUAL = "MANUAL"
|
||||
|
||||
|
||||
class LogLevel(StrEnum):
|
||||
|
|
@ -702,11 +638,11 @@ class NewChatThread(BaseModel, TimestampMixin):
|
|||
default=False,
|
||||
server_default="false",
|
||||
)
|
||||
# Auto (Fastest) model pin for this thread: concrete resolved global LLM
|
||||
# Auto model pin for this thread: concrete resolved global LLM
|
||||
# config id. NULL means no pin; Auto will resolve on the next turn.
|
||||
# Single-writer invariant: only app.services.auto_model_pin_service sets
|
||||
# or clears this column (plus bulk clears when a search space's
|
||||
# agent_llm_id changes). Unindexed: all reads are by primary key.
|
||||
# chat_model_id changes). Unindexed: all reads are by primary key.
|
||||
pinned_llm_config_id = Column(Integer, nullable=True)
|
||||
|
||||
# Surface metadata for first-party SurfSense and external chat threads.
|
||||
|
|
@ -1487,7 +1423,10 @@ class Document(BaseModel, TimestampMixin):
|
|||
created_by = relationship("User", back_populates="documents")
|
||||
connector = relationship("SearchSourceConnector", back_populates="documents")
|
||||
chunks = relationship(
|
||||
"Chunk", back_populates="document", cascade="all, delete-orphan"
|
||||
"Chunk",
|
||||
back_populates="document",
|
||||
cascade="all, delete-orphan",
|
||||
order_by="Chunk.position",
|
||||
)
|
||||
# Original upload + future derived artifacts (redacted, filled-form).
|
||||
# Model lives in app.file_storage.persistence to keep that feature cohesive.
|
||||
|
|
@ -1523,6 +1462,9 @@ class Chunk(BaseModel, TimestampMixin):
|
|||
|
||||
content = Column(Text, nullable=False)
|
||||
embedding = Column(Vector(config.embedding_model_instance.dimension))
|
||||
# Explicit document order; ids don't follow it since incremental
|
||||
# re-indexing keeps unchanged rows across edits.
|
||||
position = Column(Integer, nullable=False, server_default="0", index=True)
|
||||
|
||||
document_id = Column(
|
||||
Integer,
|
||||
|
|
@ -1604,73 +1546,80 @@ class Report(BaseModel, TimestampMixin):
|
|||
thread = relationship("NewChatThread")
|
||||
|
||||
|
||||
class ImageGenerationConfig(BaseModel, TimestampMixin):
|
||||
"""
|
||||
Dedicated configuration table for image generation models.
|
||||
class Connection(BaseModel, TimestampMixin):
|
||||
__tablename__ = "connections"
|
||||
|
||||
Separate from NewLLMConfig because image generation models don't need
|
||||
system_instructions, citations_enabled, or use_default_system_instructions.
|
||||
They only need provider credentials and model parameters.
|
||||
"""
|
||||
|
||||
__tablename__ = "image_generation_configs"
|
||||
|
||||
name = Column(String(100), nullable=False, index=True)
|
||||
description = Column(String(500), nullable=True)
|
||||
|
||||
# Provider & model (uses ImageGenProvider, NOT LiteLLMProvider)
|
||||
provider = Column(SQLAlchemyEnum(ImageGenProvider), nullable=False)
|
||||
custom_provider = Column(String(100), nullable=True)
|
||||
model_name = Column(String(100), nullable=False)
|
||||
|
||||
# Credentials
|
||||
api_key = Column(String, nullable=False)
|
||||
api_base = Column(String(500), nullable=True)
|
||||
api_version = Column(String(50), nullable=True) # Azure-specific
|
||||
|
||||
# Additional litellm parameters
|
||||
litellm_params = Column(JSON, nullable=True, default={})
|
||||
|
||||
# Relationships
|
||||
search_space_id = Column(
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
search_space = relationship(
|
||||
"SearchSpace", back_populates="image_generation_configs"
|
||||
)
|
||||
|
||||
# User who created this config
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
user = relationship("User", back_populates="image_generation_configs")
|
||||
|
||||
|
||||
class VisionLLMConfig(BaseModel, TimestampMixin):
|
||||
__tablename__ = "vision_llm_configs"
|
||||
|
||||
name = Column(String(100), nullable=False, index=True)
|
||||
description = Column(String(500), nullable=True)
|
||||
|
||||
provider = Column(SQLAlchemyEnum(VisionProvider), nullable=False)
|
||||
custom_provider = Column(String(100), nullable=True)
|
||||
model_name = Column(String(100), nullable=False)
|
||||
|
||||
api_key = Column(String, nullable=False)
|
||||
api_base = Column(String(500), nullable=True)
|
||||
api_version = Column(String(50), nullable=True)
|
||||
|
||||
litellm_params = Column(JSON, nullable=True, default={})
|
||||
provider = Column(String(100), nullable=False, index=True)
|
||||
base_url = Column(String(500), nullable=True)
|
||||
api_key = Column(String, nullable=True)
|
||||
extra = Column(JSONB, nullable=False, default=dict, server_default="{}")
|
||||
scope = Column(SQLAlchemyEnum(ConnectionScope), nullable=False, index=True)
|
||||
enabled = Column(Boolean, nullable=False, default=True, server_default="true")
|
||||
|
||||
search_space_id = Column(
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
search_space = relationship("SearchSpace", back_populates="vision_llm_configs")
|
||||
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
|
||||
search_space = relationship("SearchSpace", back_populates="connections")
|
||||
user = relationship("User", back_populates="connections")
|
||||
models = relationship(
|
||||
"Model",
|
||||
back_populates="connection",
|
||||
order_by="Model.id",
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
CheckConstraint(
|
||||
"(scope = 'GLOBAL' AND search_space_id IS NULL AND user_id IS NULL) OR "
|
||||
"(scope = 'SEARCH_SPACE' AND search_space_id IS NOT NULL AND user_id IS NOT NULL) OR "
|
||||
"(scope = 'USER' AND user_id IS NOT NULL)",
|
||||
name="ck_connections_scope_owner",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class Model(BaseModel, TimestampMixin):
|
||||
__tablename__ = "models"
|
||||
|
||||
connection_id = Column(
|
||||
Integer,
|
||||
ForeignKey("connections.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
model_id = Column(String(255), nullable=False)
|
||||
display_name = Column(String(255), nullable=True)
|
||||
source = Column(
|
||||
SQLAlchemyEnum(ModelSource),
|
||||
nullable=False,
|
||||
default=ModelSource.DISCOVERED,
|
||||
server_default=ModelSource.DISCOVERED.value,
|
||||
)
|
||||
supports_chat = Column(Boolean, nullable=True)
|
||||
max_input_tokens = Column(Integer, nullable=True)
|
||||
supports_image_input = Column(Boolean, nullable=True)
|
||||
supports_tools = Column(Boolean, nullable=True)
|
||||
supports_image_generation = Column(Boolean, nullable=True)
|
||||
capabilities_override = Column(
|
||||
JSONB, nullable=False, default=dict, server_default="{}"
|
||||
)
|
||||
enabled = Column(Boolean, nullable=False, default=True, server_default="true")
|
||||
billing_tier = Column(String(50), nullable=True, index=True)
|
||||
catalog = Column(JSONB, nullable=False, default=dict, server_default="{}")
|
||||
|
||||
connection = relationship("Connection", back_populates="models")
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"connection_id", "model_id", name="uq_models_connection_model_id"
|
||||
),
|
||||
Index("ix_models_model_id", "model_id"),
|
||||
)
|
||||
user = relationship("User", back_populates="vision_llm_configs")
|
||||
|
||||
|
||||
class ImageGeneration(BaseModel, TimestampMixin):
|
||||
|
|
@ -1704,10 +1653,9 @@ class ImageGeneration(BaseModel, TimestampMixin):
|
|||
style = Column(String(50), nullable=True) # Model-specific style parameter
|
||||
response_format = Column(String(50), nullable=True) # "url" or "b64_json"
|
||||
|
||||
# Image generation config reference
|
||||
# 0 = Auto mode (router), negative IDs = global configs from YAML,
|
||||
# positive IDs = ImageGenerationConfig records in DB
|
||||
image_generation_config_id = Column(Integer, nullable=True)
|
||||
# Image generation model provenance.
|
||||
# 0 = Auto mode, negative IDs = GLOBAL models, positive IDs = Model records.
|
||||
image_gen_model_id = Column(Integer, nullable=True)
|
||||
|
||||
# Response data (full litellm response as JSONB) — present on success
|
||||
response_data = Column(JSONB, nullable=True)
|
||||
|
|
@ -1749,19 +1697,19 @@ class SearchSpace(BaseModel, TimestampMixin):
|
|||
|
||||
shared_memory_md = Column(Text, nullable=True, server_default="")
|
||||
|
||||
# Search space-level LLM preferences (shared by all members)
|
||||
# Note: ID values:
|
||||
# - 0: Auto mode (uses LiteLLM Router for load balancing) - default for new search spaces
|
||||
# - Negative IDs: Global configs from YAML
|
||||
# - Positive IDs: Custom configs from DB (NewLLMConfig table)
|
||||
agent_llm_id = Column(
|
||||
Integer, nullable=True, default=0
|
||||
# Connection/model role bindings.
|
||||
# Note: ID values preserve the existing convention:
|
||||
# - 0: Auto mode
|
||||
# - Negative IDs: Global virtual models from global_llm_config.yaml
|
||||
# - Positive IDs: User/search-space models from the models table
|
||||
chat_model_id = Column(
|
||||
Integer, nullable=True, default=0, server_default="0"
|
||||
) # For agent/chat operations, defaults to Auto mode
|
||||
image_generation_config_id = Column(
|
||||
Integer, nullable=True, default=0
|
||||
) # For image generation, defaults to Auto mode
|
||||
vision_llm_config_id = Column(
|
||||
Integer, nullable=True, default=0
|
||||
image_gen_model_id = Column(
|
||||
Integer, nullable=True, default=0, server_default="0"
|
||||
) # For image generation, defaults to Auto mode when eligible
|
||||
vision_model_id = Column(
|
||||
Integer, nullable=True, default=0, server_default="0"
|
||||
) # For vision/screenshot analysis, defaults to Auto mode
|
||||
|
||||
ai_file_sort_enabled = Column(
|
||||
|
|
@ -1833,23 +1781,12 @@ class SearchSpace(BaseModel, TimestampMixin):
|
|||
order_by="SearchSourceConnector.id",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
new_llm_configs = relationship(
|
||||
"NewLLMConfig",
|
||||
connections = relationship(
|
||||
"Connection",
|
||||
back_populates="search_space",
|
||||
order_by="NewLLMConfig.id",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
image_generation_configs = relationship(
|
||||
"ImageGenerationConfig",
|
||||
back_populates="search_space",
|
||||
order_by="ImageGenerationConfig.id",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
vision_llm_configs = relationship(
|
||||
"VisionLLMConfig",
|
||||
back_populates="search_space",
|
||||
order_by="VisionLLMConfig.id",
|
||||
order_by="Connection.id",
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
automations = relationship(
|
||||
|
|
@ -1952,64 +1889,6 @@ class SearchSourceConnector(BaseModel, TimestampMixin):
|
|||
documents = relationship("Document", back_populates="connector")
|
||||
|
||||
|
||||
class NewLLMConfig(BaseModel, TimestampMixin):
|
||||
"""
|
||||
New LLM configuration table that combines model settings with prompt configuration.
|
||||
|
||||
This table provides:
|
||||
- LLM model configuration (provider, model_name, api_key, etc.)
|
||||
- Configurable system instructions (defaults to SURFSENSE_SYSTEM_INSTRUCTIONS)
|
||||
- Citation toggle (enable/disable citation instructions)
|
||||
|
||||
Note: Tools instructions are built by get_tools_instructions(thread_visibility) (personal vs shared memory).
|
||||
"""
|
||||
|
||||
__tablename__ = "new_llm_configs"
|
||||
|
||||
name = Column(String(100), nullable=False, index=True)
|
||||
description = Column(String(500), nullable=True)
|
||||
|
||||
# === LLM Model Configuration (from original LLMConfig, excluding 'language') ===
|
||||
# Provider from the enum
|
||||
provider = Column(SQLAlchemyEnum(LiteLLMProvider), nullable=False)
|
||||
# Custom provider name when provider is CUSTOM
|
||||
custom_provider = Column(String(100), nullable=True)
|
||||
# Just the model name without provider prefix
|
||||
model_name = Column(String(100), nullable=False)
|
||||
# API Key should be encrypted before storing
|
||||
api_key = Column(String, nullable=False)
|
||||
api_base = Column(String(500), nullable=True)
|
||||
# For any other parameters that litellm supports
|
||||
litellm_params = Column(JSON, nullable=True, default={})
|
||||
|
||||
# === Prompt Configuration ===
|
||||
# Configurable system instructions (defaults to SURFSENSE_SYSTEM_INSTRUCTIONS)
|
||||
# Users can customize this from the UI
|
||||
system_instructions = Column(
|
||||
Text,
|
||||
nullable=False,
|
||||
default="", # Empty string means use default SURFSENSE_SYSTEM_INSTRUCTIONS
|
||||
)
|
||||
# Whether to use the default system instructions when system_instructions is empty
|
||||
use_default_system_instructions = Column(Boolean, nullable=False, default=True)
|
||||
|
||||
# Citation toggle - when enabled, SURFSENSE_CITATION_INSTRUCTIONS is injected
|
||||
# When disabled, an anti-citation prompt is injected instead
|
||||
citations_enabled = Column(Boolean, nullable=False, default=True)
|
||||
|
||||
# === Relationships ===
|
||||
search_space_id = Column(
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
search_space = relationship("SearchSpace", back_populates="new_llm_configs")
|
||||
|
||||
# User who created this config
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
user = relationship("User", back_populates="new_llm_configs")
|
||||
|
||||
|
||||
class Log(BaseModel, TimestampMixin):
|
||||
__tablename__ = "logs"
|
||||
|
||||
|
|
@ -2376,22 +2255,8 @@ if config.AUTH_TYPE == "GOOGLE":
|
|||
passive_deletes=True,
|
||||
)
|
||||
|
||||
# LLM configs created by this user
|
||||
new_llm_configs = relationship(
|
||||
"NewLLMConfig",
|
||||
back_populates="user",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
# Image generation configs created by this user
|
||||
image_generation_configs = relationship(
|
||||
"ImageGenerationConfig",
|
||||
back_populates="user",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
vision_llm_configs = relationship(
|
||||
"VisionLLMConfig",
|
||||
connections = relationship(
|
||||
"Connection",
|
||||
back_populates="user",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
|
@ -2522,22 +2387,8 @@ else:
|
|||
passive_deletes=True,
|
||||
)
|
||||
|
||||
# LLM configs created by this user
|
||||
new_llm_configs = relationship(
|
||||
"NewLLMConfig",
|
||||
back_populates="user",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
# Image generation configs created by this user
|
||||
image_generation_configs = relationship(
|
||||
"ImageGenerationConfig",
|
||||
back_populates="user",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
vision_llm_configs = relationship(
|
||||
"VisionLLMConfig",
|
||||
connections = relationship(
|
||||
"Connection",
|
||||
back_populates="user",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
|
@ -2867,7 +2718,11 @@ from app.automations.persistence import ( # noqa: E402, F401
|
|||
AutomationRun,
|
||||
AutomationTrigger,
|
||||
)
|
||||
from app.etl_pipeline.cache.persistence.models import CachedParse # noqa: E402, F401
|
||||
from app.file_storage.persistence import DocumentFile # noqa: E402, F401
|
||||
from app.indexing_pipeline.cache.persistence.models import ( # noqa: E402, F401
|
||||
CachedEmbeddingSet,
|
||||
)
|
||||
from app.notifications.persistence import Notification # noqa: E402, F401
|
||||
from app.podcasts.persistence import ( # noqa: E402, F401
|
||||
Podcast,
|
||||
|
|
|
|||
11
surfsense_backend/app/etl_pipeline/cache/__init__.py
vendored
Normal file
11
surfsense_backend/app/etl_pipeline/cache/__init__.py
vendored
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
"""Content-addressed reuse of expensive ETL parser output across workspaces."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.etl_pipeline.cache.cached_extraction import extract_with_cache
|
||||
from app.etl_pipeline.cache.service import EtlCacheService
|
||||
|
||||
__all__ = [
|
||||
"EtlCacheService",
|
||||
"extract_with_cache",
|
||||
]
|
||||
86
surfsense_backend/app/etl_pipeline/cache/cached_extraction.py
vendored
Normal file
86
surfsense_backend/app/etl_pipeline/cache/cached_extraction.py
vendored
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
"""Entry point: serve ETL parses from cache, parsing only on a miss."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
from app.config import config
|
||||
from app.etl_pipeline.cache.eligibility import is_parse_cacheable
|
||||
from app.etl_pipeline.cache.schemas import ParseKey
|
||||
from app.etl_pipeline.cache.service import EtlCacheService
|
||||
from app.etl_pipeline.cache.settings import load_etl_cache_settings
|
||||
from app.etl_pipeline.etl_document import EtlRequest, EtlResult
|
||||
from app.etl_pipeline.etl_pipeline_service import EtlPipelineService
|
||||
from app.observability import metrics
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_HASH_CHUNK = 1024 * 1024
|
||||
|
||||
|
||||
async def extract_with_cache(request: EtlRequest, *, vision_llm=None) -> EtlResult:
|
||||
"""Drop-in for ``EtlPipelineService.extract`` that reuses prior parser output."""
|
||||
settings = load_etl_cache_settings()
|
||||
|
||||
cacheable = is_parse_cacheable(
|
||||
filename=request.filename,
|
||||
etl_service=config.ETL_SERVICE,
|
||||
cache_enabled=settings.enabled,
|
||||
has_vision_llm=vision_llm is not None,
|
||||
)
|
||||
if not cacheable:
|
||||
return await EtlPipelineService(vision_llm=vision_llm).extract(request)
|
||||
|
||||
key = ParseKey.for_document(
|
||||
await asyncio.to_thread(_hash_file, request.file_path),
|
||||
etl_service=config.ETL_SERVICE,
|
||||
mode=request.processing_mode.value,
|
||||
version=settings.parser_version,
|
||||
)
|
||||
|
||||
cached_result = await _recall(key)
|
||||
if cached_result is not None:
|
||||
metrics.record_etl_cache_lookup(
|
||||
etl_service=key.etl_service, mode=key.mode, outcome="hit"
|
||||
)
|
||||
logger.debug("ETL cache hit for %s", key.source_sha256)
|
||||
return cached_result
|
||||
|
||||
metrics.record_etl_cache_lookup(
|
||||
etl_service=key.etl_service, mode=key.mode, outcome="miss"
|
||||
)
|
||||
result = await EtlPipelineService(vision_llm=vision_llm).extract(request)
|
||||
await _remember(key, result)
|
||||
return result
|
||||
|
||||
|
||||
async def _recall(key: ParseKey) -> EtlResult | None:
|
||||
# Caching is best-effort: any failure falls through to a normal parse.
|
||||
try:
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
|
||||
async with get_celery_session_maker()() as session:
|
||||
return await EtlCacheService(session).recall(key)
|
||||
except Exception:
|
||||
logger.warning("ETL cache recall failed; parsing fresh", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
async def _remember(key: ParseKey, result: EtlResult) -> None:
|
||||
try:
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
|
||||
async with get_celery_session_maker()() as session:
|
||||
await EtlCacheService(session).remember(key, result)
|
||||
except Exception:
|
||||
logger.warning("ETL cache write failed; result not cached", exc_info=True)
|
||||
|
||||
|
||||
def _hash_file(path: str) -> str:
|
||||
digest = hashlib.sha256()
|
||||
with open(path, "rb") as handle:
|
||||
for chunk in iter(lambda: handle.read(_HASH_CHUNK), b""):
|
||||
digest.update(chunk)
|
||||
return digest.hexdigest()
|
||||
28
surfsense_backend/app/etl_pipeline/cache/eligibility.py
vendored
Normal file
28
surfsense_backend/app/etl_pipeline/cache/eligibility.py
vendored
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
"""Gating rule: may this upload be served from / written to the parse cache?"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.etl_pipeline.file_classifier import FileCategory, classify_file
|
||||
|
||||
|
||||
def is_parse_cacheable(
|
||||
*,
|
||||
filename: str,
|
||||
etl_service: str | None,
|
||||
cache_enabled: bool,
|
||||
has_vision_llm: bool,
|
||||
) -> bool:
|
||||
"""Only deterministic document parses are shareable across workspaces.
|
||||
|
||||
Vision-LLM runs append model-generated content not captured by the cache key,
|
||||
and a missing ETL service means there is no document parser to key against --
|
||||
both bypass the cache. Non-document categories (plaintext, audio, images,
|
||||
direct-convert) are cheap or parser-agnostic and are handled outside it.
|
||||
"""
|
||||
if not cache_enabled:
|
||||
return False
|
||||
if has_vision_llm:
|
||||
return False
|
||||
if not etl_service:
|
||||
return False
|
||||
return classify_file(filename) == FileCategory.DOCUMENT
|
||||
9
surfsense_backend/app/etl_pipeline/cache/eviction/__init__.py
vendored
Normal file
9
surfsense_backend/app/etl_pipeline/cache/eviction/__init__.py
vendored
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
"""Background pruning of the parse cache by age and size budget."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .task import evict_etl_cache_task
|
||||
|
||||
__all__ = [
|
||||
"evict_etl_cache_task",
|
||||
]
|
||||
28
surfsense_backend/app/etl_pipeline/cache/eviction/policy.py
vendored
Normal file
28
surfsense_backend/app/etl_pipeline/cache/eviction/policy.py
vendored
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
"""Pure selection rules for which cached entries to drop."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
from app.etl_pipeline.cache.schemas import EvictionCandidate
|
||||
|
||||
|
||||
def select_over_budget(
|
||||
coldest_first: Iterable[EvictionCandidate],
|
||||
*,
|
||||
current_total_bytes: int,
|
||||
max_total_bytes: int,
|
||||
) -> list[EvictionCandidate]:
|
||||
"""Pick coldest entries until the footprint drops under the budget."""
|
||||
bytes_to_free = current_total_bytes - max_total_bytes
|
||||
if bytes_to_free <= 0:
|
||||
return []
|
||||
|
||||
chosen: list[EvictionCandidate] = []
|
||||
bytes_freed = 0
|
||||
for candidate in coldest_first:
|
||||
if bytes_freed >= bytes_to_free:
|
||||
break
|
||||
chosen.append(candidate)
|
||||
bytes_freed += candidate.size_bytes
|
||||
return chosen
|
||||
68
surfsense_backend/app/etl_pipeline/cache/eviction/task.py
vendored
Normal file
68
surfsense_backend/app/etl_pipeline/cache/eviction/task.py
vendored
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
"""Celery task that prunes the parse cache by TTL, then by size budget."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.etl_pipeline.cache.eviction.policy import select_over_budget
|
||||
from app.etl_pipeline.cache.persistence import CachedParseRepository
|
||||
from app.etl_pipeline.cache.schemas import EvictionCandidate
|
||||
from app.etl_pipeline.cache.settings import load_etl_cache_settings
|
||||
from app.etl_pipeline.cache.storage import MarkdownCacheStore
|
||||
from app.observability import metrics
|
||||
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@celery_app.task(name="evict_etl_cache")
|
||||
def evict_etl_cache_task():
|
||||
return run_async_celery_task(_evict)
|
||||
|
||||
|
||||
async def _evict() -> None:
|
||||
"""Expire stale entries, then shed the coldest overflow only if still over budget."""
|
||||
settings = load_etl_cache_settings()
|
||||
if not settings.enabled:
|
||||
return
|
||||
|
||||
store = MarkdownCacheStore()
|
||||
async with get_celery_session_maker()() as session:
|
||||
index = CachedParseRepository(session)
|
||||
|
||||
cutoff = datetime.now(UTC) - timedelta(days=settings.ttl_days)
|
||||
expired = await index.select_expired(
|
||||
cutoff=cutoff, limit=settings.eviction_batch
|
||||
)
|
||||
await _drop(index, store, expired, phase="ttl")
|
||||
|
||||
total = await index.total_size_bytes()
|
||||
if total > settings.max_total_bytes:
|
||||
coldest = await index.select_coldest(limit=settings.eviction_batch)
|
||||
over_budget = select_over_budget(
|
||||
coldest,
|
||||
current_total_bytes=total,
|
||||
max_total_bytes=settings.max_total_bytes,
|
||||
)
|
||||
await _drop(index, store, over_budget, phase="size")
|
||||
|
||||
|
||||
async def _drop(
|
||||
index: CachedParseRepository,
|
||||
store: MarkdownCacheStore,
|
||||
candidates: list[EvictionCandidate],
|
||||
*,
|
||||
phase: str,
|
||||
) -> None:
|
||||
if not candidates:
|
||||
return
|
||||
for candidate in candidates:
|
||||
# Drop the index row even if the blob delete fails (orphan blob is harmless).
|
||||
with contextlib.suppress(Exception):
|
||||
await store.delete(candidate.storage_key)
|
||||
await index.delete_by_ids([candidate.id for candidate in candidates])
|
||||
metrics.record_etl_cache_eviction(len(candidates), phase=phase)
|
||||
logger.info("Evicted %d cached parses (%s)", len(candidates), phase)
|
||||
11
surfsense_backend/app/etl_pipeline/cache/persistence/__init__.py
vendored
Normal file
11
surfsense_backend/app/etl_pipeline/cache/persistence/__init__.py
vendored
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
"""Database access for cached parse rows."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .models import CachedParse
|
||||
from .repository import CachedParseRepository
|
||||
|
||||
__all__ = [
|
||||
"CachedParse",
|
||||
"CachedParseRepository",
|
||||
]
|
||||
49
surfsense_backend/app/etl_pipeline/cache/persistence/models.py
vendored
Normal file
49
surfsense_backend/app/etl_pipeline/cache/persistence/models.py
vendored
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
"""``etl_cache_parses``: one reusable parser result per (bytes + recipe)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import (
|
||||
BigInteger,
|
||||
Column,
|
||||
DateTime,
|
||||
Index,
|
||||
Integer,
|
||||
String,
|
||||
UniqueConstraint,
|
||||
)
|
||||
|
||||
from app.db import BaseModel, TimestampMixin
|
||||
|
||||
|
||||
class CachedParse(BaseModel, TimestampMixin):
|
||||
__tablename__ = "etl_cache_parses"
|
||||
|
||||
# Key: raw bytes + the recipe that produced the markdown.
|
||||
source_sha256 = Column(String(64), nullable=False)
|
||||
etl_service = Column(String(32), nullable=False)
|
||||
mode = Column(String(16), nullable=False)
|
||||
parser_version = Column(Integer, nullable=False)
|
||||
|
||||
# Where the markdown blob lives (kept out of the row to stay small).
|
||||
storage_backend = Column(String(32), nullable=False)
|
||||
storage_key = Column(String, nullable=False)
|
||||
size_bytes = Column(BigInteger, nullable=False)
|
||||
|
||||
# Payload needed to rebuild the EtlResult on a hit.
|
||||
content_type = Column(String(32), nullable=False)
|
||||
actual_pages = Column(Integer, nullable=False, default=0, server_default="0")
|
||||
|
||||
# Drives eviction (popularity + recency).
|
||||
times_reused = Column(BigInteger, nullable=False, default=0, server_default="0")
|
||||
last_used_at = Column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"source_sha256",
|
||||
"etl_service",
|
||||
"mode",
|
||||
"parser_version",
|
||||
name="uq_etl_cache_parses_key",
|
||||
),
|
||||
Index("ix_etl_cache_parses_last_used_at", "last_used_at"),
|
||||
)
|
||||
121
surfsense_backend/app/etl_pipeline/cache/persistence/repository.py
vendored
Normal file
121
surfsense_backend/app/etl_pipeline/cache/persistence/repository.py
vendored
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
"""CRUD and eviction selectors for ``etl_cache_parses`` (no business rules)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import delete, func, select, update
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.etl_pipeline.cache.schemas import EvictionCandidate, ParseKey
|
||||
|
||||
from .models import CachedParse
|
||||
|
||||
_EVICTION_COLUMNS = (
|
||||
CachedParse.id,
|
||||
CachedParse.storage_key,
|
||||
CachedParse.size_bytes,
|
||||
CachedParse.last_used_at,
|
||||
CachedParse.times_reused,
|
||||
)
|
||||
|
||||
|
||||
def _as_eviction_candidate(row) -> EvictionCandidate:
|
||||
return EvictionCandidate(
|
||||
id=row.id,
|
||||
storage_key=row.storage_key,
|
||||
size_bytes=row.size_bytes,
|
||||
last_used_at=row.last_used_at,
|
||||
times_reused=row.times_reused,
|
||||
)
|
||||
|
||||
|
||||
class CachedParseRepository:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def get(self, key: ParseKey) -> CachedParse | None:
|
||||
result = await self._session.execute(
|
||||
select(CachedParse).where(
|
||||
CachedParse.source_sha256 == key.source_sha256,
|
||||
CachedParse.etl_service == key.etl_service,
|
||||
CachedParse.mode == key.mode,
|
||||
CachedParse.parser_version == key.version,
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def insert(
|
||||
self,
|
||||
*,
|
||||
key: ParseKey,
|
||||
content_type: str,
|
||||
actual_pages: int,
|
||||
storage_backend: str,
|
||||
storage_key: str,
|
||||
size_bytes: int,
|
||||
) -> None:
|
||||
# Concurrent writers parse identical bytes, so a lost race is harmless.
|
||||
now = datetime.now(UTC)
|
||||
await self._session.execute(
|
||||
pg_insert(CachedParse)
|
||||
.values(
|
||||
source_sha256=key.source_sha256,
|
||||
etl_service=key.etl_service,
|
||||
mode=key.mode,
|
||||
parser_version=key.version,
|
||||
content_type=content_type,
|
||||
actual_pages=actual_pages,
|
||||
storage_backend=storage_backend,
|
||||
storage_key=storage_key,
|
||||
size_bytes=size_bytes,
|
||||
times_reused=0,
|
||||
last_used_at=now,
|
||||
created_at=now,
|
||||
)
|
||||
.on_conflict_do_nothing(constraint="uq_etl_cache_parses_key")
|
||||
)
|
||||
await self._session.commit()
|
||||
|
||||
async def mark_used(self, row_id: int) -> None:
|
||||
await self._session.execute(
|
||||
update(CachedParse)
|
||||
.where(CachedParse.id == row_id)
|
||||
.values(
|
||||
times_reused=CachedParse.times_reused + 1,
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
)
|
||||
await self._session.commit()
|
||||
|
||||
async def total_size_bytes(self) -> int:
|
||||
result = await self._session.execute(
|
||||
select(func.coalesce(func.sum(CachedParse.size_bytes), 0))
|
||||
)
|
||||
return int(result.scalar() or 0)
|
||||
|
||||
async def select_expired(
|
||||
self, *, cutoff: datetime, limit: int
|
||||
) -> list[EvictionCandidate]:
|
||||
result = await self._session.execute(
|
||||
select(*_EVICTION_COLUMNS)
|
||||
.where(CachedParse.last_used_at < cutoff)
|
||||
.order_by(CachedParse.last_used_at.asc())
|
||||
.limit(limit)
|
||||
)
|
||||
return [_as_eviction_candidate(row) for row in result]
|
||||
|
||||
async def select_coldest(self, *, limit: int) -> list[EvictionCandidate]:
|
||||
result = await self._session.execute(
|
||||
select(*_EVICTION_COLUMNS)
|
||||
.order_by(CachedParse.times_reused.asc(), CachedParse.last_used_at.asc())
|
||||
.limit(limit)
|
||||
)
|
||||
return [_as_eviction_candidate(row) for row in result]
|
||||
|
||||
async def delete_by_ids(self, ids: list[int]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
await self._session.execute(delete(CachedParse).where(CachedParse.id.in_(ids)))
|
||||
await self._session.commit()
|
||||
11
surfsense_backend/app/etl_pipeline/cache/schemas/__init__.py
vendored
Normal file
11
surfsense_backend/app/etl_pipeline/cache/schemas/__init__.py
vendored
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
"""Pure value objects for the parse cache."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .eviction_candidate import EvictionCandidate
|
||||
from .parse_key import ParseKey
|
||||
|
||||
__all__ = [
|
||||
"EvictionCandidate",
|
||||
"ParseKey",
|
||||
]
|
||||
15
surfsense_backend/app/etl_pipeline/cache/schemas/eviction_candidate.py
vendored
Normal file
15
surfsense_backend/app/etl_pipeline/cache/schemas/eviction_candidate.py
vendored
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
"""Row projection handed to the eviction policy."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EvictionCandidate:
|
||||
id: int
|
||||
storage_key: str
|
||||
size_bytes: int
|
||||
last_used_at: datetime
|
||||
times_reused: int
|
||||
28
surfsense_backend/app/etl_pipeline/cache/schemas/parse_key.py
vendored
Normal file
28
surfsense_backend/app/etl_pipeline/cache/schemas/parse_key.py
vendored
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
"""Identity of a cacheable parse: equal keys yield identical markdown."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ParseKey:
|
||||
source_sha256: str
|
||||
etl_service: str
|
||||
mode: str
|
||||
version: int
|
||||
|
||||
@classmethod
|
||||
def for_document(
|
||||
cls, source_sha256: str, *, etl_service: str, mode: str, version: int
|
||||
) -> ParseKey:
|
||||
return cls(
|
||||
source_sha256=source_sha256,
|
||||
etl_service=etl_service,
|
||||
mode=mode,
|
||||
version=version,
|
||||
)
|
||||
|
||||
@property
|
||||
def object_suffix(self) -> str:
|
||||
return f"{self.etl_service}.{self.mode}.v{self.version}.md"
|
||||
53
surfsense_backend/app/etl_pipeline/cache/service.py
vendored
Normal file
53
surfsense_backend/app/etl_pipeline/cache/service.py
vendored
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
"""Recall and remember parser output, coordinating the index and blob store."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.etl_pipeline.cache.persistence import CachedParseRepository
|
||||
from app.etl_pipeline.cache.schemas import ParseKey
|
||||
from app.etl_pipeline.cache.storage import MarkdownCacheStore
|
||||
from app.etl_pipeline.etl_document import EtlResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EtlCacheService:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._index = CachedParseRepository(session)
|
||||
self._store = MarkdownCacheStore()
|
||||
|
||||
async def recall(self, key: ParseKey) -> EtlResult | None:
|
||||
"""Return the cached result, or None on a miss."""
|
||||
row = await self._index.get(key)
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
markdown = await self._store.load(row.storage_key)
|
||||
except Exception:
|
||||
# Index points at a blob that is gone; treat as a miss and re-parse.
|
||||
logger.warning("Cache blob missing: %s", row.storage_key, exc_info=True)
|
||||
return None
|
||||
|
||||
await self._index.mark_used(row.id)
|
||||
return EtlResult(
|
||||
markdown_content=markdown,
|
||||
etl_service=row.etl_service,
|
||||
actual_pages=row.actual_pages,
|
||||
content_type=row.content_type,
|
||||
)
|
||||
|
||||
async def remember(self, key: ParseKey, result: EtlResult) -> None:
|
||||
"""Store a freshly parsed result for future reuse."""
|
||||
storage_key = await self._store.save(key, result.markdown_content)
|
||||
await self._index.insert(
|
||||
key=key,
|
||||
content_type=result.content_type,
|
||||
actual_pages=result.actual_pages,
|
||||
storage_backend=self._store.backend_name,
|
||||
storage_key=storage_key,
|
||||
size_bytes=len(result.markdown_content.encode("utf-8")),
|
||||
)
|
||||
33
surfsense_backend/app/etl_pipeline/cache/settings.py
vendored
Normal file
33
surfsense_backend/app/etl_pipeline/cache/settings.py
vendored
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
"""Cache configuration resolved from the central ``Config``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EtlCacheSettings:
|
||||
enabled: bool
|
||||
parser_version: int
|
||||
ttl_days: int
|
||||
max_total_bytes: int
|
||||
eviction_batch: int
|
||||
# None for any storage_* field means: reuse the main file_storage backend.
|
||||
storage_backend: str | None
|
||||
storage_container: str | None
|
||||
storage_local_root: str | None
|
||||
|
||||
|
||||
def load_etl_cache_settings() -> EtlCacheSettings:
|
||||
from app.config import config
|
||||
|
||||
return EtlCacheSettings(
|
||||
enabled=config.ETL_CACHE_ENABLED,
|
||||
parser_version=config.ETL_CACHE_PARSER_VERSION,
|
||||
ttl_days=config.ETL_CACHE_TTL_DAYS,
|
||||
max_total_bytes=config.ETL_CACHE_MAX_TOTAL_MB * 1024 * 1024,
|
||||
eviction_batch=config.ETL_CACHE_EVICTION_BATCH,
|
||||
storage_backend=config.ETL_CACHE_STORAGE_BACKEND or None,
|
||||
storage_container=config.ETL_CACHE_STORAGE_CONTAINER or None,
|
||||
storage_local_root=config.ETL_CACHE_STORAGE_LOCAL_PATH or None,
|
||||
)
|
||||
9
surfsense_backend/app/etl_pipeline/cache/storage/__init__.py
vendored
Normal file
9
surfsense_backend/app/etl_pipeline/cache/storage/__init__.py
vendored
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
"""Blob storage for cached parse markdown."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .markdown_store import MarkdownCacheStore
|
||||
|
||||
__all__ = [
|
||||
"MarkdownCacheStore",
|
||||
]
|
||||
48
surfsense_backend/app/etl_pipeline/cache/storage/backend.py
vendored
Normal file
48
surfsense_backend/app/etl_pipeline/cache/storage/backend.py
vendored
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
"""Resolve the storage backend for cache blobs: shared main store or a dedicated one."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
|
||||
from app.file_storage.backends.base import StorageBackend
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def resolve_cache_backend() -> StorageBackend:
|
||||
from app.etl_pipeline.cache.settings import load_etl_cache_settings
|
||||
|
||||
settings = load_etl_cache_settings()
|
||||
|
||||
if not settings.storage_backend:
|
||||
from app.file_storage.factory import get_storage_backend
|
||||
|
||||
return get_storage_backend()
|
||||
|
||||
backend = settings.storage_backend.strip().lower()
|
||||
|
||||
if backend == "azure":
|
||||
from app.config import config
|
||||
|
||||
if not settings.storage_container:
|
||||
raise ValueError("ETL_CACHE_STORAGE_CONTAINER is required for azure cache.")
|
||||
if not config.AZURE_STORAGE_CONNECTION_STRING:
|
||||
raise ValueError(
|
||||
"AZURE_STORAGE_CONNECTION_STRING is required for azure cache."
|
||||
)
|
||||
from app.file_storage.backends.azure import AzureBlobBackend
|
||||
|
||||
return AzureBlobBackend(
|
||||
connection_string=config.AZURE_STORAGE_CONNECTION_STRING,
|
||||
container=settings.storage_container,
|
||||
)
|
||||
|
||||
if backend == "local":
|
||||
if not settings.storage_local_root:
|
||||
raise ValueError(
|
||||
"ETL_CACHE_STORAGE_LOCAL_PATH is required for local cache."
|
||||
)
|
||||
from app.file_storage.backends.local import LocalFileBackend
|
||||
|
||||
return LocalFileBackend(settings.storage_local_root)
|
||||
|
||||
raise ValueError(f"Unknown ETL_CACHE_STORAGE_BACKEND: {settings.storage_backend!r}")
|
||||
35
surfsense_backend/app/etl_pipeline/cache/storage/markdown_store.py
vendored
Normal file
35
surfsense_backend/app/etl_pipeline/cache/storage/markdown_store.py
vendored
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
"""Read and write cached markdown blobs through the resolved backend."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.etl_pipeline.cache.schemas import ParseKey
|
||||
from app.etl_pipeline.cache.storage.backend import resolve_cache_backend
|
||||
from app.etl_pipeline.cache.storage.object_keys import build_parse_object_key
|
||||
|
||||
_MARKDOWN_CONTENT_TYPE = "text/markdown; charset=utf-8"
|
||||
|
||||
|
||||
class MarkdownCacheStore:
|
||||
def __init__(self) -> None:
|
||||
self._backend = resolve_cache_backend()
|
||||
|
||||
@property
|
||||
def backend_name(self) -> str:
|
||||
return self._backend.backend_name
|
||||
|
||||
async def save(self, key: ParseKey, markdown: str) -> str:
|
||||
"""Persist the markdown and return its storage key for the index row."""
|
||||
storage_key = build_parse_object_key(key)
|
||||
await self._backend.put(
|
||||
storage_key,
|
||||
markdown.encode("utf-8"),
|
||||
content_type=_MARKDOWN_CONTENT_TYPE,
|
||||
)
|
||||
return storage_key
|
||||
|
||||
async def load(self, storage_key: str) -> str:
|
||||
chunks = [chunk async for chunk in self._backend.open_stream(storage_key)]
|
||||
return b"".join(chunks).decode("utf-8")
|
||||
|
||||
async def delete(self, storage_key: str) -> None:
|
||||
await self._backend.delete(storage_key)
|
||||
12
surfsense_backend/app/etl_pipeline/cache/storage/object_keys.py
vendored
Normal file
12
surfsense_backend/app/etl_pipeline/cache/storage/object_keys.py
vendored
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
"""Object keys for cached markdown, namespaced under a dedicated prefix."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.etl_pipeline.cache.schemas import ParseKey
|
||||
|
||||
CACHE_PREFIX = "etl_cache"
|
||||
|
||||
|
||||
def build_parse_object_key(key: ParseKey) -> str:
|
||||
# Content-addressed: identical bytes + recipe always map to the same key.
|
||||
return f"{CACHE_PREFIX}/{key.source_sha256}/{key.object_suffix}"
|
||||
|
|
@ -8,7 +8,7 @@ from app.config import config
|
|||
|
||||
|
||||
def require_gateway_enabled() -> None:
|
||||
"""FastAPI dependency that gates all gateway HTTP routes on the global flag.
|
||||
"""FastAPI dependency that gates gateway operational routes on the global flag.
|
||||
|
||||
Returns 404 (rather than 503) when ``GATEWAY_ENABLED`` is FALSE so that
|
||||
disabling the gateway makes its webhook/OAuth/pairing surface indistinguishable
|
||||
|
|
|
|||
11
surfsense_backend/app/indexing_pipeline/cache/__init__.py
vendored
Normal file
11
surfsense_backend/app/indexing_pipeline/cache/__init__.py
vendored
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
"""Content-addressed reuse of chunk+embedding output across workspaces."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.indexing_pipeline.cache.cached_indexing import build_chunk_embeddings
|
||||
from app.indexing_pipeline.cache.service import EmbeddingCacheService
|
||||
|
||||
__all__ = [
|
||||
"EmbeddingCacheService",
|
||||
"build_chunk_embeddings",
|
||||
]
|
||||
129
surfsense_backend/app/indexing_pipeline/cache/cached_indexing.py
vendored
Normal file
129
surfsense_backend/app/indexing_pipeline/cache/cached_indexing.py
vendored
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
"""Entry point: serve chunk embeddings from cache, embedding only on a miss.
|
||||
|
||||
Embeddings are a pure function of the markdown, the embedding model, and the
|
||||
chunker -- so identical markdown is chunked and embedded once and reused across
|
||||
workspaces, even when it came from different sources.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
from app.config import config
|
||||
from app.indexing_pipeline.cache.eligibility import is_embedding_cacheable
|
||||
from app.indexing_pipeline.cache.schemas import CachedChunk, EmbeddingKey, EmbeddingSet
|
||||
from app.indexing_pipeline.cache.service import EmbeddingCacheService
|
||||
from app.indexing_pipeline.cache.settings import load_embedding_cache_settings
|
||||
from app.indexing_pipeline.document_chunker import chunk_text, chunk_text_hybrid
|
||||
from app.indexing_pipeline.document_embedder import embed_texts
|
||||
from app.observability import metrics
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ChunkPair = tuple[str, np.ndarray]
|
||||
|
||||
|
||||
async def build_chunk_embeddings(
|
||||
markdown: str, *, use_code_chunker: bool
|
||||
) -> tuple[np.ndarray, list[ChunkPair]]:
|
||||
"""Return the document-level vector and ordered ``(chunk_text, vector)`` pairs.
|
||||
|
||||
Drop-in for the inline chunk+embed step; reuses prior output when the same
|
||||
markdown has already been embedded with the current model and chunker.
|
||||
"""
|
||||
settings = load_embedding_cache_settings()
|
||||
chunker_kind = "code" if use_code_chunker else "hybrid"
|
||||
embedding_dim = getattr(config.embedding_model_instance, "dimension", None)
|
||||
|
||||
cacheable = is_embedding_cacheable(
|
||||
cache_enabled=settings.enabled,
|
||||
embedding_model=config.EMBEDDING_MODEL,
|
||||
embedding_dim=embedding_dim,
|
||||
)
|
||||
if not cacheable:
|
||||
return await _compute(markdown, use_code_chunker=use_code_chunker)
|
||||
|
||||
key = EmbeddingKey(
|
||||
markdown_sha256=_hash_text(markdown),
|
||||
embedding_model=config.EMBEDDING_MODEL,
|
||||
embedding_dim=int(embedding_dim),
|
||||
chunker_kind=chunker_kind,
|
||||
chunker_version=settings.chunker_version,
|
||||
)
|
||||
|
||||
cached = await _recall(key)
|
||||
if cached is not None:
|
||||
metrics.record_embedding_cache_lookup(
|
||||
embedding_model=key.embedding_model,
|
||||
chunker_kind=chunker_kind,
|
||||
outcome="hit",
|
||||
)
|
||||
logger.debug("Embedding cache hit for %s", key.markdown_sha256)
|
||||
return cached.summary_embedding, [(c.text, c.embedding) for c in cached.chunks]
|
||||
|
||||
metrics.record_embedding_cache_lookup(
|
||||
embedding_model=key.embedding_model, chunker_kind=chunker_kind, outcome="miss"
|
||||
)
|
||||
summary_embedding, chunk_pairs = await _compute(
|
||||
markdown, use_code_chunker=use_code_chunker
|
||||
)
|
||||
await _remember(key, summary_embedding, chunk_pairs)
|
||||
return summary_embedding, chunk_pairs
|
||||
|
||||
|
||||
async def chunk_markdown(markdown: str, *, use_code_chunker: bool) -> list[str]:
|
||||
"""Chunk markdown into ordered texts with the pipeline's chunker selection."""
|
||||
if use_code_chunker:
|
||||
return await asyncio.to_thread(chunk_text, markdown, use_code_chunker=True)
|
||||
# Table-aware hybrid chunker keeps Markdown tables intact (issue #1334).
|
||||
return await asyncio.to_thread(chunk_text_hybrid, markdown)
|
||||
|
||||
|
||||
async def embed_batch(texts: list[str]) -> list[np.ndarray]:
|
||||
"""Embed texts in one batch off the event loop."""
|
||||
return await asyncio.to_thread(embed_texts, texts)
|
||||
|
||||
|
||||
async def _compute(
|
||||
markdown: str, *, use_code_chunker: bool
|
||||
) -> tuple[np.ndarray, list[ChunkPair]]:
|
||||
chunk_texts = await chunk_markdown(markdown, use_code_chunker=use_code_chunker)
|
||||
embeddings = await embed_batch([markdown, *chunk_texts])
|
||||
summary_embedding, *chunk_embeddings = embeddings
|
||||
return summary_embedding, list(zip(chunk_texts, chunk_embeddings, strict=False))
|
||||
|
||||
|
||||
async def _recall(key: EmbeddingKey) -> EmbeddingSet | None:
|
||||
# Caching is best-effort: any failure falls through to a normal embed.
|
||||
try:
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
|
||||
async with get_celery_session_maker()() as session:
|
||||
return await EmbeddingCacheService(session).recall(key)
|
||||
except Exception:
|
||||
logger.warning("Embedding cache recall failed; embedding fresh", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
async def _remember(
|
||||
key: EmbeddingKey, summary_embedding: np.ndarray, chunk_pairs: list[ChunkPair]
|
||||
) -> None:
|
||||
try:
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
|
||||
embedding_set = EmbeddingSet(
|
||||
summary_embedding=summary_embedding,
|
||||
chunks=[CachedChunk(text=text, embedding=vec) for text, vec in chunk_pairs],
|
||||
)
|
||||
async with get_celery_session_maker()() as session:
|
||||
await EmbeddingCacheService(session).remember(key, embedding_set)
|
||||
except Exception:
|
||||
logger.warning("Embedding cache write failed; result not cached", exc_info=True)
|
||||
|
||||
|
||||
def _hash_text(text: str) -> str:
|
||||
return hashlib.sha256(text.encode("utf-8")).hexdigest()
|
||||
21
surfsense_backend/app/indexing_pipeline/cache/eligibility.py
vendored
Normal file
21
surfsense_backend/app/indexing_pipeline/cache/eligibility.py
vendored
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
"""Gating rule: may this document be served from / written to the embedding cache?"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def is_embedding_cacheable(
|
||||
*,
|
||||
cache_enabled: bool,
|
||||
embedding_model: str | None,
|
||||
embedding_dim: int | None,
|
||||
) -> bool:
|
||||
"""Cache only when a concrete embedding model and dimension are configured.
|
||||
|
||||
Without a model there is nothing to key against, and without a dimension the
|
||||
blob's integrity guard cannot run -- both bypass the cache.
|
||||
"""
|
||||
if not cache_enabled:
|
||||
return False
|
||||
if not embedding_model:
|
||||
return False
|
||||
return bool(embedding_dim)
|
||||
9
surfsense_backend/app/indexing_pipeline/cache/eviction/__init__.py
vendored
Normal file
9
surfsense_backend/app/indexing_pipeline/cache/eviction/__init__.py
vendored
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
"""Background pruning of the embedding cache by age and size budget."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .task import evict_embedding_cache_task
|
||||
|
||||
__all__ = [
|
||||
"evict_embedding_cache_task",
|
||||
]
|
||||
68
surfsense_backend/app/indexing_pipeline/cache/eviction/task.py
vendored
Normal file
68
surfsense_backend/app/indexing_pipeline/cache/eviction/task.py
vendored
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
"""Celery task that prunes the embedding cache by TTL, then by size budget."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.etl_pipeline.cache.eviction.policy import select_over_budget
|
||||
from app.etl_pipeline.cache.schemas import EvictionCandidate
|
||||
from app.indexing_pipeline.cache.persistence import CachedEmbeddingSetRepository
|
||||
from app.indexing_pipeline.cache.settings import load_embedding_cache_settings
|
||||
from app.indexing_pipeline.cache.storage import EmbeddingCacheStore
|
||||
from app.observability import metrics
|
||||
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@celery_app.task(name="evict_embedding_cache")
|
||||
def evict_embedding_cache_task():
|
||||
return run_async_celery_task(_evict)
|
||||
|
||||
|
||||
async def _evict() -> None:
|
||||
"""Expire stale entries, then shed the coldest overflow only if still over budget."""
|
||||
settings = load_embedding_cache_settings()
|
||||
if not settings.enabled:
|
||||
return
|
||||
|
||||
store = EmbeddingCacheStore()
|
||||
async with get_celery_session_maker()() as session:
|
||||
index = CachedEmbeddingSetRepository(session)
|
||||
|
||||
cutoff = datetime.now(UTC) - timedelta(days=settings.ttl_days)
|
||||
expired = await index.select_expired(
|
||||
cutoff=cutoff, limit=settings.eviction_batch
|
||||
)
|
||||
await _drop(index, store, expired, phase="ttl")
|
||||
|
||||
total = await index.total_size_bytes()
|
||||
if total > settings.max_total_bytes:
|
||||
coldest = await index.select_coldest(limit=settings.eviction_batch)
|
||||
over_budget = select_over_budget(
|
||||
coldest,
|
||||
current_total_bytes=total,
|
||||
max_total_bytes=settings.max_total_bytes,
|
||||
)
|
||||
await _drop(index, store, over_budget, phase="size")
|
||||
|
||||
|
||||
async def _drop(
|
||||
index: CachedEmbeddingSetRepository,
|
||||
store: EmbeddingCacheStore,
|
||||
candidates: list[EvictionCandidate],
|
||||
*,
|
||||
phase: str,
|
||||
) -> None:
|
||||
if not candidates:
|
||||
return
|
||||
for candidate in candidates:
|
||||
# Drop the index row even if the blob delete fails (orphan blob is harmless).
|
||||
with contextlib.suppress(Exception):
|
||||
await store.delete(candidate.storage_key)
|
||||
await index.delete_by_ids([candidate.id for candidate in candidates])
|
||||
metrics.record_embedding_cache_eviction(len(candidates), phase=phase)
|
||||
logger.info("Evicted %d cached embedding sets (%s)", len(candidates), phase)
|
||||
11
surfsense_backend/app/indexing_pipeline/cache/persistence/__init__.py
vendored
Normal file
11
surfsense_backend/app/indexing_pipeline/cache/persistence/__init__.py
vendored
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
"""Database access for cached embedding sets."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .models import CachedEmbeddingSet
|
||||
from .repository import CachedEmbeddingSetRepository
|
||||
|
||||
__all__ = [
|
||||
"CachedEmbeddingSet",
|
||||
"CachedEmbeddingSetRepository",
|
||||
]
|
||||
47
surfsense_backend/app/indexing_pipeline/cache/persistence/models.py
vendored
Normal file
47
surfsense_backend/app/indexing_pipeline/cache/persistence/models.py
vendored
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
"""``embedding_cache_sets``: one reusable chunk+embedding set per markdown."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import (
|
||||
BigInteger,
|
||||
Column,
|
||||
DateTime,
|
||||
Index,
|
||||
Integer,
|
||||
String,
|
||||
UniqueConstraint,
|
||||
)
|
||||
|
||||
from app.db import BaseModel, TimestampMixin
|
||||
|
||||
|
||||
class CachedEmbeddingSet(BaseModel, TimestampMixin):
|
||||
__tablename__ = "embedding_cache_sets"
|
||||
|
||||
# Key: markdown text + the recipe that turned it into vectors.
|
||||
markdown_sha256 = Column(String(64), nullable=False)
|
||||
embedding_model = Column(String(255), nullable=False)
|
||||
embedding_dim = Column(Integer, nullable=False)
|
||||
chunker_kind = Column(String(8), nullable=False)
|
||||
chunker_version = Column(Integer, nullable=False)
|
||||
|
||||
# Where the embedding blob lives (kept out of the row to stay small).
|
||||
storage_backend = Column(String(32), nullable=False)
|
||||
storage_key = Column(String, nullable=False)
|
||||
size_bytes = Column(BigInteger, nullable=False)
|
||||
chunk_count = Column(Integer, nullable=False, default=0, server_default="0")
|
||||
|
||||
# Drives eviction (popularity + recency).
|
||||
times_reused = Column(BigInteger, nullable=False, default=0, server_default="0")
|
||||
last_used_at = Column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"markdown_sha256",
|
||||
"embedding_model",
|
||||
"chunker_kind",
|
||||
"chunker_version",
|
||||
name="uq_embedding_cache_sets_key",
|
||||
),
|
||||
Index("ix_embedding_cache_sets_last_used_at", "last_used_at"),
|
||||
)
|
||||
126
surfsense_backend/app/indexing_pipeline/cache/persistence/repository.py
vendored
Normal file
126
surfsense_backend/app/indexing_pipeline/cache/persistence/repository.py
vendored
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
"""CRUD and eviction selectors for ``embedding_cache_sets`` (no business rules)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import delete, func, select, update
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.etl_pipeline.cache.schemas import EvictionCandidate
|
||||
from app.indexing_pipeline.cache.schemas import EmbeddingKey
|
||||
|
||||
from .models import CachedEmbeddingSet
|
||||
|
||||
_EVICTION_COLUMNS = (
|
||||
CachedEmbeddingSet.id,
|
||||
CachedEmbeddingSet.storage_key,
|
||||
CachedEmbeddingSet.size_bytes,
|
||||
CachedEmbeddingSet.last_used_at,
|
||||
CachedEmbeddingSet.times_reused,
|
||||
)
|
||||
|
||||
|
||||
def _as_eviction_candidate(row) -> EvictionCandidate:
|
||||
return EvictionCandidate(
|
||||
id=row.id,
|
||||
storage_key=row.storage_key,
|
||||
size_bytes=row.size_bytes,
|
||||
last_used_at=row.last_used_at,
|
||||
times_reused=row.times_reused,
|
||||
)
|
||||
|
||||
|
||||
class CachedEmbeddingSetRepository:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def get(self, key: EmbeddingKey) -> CachedEmbeddingSet | None:
|
||||
result = await self._session.execute(
|
||||
select(CachedEmbeddingSet).where(
|
||||
CachedEmbeddingSet.markdown_sha256 == key.markdown_sha256,
|
||||
CachedEmbeddingSet.embedding_model == key.embedding_model,
|
||||
CachedEmbeddingSet.chunker_kind == key.chunker_kind,
|
||||
CachedEmbeddingSet.chunker_version == key.chunker_version,
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def insert(
|
||||
self,
|
||||
*,
|
||||
key: EmbeddingKey,
|
||||
storage_backend: str,
|
||||
storage_key: str,
|
||||
size_bytes: int,
|
||||
chunk_count: int,
|
||||
) -> None:
|
||||
# Concurrent writers embed identical markdown, so a lost race is harmless.
|
||||
now = datetime.now(UTC)
|
||||
await self._session.execute(
|
||||
pg_insert(CachedEmbeddingSet)
|
||||
.values(
|
||||
markdown_sha256=key.markdown_sha256,
|
||||
embedding_model=key.embedding_model,
|
||||
embedding_dim=key.embedding_dim,
|
||||
chunker_kind=key.chunker_kind,
|
||||
chunker_version=key.chunker_version,
|
||||
storage_backend=storage_backend,
|
||||
storage_key=storage_key,
|
||||
size_bytes=size_bytes,
|
||||
chunk_count=chunk_count,
|
||||
times_reused=0,
|
||||
last_used_at=now,
|
||||
created_at=now,
|
||||
)
|
||||
.on_conflict_do_nothing(constraint="uq_embedding_cache_sets_key")
|
||||
)
|
||||
await self._session.commit()
|
||||
|
||||
async def mark_used(self, row_id: int) -> None:
|
||||
await self._session.execute(
|
||||
update(CachedEmbeddingSet)
|
||||
.where(CachedEmbeddingSet.id == row_id)
|
||||
.values(
|
||||
times_reused=CachedEmbeddingSet.times_reused + 1,
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
)
|
||||
await self._session.commit()
|
||||
|
||||
async def total_size_bytes(self) -> int:
|
||||
result = await self._session.execute(
|
||||
select(func.coalesce(func.sum(CachedEmbeddingSet.size_bytes), 0))
|
||||
)
|
||||
return int(result.scalar() or 0)
|
||||
|
||||
async def select_expired(
|
||||
self, *, cutoff: datetime, limit: int
|
||||
) -> list[EvictionCandidate]:
|
||||
result = await self._session.execute(
|
||||
select(*_EVICTION_COLUMNS)
|
||||
.where(CachedEmbeddingSet.last_used_at < cutoff)
|
||||
.order_by(CachedEmbeddingSet.last_used_at.asc())
|
||||
.limit(limit)
|
||||
)
|
||||
return [_as_eviction_candidate(row) for row in result]
|
||||
|
||||
async def select_coldest(self, *, limit: int) -> list[EvictionCandidate]:
|
||||
result = await self._session.execute(
|
||||
select(*_EVICTION_COLUMNS)
|
||||
.order_by(
|
||||
CachedEmbeddingSet.times_reused.asc(),
|
||||
CachedEmbeddingSet.last_used_at.asc(),
|
||||
)
|
||||
.limit(limit)
|
||||
)
|
||||
return [_as_eviction_candidate(row) for row in result]
|
||||
|
||||
async def delete_by_ids(self, ids: list[int]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
await self._session.execute(
|
||||
delete(CachedEmbeddingSet).where(CachedEmbeddingSet.id.in_(ids))
|
||||
)
|
||||
await self._session.commit()
|
||||
12
surfsense_backend/app/indexing_pipeline/cache/schemas/__init__.py
vendored
Normal file
12
surfsense_backend/app/indexing_pipeline/cache/schemas/__init__.py
vendored
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
"""Pure value objects for the embedding cache."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .embedding_key import EmbeddingKey
|
||||
from .embedding_set import CachedChunk, EmbeddingSet
|
||||
|
||||
__all__ = [
|
||||
"CachedChunk",
|
||||
"EmbeddingKey",
|
||||
"EmbeddingSet",
|
||||
]
|
||||
27
surfsense_backend/app/indexing_pipeline/cache/schemas/embedding_key.py
vendored
Normal file
27
surfsense_backend/app/indexing_pipeline/cache/schemas/embedding_key.py
vendored
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
"""Identity of a cacheable embedding set: equal keys yield identical vectors.
|
||||
|
||||
Embeddings depend on the markdown text, the embedding model, and the chunker --
|
||||
never on how the markdown was produced. So the key is the markdown's own hash
|
||||
plus the model and chunker recipe, not the upstream parse identity.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EmbeddingKey:
|
||||
markdown_sha256: str
|
||||
embedding_model: str
|
||||
embedding_dim: int
|
||||
chunker_kind: str
|
||||
chunker_version: int
|
||||
|
||||
@property
|
||||
def object_suffix(self) -> str:
|
||||
# Fingerprint the model so distinct models never share a blob, while the
|
||||
# markdown hash (the object's folder) stays human-readable.
|
||||
fingerprint = hashlib.sha256(self.embedding_model.encode("utf-8")).hexdigest()
|
||||
return f"{fingerprint[:16]}.{self.chunker_kind}.v{self.chunker_version}.emb"
|
||||
29
surfsense_backend/app/indexing_pipeline/cache/schemas/embedding_set.py
vendored
Normal file
29
surfsense_backend/app/indexing_pipeline/cache/schemas/embedding_set.py
vendored
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
"""The cached payload: a document's chunk texts paired with their vectors."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class CachedChunk:
|
||||
text: str
|
||||
embedding: np.ndarray
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EmbeddingSet:
|
||||
"""Everything the indexer needs to rebuild a document's chunks without embedding.
|
||||
|
||||
``summary_embedding`` is the document-level vector; ``chunks`` are the ordered
|
||||
chunk texts and their vectors.
|
||||
"""
|
||||
|
||||
summary_embedding: np.ndarray
|
||||
chunks: list[CachedChunk]
|
||||
|
||||
@property
|
||||
def chunk_count(self) -> int:
|
||||
return len(self.chunks)
|
||||
75
surfsense_backend/app/indexing_pipeline/cache/serialization.py
vendored
Normal file
75
surfsense_backend/app/indexing_pipeline/cache/serialization.py
vendored
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
"""Serialize an EmbeddingSet to a compact, self-describing blob (no pickle).
|
||||
|
||||
Layout: ``MAGIC | uint32 header_len | json header | float32 matrix``. The header
|
||||
carries the dim, chunk count, and ordered chunk texts; the matrix holds the
|
||||
summary vector followed by one row per chunk, all float32 for compactness.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import struct
|
||||
|
||||
import numpy as np
|
||||
|
||||
from app.indexing_pipeline.cache.schemas import CachedChunk, EmbeddingSet
|
||||
|
||||
# Marker at the start of every blob: "SurfSense EMBeddings, version 1"-> SSEMB1. Lets us
|
||||
# reject foreign blobs and bump the trailing digit if the layout ever changes.
|
||||
_MAGIC = b"SSEMB1"
|
||||
# 4-byte big-endian unsigned int written before the variable-length JSON header,
|
||||
# so the reader knows where the header ends and the float matrix begins.
|
||||
_HEADER_LEN = struct.Struct(">I")
|
||||
|
||||
|
||||
def serialize(embedding_set: EmbeddingSet) -> bytes:
|
||||
summary = np.asarray(embedding_set.summary_embedding, dtype=np.float32).reshape(-1)
|
||||
dim = int(summary.shape[0])
|
||||
|
||||
rows = [summary]
|
||||
texts: list[str] = []
|
||||
for chunk in embedding_set.chunks:
|
||||
vector = np.asarray(chunk.embedding, dtype=np.float32).reshape(-1)
|
||||
if vector.shape[0] != dim:
|
||||
raise ValueError(
|
||||
"All vectors in an embedding set must share one dimension."
|
||||
)
|
||||
rows.append(vector)
|
||||
texts.append(chunk.text)
|
||||
|
||||
matrix = np.stack(rows, axis=0)
|
||||
header = json.dumps(
|
||||
{"dim": dim, "count": len(texts), "texts": texts}, ensure_ascii=False
|
||||
).encode("utf-8")
|
||||
return b"".join(
|
||||
[_MAGIC, _HEADER_LEN.pack(len(header)), header, matrix.tobytes(order="C")]
|
||||
)
|
||||
|
||||
|
||||
def deserialize(blob: bytes) -> EmbeddingSet:
|
||||
view = memoryview(blob)
|
||||
if bytes(view[: len(_MAGIC)]) != _MAGIC:
|
||||
raise ValueError("Unrecognized embedding cache blob.")
|
||||
|
||||
offset = len(_MAGIC)
|
||||
(header_len,) = _HEADER_LEN.unpack(view[offset : offset + _HEADER_LEN.size])
|
||||
offset += _HEADER_LEN.size
|
||||
|
||||
header = json.loads(bytes(view[offset : offset + header_len]).decode("utf-8"))
|
||||
offset += header_len
|
||||
|
||||
dim = int(header["dim"])
|
||||
count = int(header["count"])
|
||||
texts: list[str] = header["texts"]
|
||||
|
||||
matrix = np.frombuffer(view[offset:], dtype=np.float32)
|
||||
if matrix.shape[0] != (count + 1) * dim:
|
||||
raise ValueError("Embedding cache blob is truncated or corrupt.")
|
||||
matrix = matrix.reshape(count + 1, dim)
|
||||
|
||||
return EmbeddingSet(
|
||||
summary_embedding=matrix[0],
|
||||
chunks=[
|
||||
CachedChunk(text=texts[i], embedding=matrix[i + 1]) for i in range(count)
|
||||
],
|
||||
)
|
||||
51
surfsense_backend/app/indexing_pipeline/cache/service.py
vendored
Normal file
51
surfsense_backend/app/indexing_pipeline/cache/service.py
vendored
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
"""Recall and remember embedding sets, coordinating the index and blob store."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.indexing_pipeline.cache.persistence import CachedEmbeddingSetRepository
|
||||
from app.indexing_pipeline.cache.schemas import EmbeddingKey, EmbeddingSet
|
||||
from app.indexing_pipeline.cache.storage import EmbeddingCacheStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingCacheService:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._index = CachedEmbeddingSetRepository(session)
|
||||
self._store = EmbeddingCacheStore()
|
||||
|
||||
async def recall(self, key: EmbeddingKey) -> EmbeddingSet | None:
|
||||
"""Return the cached embedding set, or None on a miss."""
|
||||
row = await self._index.get(key)
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
embedding_set = await self._store.load(row.storage_key)
|
||||
except Exception:
|
||||
# Index points at a blob that is gone; treat as a miss and re-embed.
|
||||
logger.warning("Cache blob missing: %s", row.storage_key, exc_info=True)
|
||||
return None
|
||||
|
||||
if int(embedding_set.summary_embedding.shape[0]) != key.embedding_dim:
|
||||
# A model swapped its dimension under a reused name; never serve it.
|
||||
logger.warning("Cached embedding dimension mismatch: %s", row.storage_key)
|
||||
return None
|
||||
|
||||
await self._index.mark_used(row.id)
|
||||
return embedding_set
|
||||
|
||||
async def remember(self, key: EmbeddingKey, embedding_set: EmbeddingSet) -> None:
|
||||
"""Store a freshly embedded set for future reuse."""
|
||||
storage_key, size_bytes = await self._store.save(key, embedding_set)
|
||||
await self._index.insert(
|
||||
key=key,
|
||||
storage_backend=self._store.backend_name,
|
||||
storage_key=storage_key,
|
||||
size_bytes=size_bytes,
|
||||
chunk_count=embedding_set.chunk_count,
|
||||
)
|
||||
30
surfsense_backend/app/indexing_pipeline/cache/settings.py
vendored
Normal file
30
surfsense_backend/app/indexing_pipeline/cache/settings.py
vendored
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
"""Embedding-cache configuration resolved from the central ``Config``.
|
||||
|
||||
The blob backend is intentionally not configured here: it is shared with the ETL
|
||||
parse cache (see ``ETL_CACHE_STORAGE_*``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EmbeddingCacheSettings:
|
||||
enabled: bool
|
||||
chunker_version: int
|
||||
ttl_days: int
|
||||
max_total_bytes: int
|
||||
eviction_batch: int
|
||||
|
||||
|
||||
def load_embedding_cache_settings() -> EmbeddingCacheSettings:
|
||||
from app.config import config
|
||||
|
||||
return EmbeddingCacheSettings(
|
||||
enabled=config.EMBEDDING_CACHE_ENABLED,
|
||||
chunker_version=config.EMBEDDING_CACHE_CHUNKER_VERSION,
|
||||
ttl_days=config.EMBEDDING_CACHE_TTL_DAYS,
|
||||
max_total_bytes=config.EMBEDDING_CACHE_MAX_TOTAL_MB * 1024 * 1024,
|
||||
eviction_batch=config.EMBEDDING_CACHE_EVICTION_BATCH,
|
||||
)
|
||||
9
surfsense_backend/app/indexing_pipeline/cache/storage/__init__.py
vendored
Normal file
9
surfsense_backend/app/indexing_pipeline/cache/storage/__init__.py
vendored
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
"""Blob storage for cached embedding sets."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .embedding_store import EmbeddingCacheStore
|
||||
|
||||
__all__ = [
|
||||
"EmbeddingCacheStore",
|
||||
]
|
||||
39
surfsense_backend/app/indexing_pipeline/cache/storage/embedding_store.py
vendored
Normal file
39
surfsense_backend/app/indexing_pipeline/cache/storage/embedding_store.py
vendored
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
"""Read and write cached embedding blobs through the shared cache backend.
|
||||
|
||||
The blob backend is shared with the ETL parse cache (same bucket / root), so
|
||||
markdown and its embeddings live side by side; only the object prefix differs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.etl_pipeline.cache.storage.backend import resolve_cache_backend
|
||||
from app.indexing_pipeline.cache.schemas import EmbeddingKey, EmbeddingSet
|
||||
from app.indexing_pipeline.cache.serialization import deserialize, serialize
|
||||
from app.indexing_pipeline.cache.storage.object_keys import build_embedding_object_key
|
||||
|
||||
_EMBEDDING_CONTENT_TYPE = "application/octet-stream"
|
||||
|
||||
|
||||
class EmbeddingCacheStore:
|
||||
def __init__(self) -> None:
|
||||
self._backend = resolve_cache_backend()
|
||||
|
||||
@property
|
||||
def backend_name(self) -> str:
|
||||
return self._backend.backend_name
|
||||
|
||||
async def save(
|
||||
self, key: EmbeddingKey, embedding_set: EmbeddingSet
|
||||
) -> tuple[str, int]:
|
||||
"""Persist the embedding set and return its storage key and byte size."""
|
||||
blob = serialize(embedding_set)
|
||||
storage_key = build_embedding_object_key(key)
|
||||
await self._backend.put(storage_key, blob, content_type=_EMBEDDING_CONTENT_TYPE)
|
||||
return storage_key, len(blob)
|
||||
|
||||
async def load(self, storage_key: str) -> EmbeddingSet:
|
||||
chunks = [chunk async for chunk in self._backend.open_stream(storage_key)]
|
||||
return deserialize(b"".join(chunks))
|
||||
|
||||
async def delete(self, storage_key: str) -> None:
|
||||
await self._backend.delete(storage_key)
|
||||
12
surfsense_backend/app/indexing_pipeline/cache/storage/object_keys.py
vendored
Normal file
12
surfsense_backend/app/indexing_pipeline/cache/storage/object_keys.py
vendored
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
"""Object keys for cached embedding sets, namespaced under a dedicated prefix."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.indexing_pipeline.cache.schemas import EmbeddingKey
|
||||
|
||||
CACHE_PREFIX = "embedding_cache"
|
||||
|
||||
|
||||
def build_embedding_object_key(key: EmbeddingKey) -> str:
|
||||
# Content-addressed: identical markdown + recipe always map to the same key.
|
||||
return f"{CACHE_PREFIX}/{key.markdown_sha256}/{key.object_suffix}"
|
||||
56
surfsense_backend/app/indexing_pipeline/chunk_reconciler.py
Normal file
56
surfsense_backend/app/indexing_pipeline/chunk_reconciler.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
"""Diff a document's existing chunk rows against its freshly chunked texts.
|
||||
|
||||
Embeddings are a pure function of chunk text, so a row whose content reappears
|
||||
in the new chunking keeps its embedding (and its HNSW/GIN index entries); only
|
||||
genuinely new texts are embedded and only vanished rows are deleted. Matching
|
||||
is a greedy multiset match on content in document order, so duplicate
|
||||
boilerplate chunks pair up one-to-one and reordered chunks become cheap
|
||||
position updates instead of delete+reinsert.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ExistingChunk:
|
||||
id: int
|
||||
content: str
|
||||
position: int
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ChunkPlan:
|
||||
"""The minimal set of writes that turns the stored chunks into the new ones.
|
||||
|
||||
``reused`` holds only kept rows whose position actually changed; rows that
|
||||
match in place need no write at all. Kept-row count (for metrics) is
|
||||
``len(existing) - len(to_delete)``.
|
||||
"""
|
||||
|
||||
reused: list[tuple[int, int]] # (existing_chunk_id, new_position)
|
||||
to_embed: list[tuple[int, str]] # (new_position, text)
|
||||
to_delete: list[int] # existing chunk ids
|
||||
|
||||
|
||||
def reconcile(existing: list[ExistingChunk], new_texts: list[str]) -> ChunkPlan:
|
||||
available: dict[str, deque[ExistingChunk]] = defaultdict(deque)
|
||||
for chunk in sorted(existing, key=lambda c: c.position):
|
||||
available[chunk.content].append(chunk)
|
||||
|
||||
reused: list[tuple[int, int]] = []
|
||||
to_embed: list[tuple[int, str]] = []
|
||||
|
||||
for new_position, text in enumerate(new_texts):
|
||||
matches = available.get(text)
|
||||
if matches:
|
||||
chunk = matches.popleft()
|
||||
if chunk.position != new_position:
|
||||
reused.append((chunk.id, new_position))
|
||||
else:
|
||||
to_embed.append((new_position, text))
|
||||
|
||||
to_delete = [chunk.id for queue in available.values() for chunk in queue]
|
||||
return ChunkPlan(reused=reused, to_embed=to_embed, to_delete=to_delete)
|
||||
|
|
@ -1,12 +1,12 @@
|
|||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import object_session
|
||||
from sqlalchemy.orm.attributes import set_committed_value
|
||||
|
||||
from app.db import Document, DocumentStatus
|
||||
from app.db import Chunk, Document, DocumentStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -22,7 +22,6 @@ async def rollback_and_persist_failure(
|
|||
try:
|
||||
await session.rollback()
|
||||
except Exception:
|
||||
# Session is completely dead; surface it but never raise.
|
||||
logger.warning(
|
||||
"Rollback failed; cannot persist failed status for document %s",
|
||||
getattr(document, "id", "unknown"),
|
||||
|
|
@ -35,8 +34,6 @@ async def rollback_and_persist_failure(
|
|||
document.status = DocumentStatus.failed(message)
|
||||
await session.commit()
|
||||
except Exception:
|
||||
# Best-effort: the document stays non-ready and is retried next sync.
|
||||
# Log it so a permanently-stuck document is at least traceable.
|
||||
logger.warning(
|
||||
"Could not persist failed status for document %s; will retry next sync",
|
||||
getattr(document, "id", "unknown"),
|
||||
|
|
@ -46,12 +43,60 @@ async def rollback_and_persist_failure(
|
|||
await session.rollback()
|
||||
|
||||
|
||||
def attach_chunks_to_document(document: Document, chunks: list) -> None:
|
||||
"""Assign chunks to a document without triggering SQLAlchemy async lazy loading."""
|
||||
async def persist_scratch_index(
|
||||
session: AsyncSession,
|
||||
document: Document,
|
||||
content: str,
|
||||
chunks: list[Chunk],
|
||||
*,
|
||||
batch_size: int,
|
||||
perf: logging.Logger,
|
||||
) -> None:
|
||||
"""Commit document content first, then chunk rows in batches, then mark ready."""
|
||||
if document.id is None:
|
||||
raise ValueError("document.id is required to persist chunks")
|
||||
|
||||
document.content = content
|
||||
document.updated_at = datetime.now(UTC)
|
||||
await session.commit()
|
||||
|
||||
t_persist = time.perf_counter()
|
||||
total = len(chunks)
|
||||
if total == 0:
|
||||
set_committed_value(document, "chunks", [])
|
||||
document.status = DocumentStatus.ready()
|
||||
document.updated_at = datetime.now(UTC)
|
||||
await session.commit()
|
||||
return
|
||||
|
||||
effective_batch = total if batch_size <= 0 else batch_size
|
||||
num_batches = (total + effective_batch - 1) // effective_batch
|
||||
doc_id = document.id
|
||||
|
||||
for batch_idx, start in enumerate(range(0, total, effective_batch), start=1):
|
||||
batch = chunks[start : start + effective_batch]
|
||||
t_batch = time.perf_counter()
|
||||
for chunk in batch:
|
||||
chunk.document_id = doc_id
|
||||
session.add_all(batch)
|
||||
await session.commit()
|
||||
perf.info(
|
||||
"[indexing] chunk batch doc=%d batch=%d/%d rows=%d in %.3fs",
|
||||
doc_id,
|
||||
batch_idx,
|
||||
num_batches,
|
||||
len(batch),
|
||||
time.perf_counter() - t_batch,
|
||||
)
|
||||
|
||||
set_committed_value(document, "chunks", chunks)
|
||||
session = object_session(document)
|
||||
if session is not None:
|
||||
if document.id is not None:
|
||||
for chunk in chunks:
|
||||
chunk.document_id = document.id
|
||||
session.add_all(chunks)
|
||||
document.status = DocumentStatus.ready()
|
||||
document.updated_at = datetime.now(UTC)
|
||||
await session.commit()
|
||||
perf.info(
|
||||
"[indexing] chunk persist doc=%d chunks=%d batches=%d in %.3fs",
|
||||
doc_id,
|
||||
total,
|
||||
num_batches,
|
||||
time.perf_counter() - t_persist,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@ from litellm.exceptions import (
|
|||
)
|
||||
from sqlalchemy.exc import IntegrityError as IntegrityError
|
||||
|
||||
from app.services.llm_error_adapter import LLMErrorCategory, adapt_llm_exception
|
||||
|
||||
# Tuples for use directly in except clauses.
|
||||
RETRYABLE_LLM_ERRORS = (
|
||||
RateLimitError,
|
||||
|
|
@ -97,38 +99,20 @@ def safe_exception_message(exc: Exception) -> str:
|
|||
|
||||
def llm_retryable_message(exc: Exception) -> str:
|
||||
try:
|
||||
if isinstance(exc, RateLimitError):
|
||||
return PipelineMessages.RATE_LIMIT
|
||||
if isinstance(exc, Timeout):
|
||||
return PipelineMessages.LLM_TIMEOUT
|
||||
if isinstance(exc, ServiceUnavailableError):
|
||||
return PipelineMessages.LLM_UNAVAILABLE
|
||||
if isinstance(exc, BadGatewayError):
|
||||
return PipelineMessages.LLM_BAD_GATEWAY
|
||||
if isinstance(exc, InternalServerError):
|
||||
return PipelineMessages.LLM_SERVER_ERROR
|
||||
if isinstance(exc, APIConnectionError):
|
||||
return PipelineMessages.LLM_CONNECTION
|
||||
return safe_exception_message(exc)
|
||||
adapted = adapt_llm_exception(exc)
|
||||
if adapted.category is LLMErrorCategory.UNKNOWN:
|
||||
return safe_exception_message(exc)
|
||||
return adapted.user_message
|
||||
except Exception:
|
||||
return "Something went wrong when calling the LLM."
|
||||
|
||||
|
||||
def llm_permanent_message(exc: Exception) -> str:
|
||||
try:
|
||||
if isinstance(exc, AuthenticationError):
|
||||
return PipelineMessages.LLM_AUTH
|
||||
if isinstance(exc, PermissionDeniedError):
|
||||
return PipelineMessages.LLM_PERMISSION
|
||||
if isinstance(exc, NotFoundError):
|
||||
return PipelineMessages.LLM_NOT_FOUND
|
||||
if isinstance(exc, BadRequestError):
|
||||
return PipelineMessages.LLM_BAD_REQUEST
|
||||
if isinstance(exc, UnprocessableEntityError):
|
||||
return PipelineMessages.LLM_UNPROCESSABLE
|
||||
if isinstance(exc, APIResponseValidationError):
|
||||
return PipelineMessages.LLM_RESPONSE
|
||||
return safe_exception_message(exc)
|
||||
adapted = adapt_llm_exception(exc)
|
||||
if adapted.category is LLMErrorCategory.UNKNOWN:
|
||||
return safe_exception_message(exc)
|
||||
return adapted.user_message
|
||||
except Exception:
|
||||
return "Something went wrong when calling the LLM."
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from collections.abc import Awaitable, Callable
|
|||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy import delete, select, update
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
|
@ -19,16 +19,17 @@ from app.db import (
|
|||
DocumentStatus,
|
||||
DocumentType,
|
||||
)
|
||||
from app.indexing_pipeline.cache import build_chunk_embeddings
|
||||
from app.indexing_pipeline.cache.cached_indexing import chunk_markdown, embed_batch
|
||||
from app.indexing_pipeline.chunk_reconciler import ExistingChunk, reconcile
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
from app.indexing_pipeline.document_chunker import chunk_text, chunk_text_hybrid
|
||||
from app.indexing_pipeline.document_embedder import embed_texts
|
||||
from app.indexing_pipeline.document_hashing import (
|
||||
compute_content_hash,
|
||||
compute_identifier_hash,
|
||||
compute_unique_identifier_hash,
|
||||
)
|
||||
from app.indexing_pipeline.document_persistence import (
|
||||
attach_chunks_to_document,
|
||||
persist_scratch_index,
|
||||
rollback_and_persist_failure,
|
||||
)
|
||||
from app.indexing_pipeline.exceptions import (
|
||||
|
|
@ -380,53 +381,50 @@ class IndexingPipelineService:
|
|||
|
||||
content = connector_doc.source_markdown
|
||||
|
||||
await self.session.execute(
|
||||
delete(Chunk).where(Chunk.document_id == document.id)
|
||||
)
|
||||
|
||||
t_step = time.perf_counter()
|
||||
if connector_doc.should_use_code_chunker:
|
||||
chunk_texts = await asyncio.to_thread(
|
||||
chunk_text,
|
||||
connector_doc.source_markdown,
|
||||
use_code_chunker=True,
|
||||
existing = await self._load_existing_chunks(document.id)
|
||||
if existing and self._reconcile_enabled():
|
||||
chunk_count = await self._reindex_incrementally(
|
||||
document, content, connector_doc, existing
|
||||
)
|
||||
perf.info(
|
||||
"[indexing] chunk+embed doc=%d chunks=%d in %.3fs",
|
||||
document.id,
|
||||
chunk_count,
|
||||
time.perf_counter() - t_step,
|
||||
)
|
||||
document.content = content
|
||||
document.updated_at = datetime.now(UTC)
|
||||
document.status = DocumentStatus.ready()
|
||||
await self.session.commit()
|
||||
else:
|
||||
# Use the table-aware hybrid chunker so Markdown tables are not
|
||||
# split mid-row (see issue #1334).
|
||||
chunk_texts = await asyncio.to_thread(
|
||||
chunk_text_hybrid,
|
||||
connector_doc.source_markdown,
|
||||
from app.config import config
|
||||
|
||||
chunks = await self._reindex_from_scratch(
|
||||
document, content, connector_doc
|
||||
)
|
||||
chunk_count = len(chunks)
|
||||
perf.info(
|
||||
"[indexing] chunk+embed doc=%d chunks=%d in %.3fs",
|
||||
document.id,
|
||||
chunk_count,
|
||||
time.perf_counter() - t_step,
|
||||
)
|
||||
await persist_scratch_index(
|
||||
self.session,
|
||||
document,
|
||||
content,
|
||||
chunks,
|
||||
batch_size=config.INDEXING_CHUNK_INSERT_BATCH_SIZE,
|
||||
perf=perf,
|
||||
)
|
||||
|
||||
texts_to_embed = [content, *chunk_texts]
|
||||
embeddings = await asyncio.to_thread(embed_texts, texts_to_embed)
|
||||
summary_embedding, *chunk_embeddings = embeddings
|
||||
|
||||
chunks = [
|
||||
Chunk(content=text, embedding=emb)
|
||||
for text, emb in zip(chunk_texts, chunk_embeddings, strict=False)
|
||||
]
|
||||
perf.info(
|
||||
"[indexing] chunk+embed doc=%d chunks=%d in %.3fs",
|
||||
document.id,
|
||||
len(chunks),
|
||||
time.perf_counter() - t_step,
|
||||
)
|
||||
|
||||
document.content = content
|
||||
document.embedding = summary_embedding
|
||||
attach_chunks_to_document(document, chunks)
|
||||
document.updated_at = datetime.now(UTC)
|
||||
document.status = DocumentStatus.ready()
|
||||
await self.session.commit()
|
||||
perf.info(
|
||||
"[indexing] index TOTAL doc=%d chunks=%d in %.3fs",
|
||||
document.id,
|
||||
len(chunks),
|
||||
chunk_count,
|
||||
time.perf_counter() - t_index,
|
||||
)
|
||||
log_index_success(ctx, chunk_count=len(chunks))
|
||||
log_index_success(ctx, chunk_count=chunk_count)
|
||||
outcome_status = "success"
|
||||
|
||||
await self._enqueue_ai_sort_if_enabled(document)
|
||||
|
|
@ -483,6 +481,89 @@ class IndexingPipelineService:
|
|||
persist_span_cm.__exit__(*sys.exc_info())
|
||||
return document
|
||||
|
||||
@staticmethod
|
||||
def _reconcile_enabled() -> bool:
|
||||
from app.config import config
|
||||
|
||||
return config.CHUNK_RECONCILE_ENABLED
|
||||
|
||||
async def _load_existing_chunks(self, document_id: int) -> list[ExistingChunk]:
|
||||
result = await self.session.execute(
|
||||
select(Chunk.id, Chunk.content, Chunk.position).where(
|
||||
Chunk.document_id == document_id
|
||||
)
|
||||
)
|
||||
return [
|
||||
ExistingChunk(id=row.id, content=row.content, position=row.position)
|
||||
for row in result
|
||||
]
|
||||
|
||||
async def _reindex_from_scratch(
|
||||
self, document: Document, content: str, connector_doc: ConnectorDocument
|
||||
) -> list[Chunk]:
|
||||
await self.session.execute(
|
||||
delete(Chunk).where(Chunk.document_id == document.id)
|
||||
)
|
||||
|
||||
summary_embedding, chunk_pairs = await build_chunk_embeddings(
|
||||
content,
|
||||
use_code_chunker=connector_doc.should_use_code_chunker,
|
||||
)
|
||||
|
||||
document.embedding = summary_embedding
|
||||
return [
|
||||
Chunk(content=text, embedding=emb, position=i)
|
||||
for i, (text, emb) in enumerate(chunk_pairs)
|
||||
]
|
||||
|
||||
async def _reindex_incrementally(
|
||||
self,
|
||||
document: Document,
|
||||
content: str,
|
||||
connector_doc: ConnectorDocument,
|
||||
existing: list[ExistingChunk],
|
||||
) -> int:
|
||||
"""Edit path: keep rows whose text survived, embed only new texts.
|
||||
|
||||
Unchanged rows keep their embedding and their HNSW/GIN index entries;
|
||||
moved rows get a position-only UPDATE, which touches neither index.
|
||||
"""
|
||||
new_texts = await chunk_markdown(
|
||||
content, use_code_chunker=connector_doc.should_use_code_chunker
|
||||
)
|
||||
plan = reconcile(existing, new_texts)
|
||||
|
||||
# One batch: the document-level summary vector plus the missing chunks.
|
||||
embeddings = await embed_batch([content, *[t for _, t in plan.to_embed]])
|
||||
summary_embedding, *new_embeddings = embeddings
|
||||
|
||||
if plan.reused:
|
||||
await self.session.execute(
|
||||
update(Chunk),
|
||||
[{"id": cid, "position": pos} for cid, pos in plan.reused],
|
||||
)
|
||||
if plan.to_delete:
|
||||
await self.session.execute(
|
||||
delete(Chunk).where(Chunk.id.in_(plan.to_delete))
|
||||
)
|
||||
self.session.add_all(
|
||||
Chunk(
|
||||
content=text,
|
||||
embedding=emb,
|
||||
position=pos,
|
||||
document_id=document.id,
|
||||
)
|
||||
for (pos, text), emb in zip(plan.to_embed, new_embeddings, strict=True)
|
||||
)
|
||||
document.embedding = summary_embedding
|
||||
|
||||
ot_metrics.record_chunk_reconcile(
|
||||
reused=len(existing) - len(plan.to_delete),
|
||||
embedded=len(plan.to_embed),
|
||||
deleted=len(plan.to_delete),
|
||||
)
|
||||
return len(new_texts)
|
||||
|
||||
async def _enqueue_ai_sort_if_enabled(self, document: Document) -> None:
|
||||
"""Fire-and-forget: enqueue incremental AI sort if the search space has it enabled."""
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -2,6 +2,9 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
# Matches notifications.title VARCHAR(200).
|
||||
TITLE_MAX_LENGTH = 200
|
||||
|
||||
# Notifications newer than this are live-synced; older ones load via the list endpoint.
|
||||
SYNC_WINDOW_DAYS = 14
|
||||
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ class DocumentProcessingNotificationHandler(BaseNotificationHandler):
|
|||
) -> Notification:
|
||||
"""Open the notification when document processing is queued."""
|
||||
operation_id = msg.operation_id(document_type, document_name, search_space_id)
|
||||
title = f"Processing: {document_name}"
|
||||
title = msg.started_title(document_name)
|
||||
message = "Waiting in queue"
|
||||
|
||||
metadata = {
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ import hashlib
|
|||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from app.notifications.service.messages.text import format_title
|
||||
|
||||
|
||||
def operation_id(document_type: str, filename: str, search_space_id: int) -> str:
|
||||
"""Build a unique id for a document processing run."""
|
||||
|
|
@ -14,6 +16,11 @@ def operation_id(document_type: str, filename: str, search_space_id: int) -> str
|
|||
return f"doc_{document_type}_{search_space_id}_{timestamp}_{filename_hash}"
|
||||
|
||||
|
||||
def started_title(document_name: str) -> str:
|
||||
"""Title shown when document processing is queued."""
|
||||
return format_title("Processing: ", document_name)
|
||||
|
||||
|
||||
def progress(
|
||||
stage: str,
|
||||
stage_message: str | None = None,
|
||||
|
|
@ -44,11 +51,11 @@ def completion(
|
|||
) -> tuple[str, str, str, dict[str, Any]]:
|
||||
"""Compute the final title, message, status, and metadata for a finished run."""
|
||||
if error_message:
|
||||
title = f"Failed: {document_name}"
|
||||
title = format_title("Failed: ", document_name)
|
||||
message = f"Processing failed: {error_message}"
|
||||
status = "failed"
|
||||
else:
|
||||
title = f"Ready: {document_name}"
|
||||
title = format_title("Ready: ", document_name)
|
||||
message = "Now searchable!"
|
||||
status = "completed"
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,21 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.notifications.constants import TITLE_MAX_LENGTH
|
||||
|
||||
|
||||
def truncate(text: str, limit: int) -> str:
|
||||
"""Return ``text`` capped at ``limit`` chars, appending an ellipsis if cut."""
|
||||
return text[:limit] + "..." if len(text) > limit else text
|
||||
|
||||
|
||||
def format_title(prefix: str, text: str, *, max_length: int = TITLE_MAX_LENGTH) -> str:
|
||||
"""Build a notification title that fits ``max_length`` including ``prefix``."""
|
||||
budget = max_length - len(prefix)
|
||||
if budget <= 0:
|
||||
return prefix[:max_length]
|
||||
if len(text) <= budget:
|
||||
return f"{prefix}{text}"
|
||||
if budget <= 3:
|
||||
return f"{prefix}{text[:budget]}"
|
||||
return f"{prefix}{text[: budget - 3]}..."
|
||||
|
|
|
|||
|
|
@ -289,6 +289,49 @@ def _etl_extract_outcome():
|
|||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _etl_cache_lookups():
|
||||
return _get_meter().create_counter(
|
||||
"surfsense.etl.cache.lookups",
|
||||
description="Count of ETL parse-cache lookups by outcome (hit/miss).",
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _etl_cache_evictions():
|
||||
return _get_meter().create_counter(
|
||||
"surfsense.etl.cache.evictions",
|
||||
description="Count of ETL parse-cache entries evicted, by phase.",
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _embedding_cache_lookups():
|
||||
return _get_meter().create_counter(
|
||||
"surfsense.embedding.cache.lookups",
|
||||
description="Count of embedding (chunk+embedding) cache lookups by outcome (hit/miss).",
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _embedding_cache_evictions():
|
||||
return _get_meter().create_counter(
|
||||
"surfsense.embedding.cache.evictions",
|
||||
description="Count of embedding cache entries evicted, by phase.",
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _chunk_reconcile_chunks():
|
||||
return _get_meter().create_counter(
|
||||
"surfsense.indexing.reconcile.chunks",
|
||||
description=(
|
||||
"Chunks handled by incremental re-indexing, by outcome "
|
||||
"(reused/embedded/deleted)."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _celery_heartbeat_refreshes():
|
||||
return _get_meter().create_counter(
|
||||
|
|
@ -670,6 +713,61 @@ def record_etl_extract_outcome(
|
|||
)
|
||||
|
||||
|
||||
def record_etl_cache_lookup(
|
||||
*, etl_service: str | None, mode: str | None, outcome: str
|
||||
) -> None:
|
||||
"""Record a parse-cache lookup. ``outcome`` is ``hit`` or ``miss``."""
|
||||
_add(
|
||||
_etl_cache_lookups(),
|
||||
1,
|
||||
{
|
||||
"etl.service": etl_service or "unknown",
|
||||
"mode": mode or "unknown",
|
||||
"outcome": outcome,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def record_etl_cache_eviction(count: int, *, phase: str) -> None:
|
||||
"""Record evicted entries. ``phase`` is ``ttl`` or ``size``."""
|
||||
if count <= 0:
|
||||
return
|
||||
_add(_etl_cache_evictions(), count, {"phase": phase})
|
||||
|
||||
|
||||
def record_embedding_cache_lookup(
|
||||
*, embedding_model: str | None, chunker_kind: str | None, outcome: str
|
||||
) -> None:
|
||||
"""Record an embedding-cache lookup. ``outcome`` is ``hit`` or ``miss``."""
|
||||
_add(
|
||||
_embedding_cache_lookups(),
|
||||
1,
|
||||
{
|
||||
"embedding.model": embedding_model or "unknown",
|
||||
"chunker.kind": chunker_kind or "unknown",
|
||||
"outcome": outcome,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def record_embedding_cache_eviction(count: int, *, phase: str) -> None:
|
||||
"""Record evicted entries. ``phase`` is ``ttl`` or ``size``."""
|
||||
if count <= 0:
|
||||
return
|
||||
_add(_embedding_cache_evictions(), count, {"phase": phase})
|
||||
|
||||
|
||||
def record_chunk_reconcile(*, reused: int, embedded: int, deleted: int) -> None:
|
||||
"""Record an incremental re-index: how many chunks were kept vs recomputed."""
|
||||
for outcome, count in (
|
||||
("reused", reused),
|
||||
("embedded", embedded),
|
||||
("deleted", deleted),
|
||||
):
|
||||
if count > 0:
|
||||
_add(_chunk_reconcile_chunks(), count, {"outcome": outcome})
|
||||
|
||||
|
||||
def record_celery_heartbeat_refresh(*, heartbeat_type: str) -> None:
|
||||
_add(_celery_heartbeat_refreshes(), 1, {"heartbeat.type": heartbeat_type})
|
||||
|
||||
|
|
@ -863,9 +961,14 @@ __all__ = [
|
|||
"record_celery_queue_latency",
|
||||
"record_chat_request_duration",
|
||||
"record_chat_request_outcome",
|
||||
"record_chunk_reconcile",
|
||||
"record_compaction_run",
|
||||
"record_connector_sync_duration",
|
||||
"record_connector_sync_outcome",
|
||||
"record_embedding_cache_eviction",
|
||||
"record_embedding_cache_lookup",
|
||||
"record_etl_cache_eviction",
|
||||
"record_etl_cache_lookup",
|
||||
"record_etl_extract_duration",
|
||||
"record_etl_extract_outcome",
|
||||
"record_indexing_document_duration",
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ from app.utils.rbac import check_permission
|
|||
|
||||
from .schemas import (
|
||||
CreatePodcastRequest,
|
||||
LanguageOptions,
|
||||
PodcastDetail,
|
||||
PodcastSummary,
|
||||
UpdateSpecRequest,
|
||||
|
|
@ -114,6 +115,20 @@ async def list_voices(language: str | None = None):
|
|||
]
|
||||
|
||||
|
||||
@router.get("/podcasts/languages", response_model=LanguageOptions)
|
||||
async def list_languages():
|
||||
"""Languages the active TTS provider can offer the brief editor."""
|
||||
if not app_config.TTS_SERVICE:
|
||||
raise HTTPException(status_code=503, detail="No TTS provider configured")
|
||||
|
||||
provider = provider_from_service(app_config.TTS_SERVICE)
|
||||
offering = get_voice_catalog().offerable_languages(provider)
|
||||
return LanguageOptions(
|
||||
languages=offering.languages,
|
||||
allows_custom=offering.allows_custom,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/podcasts/voices/{voice_id}/preview")
|
||||
async def preview_voice(
|
||||
voice_id: str,
|
||||
|
|
|
|||
|
|
@ -63,6 +63,17 @@ class VoiceOption(BaseModel):
|
|||
gender: str
|
||||
|
||||
|
||||
class LanguageOptions(BaseModel):
|
||||
"""The languages the brief editor may offer for the active provider.
|
||||
|
||||
When ``allows_custom`` is true the list is a curated starting point and
|
||||
the editor accepts any BCP-47 tag beyond it.
|
||||
"""
|
||||
|
||||
languages: list[str]
|
||||
allows_custom: bool
|
||||
|
||||
|
||||
class PodcastSummary(BaseModel):
|
||||
"""Lightweight list item."""
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ configured provider via :func:`provider_from_service`.
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from .catalog import VoiceCatalog, get_voice_catalog
|
||||
from .catalog import LanguageOffering, VoiceCatalog, get_voice_catalog
|
||||
from .preview import render_voice_preview
|
||||
from .provider import TtsProvider, provider_from_service
|
||||
from .voice import ANY_LANGUAGE, CatalogVoice, VoiceGender
|
||||
|
|
@ -14,6 +14,7 @@ from .voice import ANY_LANGUAGE, CatalogVoice, VoiceGender
|
|||
__all__ = [
|
||||
"ANY_LANGUAGE",
|
||||
"CatalogVoice",
|
||||
"LanguageOffering",
|
||||
"TtsProvider",
|
||||
"VoiceCatalog",
|
||||
"VoiceGender",
|
||||
|
|
|
|||
|
|
@ -9,11 +9,26 @@ provider-native reference.
|
|||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
|
||||
from .data import AZURE_VOICES, KOKORO_VOICES, OPENAI_VOICES, VERTEX_VOICES
|
||||
from .data.languages import COMMON_LANGUAGES
|
||||
from .provider import TtsProvider
|
||||
from .voice import CatalogVoice
|
||||
from .voice import ANY_LANGUAGE, CatalogVoice
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class LanguageOffering:
|
||||
"""The languages a provider's roster can offer the brief form.
|
||||
|
||||
``allows_custom`` is true when the roster has wildcard voices: the listed
|
||||
languages are then a curated starting point, not a limit, and any BCP-47
|
||||
tag may be entered.
|
||||
"""
|
||||
|
||||
languages: list[str]
|
||||
allows_custom: bool
|
||||
|
||||
|
||||
class VoiceCatalog:
|
||||
|
|
@ -44,6 +59,20 @@ class VoiceCatalog:
|
|||
"""Whether ``provider`` has at least one voice for ``language``."""
|
||||
return any(v.speaks(language) for v in self.for_provider(provider))
|
||||
|
||||
def offerable_languages(self, provider: TtsProvider) -> LanguageOffering:
|
||||
"""The languages ``provider`` can offer up front.
|
||||
|
||||
Language-bound voices contribute their concrete tags; wildcard voices
|
||||
cannot enumerate languages, so their presence merges in the curated
|
||||
common list and opens free entry.
|
||||
"""
|
||||
voices = self.for_provider(provider)
|
||||
tags = {v.language for v in voices if v.language != ANY_LANGUAGE}
|
||||
has_wildcard = any(v.language == ANY_LANGUAGE for v in voices)
|
||||
if has_wildcard:
|
||||
tags.update(COMMON_LANGUAGES)
|
||||
return LanguageOffering(languages=sorted(tags), allows_custom=has_wildcard)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_voice_catalog() -> VoiceCatalog:
|
||||
|
|
|
|||
33
surfsense_backend/app/podcasts/voices/data/languages.py
Normal file
33
surfsense_backend/app/podcasts/voices/data/languages.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
"""Curated languages offered when a roster has wildcard (any-language) voices.
|
||||
|
||||
OpenAI-style multilingual voices speak whatever language the text is in, so
|
||||
there is no provider list to enumerate. This is the set the brief form offers
|
||||
up front for such providers; it is an offering, not a limit — the API flags
|
||||
``allows_custom`` so users can enter any BCP-47 tag beyond it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
COMMON_LANGUAGES: tuple[str, ...] = (
|
||||
"ar",
|
||||
"bn",
|
||||
"de",
|
||||
"en",
|
||||
"es",
|
||||
"fr",
|
||||
"hi",
|
||||
"id",
|
||||
"it",
|
||||
"ja",
|
||||
"ko",
|
||||
"nl",
|
||||
"pl",
|
||||
"pt",
|
||||
"ru",
|
||||
"sw",
|
||||
"th",
|
||||
"tr",
|
||||
"uk",
|
||||
"vi",
|
||||
"zh",
|
||||
)
|
||||
|
|
@ -82,7 +82,7 @@ def build_configurable_system_prompt(
|
|||
*,
|
||||
model_name: str | None = None,
|
||||
) -> str:
|
||||
"""Build a configurable SurfSense system prompt (NewLLMConfig path).
|
||||
"""Build a configurable SurfSense system prompt.
|
||||
|
||||
See :func:`app.prompts.system_prompt_composer.composer.compose_system_prompt`
|
||||
for full parameter docs.
|
||||
|
|
@ -104,7 +104,7 @@ def build_configurable_system_prompt(
|
|||
def get_default_system_instructions() -> str:
|
||||
"""Return the default ``<system_instruction>`` block (no tools / citations).
|
||||
|
||||
Useful for populating the UI when seeding ``NewLLMConfig.system_instructions``.
|
||||
Useful for populating the UI when editing custom system instructions.
|
||||
The output reflects the current fragment tree, not a baked-in constant.
|
||||
"""
|
||||
resolved_today = datetime.now(UTC).date().isoformat()
|
||||
|
|
|
|||
|
|
@ -348,8 +348,7 @@ def compose_system_prompt(
|
|||
mcp_connector_tools: ``{server_name: [tool_names...]}`` to inject
|
||||
an explicit MCP routing block.
|
||||
custom_system_instructions: Free-form instructions that override
|
||||
the default ``<system_instruction>`` block (legacy support
|
||||
for ``NewLLMConfig.system_instructions``).
|
||||
the default ``<system_instruction>`` block.
|
||||
use_default_system_instructions: When ``custom_system_instructions``
|
||||
is empty/None, fall back to defaults (legacy semantics).
|
||||
citations_enabled: Include ``citations_on.md`` (true) or
|
||||
|
|
|
|||
|
|
@ -420,7 +420,10 @@ class ChucksHybridSearchRetriever:
|
|||
select(
|
||||
Chunk.id.label("chunk_id"),
|
||||
func.row_number()
|
||||
.over(partition_by=Chunk.document_id, order_by=Chunk.id)
|
||||
.over(
|
||||
partition_by=Chunk.document_id,
|
||||
order_by=(Chunk.position, Chunk.id),
|
||||
)
|
||||
.label("rn"),
|
||||
)
|
||||
.where(Chunk.document_id.in_(doc_ids))
|
||||
|
|
@ -441,7 +444,7 @@ class ChucksHybridSearchRetriever:
|
|||
select(Chunk.id, Chunk.content, Chunk.document_id)
|
||||
.join(numbered, Chunk.id == numbered.c.chunk_id)
|
||||
.where(chunk_filter)
|
||||
.order_by(Chunk.document_id, Chunk.id)
|
||||
.order_by(Chunk.document_id, Chunk.position, Chunk.id)
|
||||
)
|
||||
|
||||
t_fetch = time.perf_counter()
|
||||
|
|
|
|||
|
|
@ -357,7 +357,10 @@ class DocumentHybridSearchRetriever:
|
|||
select(
|
||||
Chunk.id.label("chunk_id"),
|
||||
func.row_number()
|
||||
.over(partition_by=Chunk.document_id, order_by=Chunk.id)
|
||||
.over(
|
||||
partition_by=Chunk.document_id,
|
||||
order_by=(Chunk.position, Chunk.id),
|
||||
)
|
||||
.label("rn"),
|
||||
)
|
||||
.where(Chunk.document_id.in_(doc_ids))
|
||||
|
|
@ -369,7 +372,7 @@ class DocumentHybridSearchRetriever:
|
|||
select(Chunk.id, Chunk.content, Chunk.document_id)
|
||||
.join(numbered, Chunk.id == numbered.c.chunk_id)
|
||||
.where(numbered.c.rn <= _MAX_FETCH_CHUNKS_PER_DOC)
|
||||
.order_by(Chunk.document_id, Chunk.id)
|
||||
.order_by(Chunk.document_id, Chunk.position, Chunk.id)
|
||||
)
|
||||
|
||||
t_fetch = time.perf_counter()
|
||||
|
|
|
|||
|
|
@ -24,7 +24,10 @@ from .dropbox_add_connector_route import router as dropbox_add_connector_router
|
|||
from .editor_routes import router as editor_router
|
||||
from .export_routes import router as export_router
|
||||
from .folders_routes import router as folders_router
|
||||
from .gateway_webhook_routes import router as gateway_router
|
||||
from .gateway_webhook_routes import (
|
||||
config_router as gateway_config_router,
|
||||
router as gateway_router,
|
||||
)
|
||||
from .gateway_whatsapp_baileys_routes import router as gateway_whatsapp_baileys_router
|
||||
from .gateway_whatsapp_webhook_routes import router as gateway_whatsapp_webhook_router
|
||||
from .google_calendar_add_connector_route import (
|
||||
|
|
@ -44,9 +47,9 @@ from .logs_routes import router as logs_router
|
|||
from .luma_add_connector_route import router as luma_add_connector_router
|
||||
from .mcp_oauth_route import router as mcp_oauth_router
|
||||
from .memory_routes import router as memory_router
|
||||
from .model_connections_routes import router as model_connections_router
|
||||
from .model_list_routes import router as model_list_router
|
||||
from .new_chat_routes import router as new_chat_router
|
||||
from .new_llm_config_routes import router as new_llm_config_router
|
||||
from .notes_routes import router as notes_router
|
||||
from .notion_add_connector_route import router as notion_add_connector_router
|
||||
from .obsidian_plugin_routes import router as obsidian_plugin_router
|
||||
|
|
@ -63,7 +66,6 @@ from .stripe_routes import router as stripe_router
|
|||
from .team_memory_routes import router as team_memory_router
|
||||
from .teams_add_connector_route import router as teams_add_connector_router
|
||||
from .video_presentations_routes import router as video_presentations_router
|
||||
from .vision_llm_routes import router as vision_llm_router
|
||||
from .youtube_routes import router as youtube_router
|
||||
|
||||
router = APIRouter()
|
||||
|
|
@ -75,6 +77,7 @@ router.include_router(export_router)
|
|||
router.include_router(documents_router)
|
||||
router.include_router(folders_router)
|
||||
_gateway_enabled_dep = [Depends(require_gateway_enabled)]
|
||||
router.include_router(gateway_config_router)
|
||||
router.include_router(gateway_router, dependencies=_gateway_enabled_dep)
|
||||
router.include_router(
|
||||
gateway_whatsapp_webhook_router, dependencies=_gateway_enabled_dep
|
||||
|
|
@ -98,7 +101,6 @@ router.include_router(
|
|||
) # Video presentation status and streaming
|
||||
router.include_router(reports_router) # Report CRUD and multi-format export
|
||||
router.include_router(image_generation_router) # Image generation via litellm
|
||||
router.include_router(vision_llm_router) # Vision LLM configs for screenshot analysis
|
||||
router.include_router(search_source_connectors_router)
|
||||
router.include_router(google_calendar_add_connector_router)
|
||||
router.include_router(google_gmail_add_connector_router)
|
||||
|
|
@ -116,7 +118,7 @@ router.include_router(jira_add_connector_router)
|
|||
router.include_router(confluence_add_connector_router)
|
||||
router.include_router(clickup_add_connector_router)
|
||||
router.include_router(dropbox_add_connector_router)
|
||||
router.include_router(new_llm_config_router) # LLM configs with prompt configuration
|
||||
router.include_router(model_connections_router) # Connection-centric model catalog
|
||||
router.include_router(model_list_router) # Dynamic model catalogue from OpenRouter
|
||||
router.include_router(logs_router)
|
||||
router.include_router(circleback_webhook_router) # Circleback meeting webhooks
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from app.etl_pipeline.file_classifier import (
|
|||
PLAINTEXT_EXTENSIONS,
|
||||
)
|
||||
from app.rate_limiter import limiter
|
||||
from app.tasks.chat.streaming.errors.classifier import classify_stream_exception
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -98,7 +99,6 @@ class AnonQuotaResponse(BaseModel):
|
|||
class AnonModelResponse(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: str | None = None
|
||||
provider: str
|
||||
model_name: str
|
||||
billing_tier: str = "free"
|
||||
|
|
@ -131,8 +131,7 @@ async def list_anonymous_models():
|
|||
AnonModelResponse(
|
||||
id=cfg.get("id", 0),
|
||||
name=cfg.get("name", ""),
|
||||
description=cfg.get("description"),
|
||||
provider=cfg.get("provider", ""),
|
||||
provider=cfg.get("provider") or cfg.get("litellm_provider", ""),
|
||||
model_name=cfg.get("model_name", ""),
|
||||
billing_tier=cfg.get("billing_tier", "free"),
|
||||
is_premium=cfg.get("billing_tier", "free") == "premium",
|
||||
|
|
@ -160,8 +159,7 @@ async def get_anonymous_model(slug: str):
|
|||
return AnonModelResponse(
|
||||
id=cfg.get("id", 0),
|
||||
name=cfg.get("name", ""),
|
||||
description=cfg.get("description"),
|
||||
provider=cfg.get("provider", ""),
|
||||
provider=cfg.get("provider") or cfg.get("litellm_provider", ""),
|
||||
model_name=cfg.get("model_name", ""),
|
||||
billing_tier=cfg.get("billing_tier", "free"),
|
||||
is_premium=cfg.get("billing_tier", "free") == "premium",
|
||||
|
|
@ -474,7 +472,15 @@ async def stream_anonymous_chat(
|
|||
except Exception as e:
|
||||
logger.exception("Anonymous chat stream error")
|
||||
await TokenQuotaService.anon_release(session_key, ip_key, request_id)
|
||||
yield streaming_service.format_error(f"Error during chat: {e!s}")
|
||||
_, error_code, _, _, user_message, extra = classify_stream_exception(
|
||||
e,
|
||||
flow_label="chat",
|
||||
)
|
||||
yield streaming_service.format_error(
|
||||
user_message,
|
||||
error_code=error_code,
|
||||
extra=extra,
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
finally:
|
||||
await TokenQuotaService.anon_release_stream_slot(client_ip)
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue