mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-22 21:28:12 +02:00
Compare commits
No commits in common. "main" and "v0.0.27" have entirely different histories.
602 changed files with 20889 additions and 27047 deletions
7
.github/workflows/desktop-release.yml
vendored
7
.github/workflows/desktop-release.yml
vendored
|
|
@ -95,12 +95,10 @@ jobs:
|
|||
run: pnpm build
|
||||
working-directory: surfsense_web
|
||||
env:
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_URL: ${{ vars.HOSTED_BACKEND_URL }}
|
||||
SURFSENSE_BACKEND_INTERNAL_URL: ${{ vars.HOSTED_BACKEND_URL }}
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_URL: ${{ vars.NEXT_PUBLIC_FASTAPI_BACKEND_URL }}
|
||||
NEXT_PUBLIC_ZERO_CACHE_URL: ${{ vars.NEXT_PUBLIC_ZERO_CACHE_URL }}
|
||||
NEXT_PUBLIC_DEPLOYMENT_MODE: ${{ vars.NEXT_PUBLIC_DEPLOYMENT_MODE }}
|
||||
NEXT_PUBLIC_AUTH_TYPE: ${{ vars.NEXT_PUBLIC_AUTH_TYPE }}
|
||||
NEXT_PUBLIC_ETL_SERVICE: ${{ vars.NEXT_PUBLIC_ETL_SERVICE }}
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE: ${{ vars.NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE }}
|
||||
NEXT_PUBLIC_POSTHOG_KEY: ${{ secrets.NEXT_PUBLIC_POSTHOG_KEY }}
|
||||
|
||||
- name: Install desktop dependencies
|
||||
|
|
@ -111,7 +109,6 @@ jobs:
|
|||
run: pnpm build
|
||||
working-directory: surfsense_desktop
|
||||
env:
|
||||
HOSTED_BACKEND_URL: ${{ vars.HOSTED_BACKEND_URL }}
|
||||
HOSTED_FRONTEND_URL: ${{ vars.HOSTED_FRONTEND_URL }}
|
||||
POSTHOG_KEY: ${{ secrets.POSTHOG_KEY }}
|
||||
POSTHOG_HOST: ${{ vars.POSTHOG_HOST }}
|
||||
|
|
|
|||
5
.github/workflows/docker-build.yml
vendored
5
.github/workflows/docker-build.yml
vendored
|
|
@ -199,6 +199,11 @@ jobs:
|
|||
build-args: |
|
||||
${{ matrix.image == 'backend' && format('USE_CUDA={0}', matrix.use_cuda) || '' }}
|
||||
${{ matrix.image == 'backend' && format('CUDA_EXTRA={0}', matrix.cuda_extra) || '' }}
|
||||
${{ matrix.image == 'web' && 'NEXT_PUBLIC_FASTAPI_BACKEND_URL=__NEXT_PUBLIC_FASTAPI_BACKEND_URL__' || '' }}
|
||||
${{ matrix.image == 'web' && 'NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=__NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE__' || '' }}
|
||||
${{ matrix.image == 'web' && 'NEXT_PUBLIC_ETL_SERVICE=__NEXT_PUBLIC_ETL_SERVICE__' || '' }}
|
||||
${{ matrix.image == 'web' && 'NEXT_PUBLIC_ZERO_CACHE_URL=__NEXT_PUBLIC_ZERO_CACHE_URL__' || '' }}
|
||||
${{ matrix.image == 'web' && 'NEXT_PUBLIC_DEPLOYMENT_MODE=__NEXT_PUBLIC_DEPLOYMENT_MODE__' || '' }}
|
||||
|
||||
- name: Export digest
|
||||
run: |
|
||||
|
|
|
|||
5
.github/workflows/e2e-tests.yml
vendored
5
.github/workflows/e2e-tests.yml
vendored
|
|
@ -27,10 +27,9 @@ jobs:
|
|||
PLAYWRIGHT_TEST_EMAIL: e2e-test@surfsense.net
|
||||
PLAYWRIGHT_TEST_PASSWORD: E2eTestPassword123!
|
||||
# Frontend env: Playwright's webServer (surfsense_web/playwright.config.ts)
|
||||
# spawns `pnpm build && pnpm start` in CI.
|
||||
# spawns `pnpm build && pnpm start` in CI; these get baked into the build.
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_URL: http://localhost:8000
|
||||
SURFSENSE_BACKEND_INTERNAL_URL: http://localhost:8000
|
||||
AUTH_TYPE: LOCAL
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE: LOCAL
|
||||
# Shared secret for the test-only POST /__e2e__/auth/token endpoint.
|
||||
# Must match docker-compose.e2e.yml's backend env (x-backend-env).
|
||||
E2E_MINT_SECRET: e2e-mint-secret-not-for-production
|
||||
|
|
|
|||
2
VERSION
2
VERSION
|
|
@ -1 +1 @@
|
|||
0.0.29
|
||||
0.0.27
|
||||
|
|
|
|||
|
|
@ -30,9 +30,6 @@ SECRET_KEY=replace_me_with_a_random_string
|
|||
# Auth type: LOCAL (email/password) or GOOGLE (OAuth)
|
||||
AUTH_TYPE=LOCAL
|
||||
|
||||
# Deployment mode: self-hosted enables local filesystem connectors; cloud hides them.
|
||||
DEPLOYMENT_MODE=self-hosted
|
||||
|
||||
# Allow new user registrations (TRUE or FALSE)
|
||||
# REGISTRATION_ENABLED=TRUE
|
||||
|
||||
|
|
@ -46,47 +43,51 @@ ETL_SERVICE=DOCLING
|
|||
EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# How You Access SurfSense
|
||||
# Ports (change to avoid conflicts with other services on your machine)
|
||||
# ------------------------------------------------------------------------------
|
||||
# One public URL. Browser traffic stays same-origin and Caddy routes internally.
|
||||
SURFSENSE_PUBLIC_URL=http://localhost:3929
|
||||
|
||||
# BACKEND_PORT=8929
|
||||
# FRONTEND_PORT=3929
|
||||
# ZERO_CACHE_PORT=5929
|
||||
# SEARXNG_PORT=8888
|
||||
# FLOWER_PORT=5555
|
||||
|
||||
# ==============================================================================
|
||||
# DEV COMPOSE ONLY (docker-compose.dev.yml)
|
||||
# You only need them only if you are running `docker-compose.dev.yml`.
|
||||
# ==============================================================================
|
||||
|
||||
# -- pgAdmin (database GUI) --
|
||||
# PGADMIN_PORT=5050
|
||||
# PGADMIN_DEFAULT_EMAIL=admin@surfsense.com
|
||||
# PGADMIN_DEFAULT_PASSWORD=surfsense
|
||||
|
||||
# -- Redis exposed port (dev only; Redis is internal-only in prod) --
|
||||
# REDIS_PORT=6379
|
||||
|
||||
# -- WhatsApp bridge exposed port (dev/hybrid only; prod keeps it Docker-internal) --
|
||||
# WHATSAPP_BRIDGE_PORT=9929
|
||||
|
||||
# -- Frontend Build Args --
|
||||
# In dev, the frontend is built from source and these are passed as build args.
|
||||
# In prod, they are automatically derived from AUTH_TYPE, ETL_SERVICE, and the port settings above.
|
||||
# NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=LOCAL
|
||||
# NEXT_PUBLIC_ETL_SERVICE=DOCLING
|
||||
# NEXT_PUBLIC_DEPLOYMENT_MODE=self-hosted
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Public Ports
|
||||
# Custom Domain / Reverse Proxy
|
||||
# ------------------------------------------------------------------------------
|
||||
# Production Docker exposes only Caddy to your machine. Caddy then routes
|
||||
# frontend, backend, and zero-cache traffic internally.
|
||||
# ONLY set these if you are serving SurfSense on a real domain via a reverse
|
||||
# proxy (e.g. Caddy, Nginx, Cloudflare Tunnel).
|
||||
# For standard localhost deployments, leave all of these commented out.
|
||||
# they are automatically derived from the port settings above.
|
||||
#
|
||||
# Local default: LISTEN_HTTP_PORT=3929
|
||||
# Domain default: LISTEN_HTTP_PORT=80 and LISTEN_HTTPS_PORT=443
|
||||
LISTEN_HTTP_PORT=3929
|
||||
LISTEN_HTTPS_PORT=443
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Custom Domain / HTTPS
|
||||
# ------------------------------------------------------------------------------
|
||||
# Leave SURFSENSE_SITE_ADDRESS as :80 for local HTTP.
|
||||
# Set it to your domain to enable automatic HTTPS:
|
||||
# SURFSENSE_SITE_ADDRESS=surf.example.com
|
||||
# CERT_EMAIL=you@example.com
|
||||
SURFSENSE_SITE_ADDRESS=:80
|
||||
CERT_EMAIL=
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Advanced Reverse Proxy Settings
|
||||
# ------------------------------------------------------------------------------
|
||||
# Usually do not change these. They are for custom certificate setup, CDNs/load
|
||||
# balancers, trusted proxy IPs, or changing upload limits.
|
||||
#
|
||||
# CERT_ACME_CA=https://acme-v02.api.letsencrypt.org/directory
|
||||
# CERT_ACME_DNS=
|
||||
# If a CDN/load balancer sits in front of Caddy, narrow this to that proxy's CIDRs.
|
||||
# TRUSTED_PROXIES=0.0.0.0/0
|
||||
# SURFSENSE_MAX_BODY_SIZE=5GB
|
||||
#
|
||||
# Browser API and Zero URLs are same-origin relative behind bundled Caddy.
|
||||
# Next.js server-side calls use Docker DNS through SURFSENSE_BACKEND_INTERNAL_URL
|
||||
# set internally by docker-compose.yml. Usually do not override it.
|
||||
# NEXT_FRONTEND_URL=https://app.yourdomain.com
|
||||
# BACKEND_URL=https://api.yourdomain.com
|
||||
# NEXT_PUBLIC_FASTAPI_BACKEND_URL=https://api.yourdomain.com
|
||||
# NEXT_PUBLIC_ZERO_CACHE_URL=https://zero.yourdomain.com
|
||||
# FASTAPI_BACKEND_INTERNAL_URL=http://backend:8000
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Zero-cache (real-time sync)
|
||||
|
|
@ -107,9 +108,10 @@ CERT_EMAIL=
|
|||
|
||||
# Sync worker tuning. zero-cache defaults ZERO_NUM_SYNC_WORKERS to the number
|
||||
# of CPU cores, which can exceed the connection pool limits on high-core machines.
|
||||
# Each sync worker needs at least 1 connection from both the UPSTREAM and CVR pools.
|
||||
# Keep ZERO_UPSTREAM_MAX_CONNS and ZERO_CVR_MAX_CONNS greater than or equal to
|
||||
# ZERO_NUM_SYNC_WORKERS.
|
||||
# Each sync worker needs at least 1 connection from both the UPSTREAM and CVR
|
||||
# pools, so these constraints must hold:
|
||||
# ZERO_UPSTREAM_MAX_CONNS >= ZERO_NUM_SYNC_WORKERS
|
||||
# ZERO_CVR_MAX_CONNS >= ZERO_NUM_SYNC_WORKERS
|
||||
# Default of 4 workers is sufficient for self-hosted / personal use.
|
||||
# ZERO_NUM_SYNC_WORKERS=4
|
||||
# ZERO_UPSTREAM_MAX_CONNS=20
|
||||
|
|
@ -123,16 +125,16 @@ CERT_EMAIL=
|
|||
|
||||
# ZERO_QUERY_URL: where zero-cache forwards query requests for resolution.
|
||||
# ZERO_MUTATE_URL: required by zero-cache when auth tokens are used, even though
|
||||
# SurfSense does not use Zero mutators. Setting both URLs tells zero-cache to
|
||||
# skip its own JWT verification and let the app endpoints handle auth instead.
|
||||
# The mutate endpoint is a no-op that returns an empty response.
|
||||
# SurfSense does not use Zero mutators. Setting both URLs tells zero-cache to
|
||||
# skip its own JWT verification and let the app endpoints handle auth instead.
|
||||
# The mutate endpoint is a no-op that returns an empty response.
|
||||
# Default: Docker service networking (http://frontend:3000/api/zero/...).
|
||||
# Override when running the frontend outside Docker:
|
||||
# ZERO_QUERY_URL=http://host.docker.internal:3000/api/zero/query
|
||||
# ZERO_MUTATE_URL=http://host.docker.internal:3000/api/zero/mutate
|
||||
# Override for custom domain only when zero-cache is not in the bundled Docker network:
|
||||
# ZERO_QUERY_URL=https://surf.example.com/api/zero/query
|
||||
# ZERO_MUTATE_URL=https://surf.example.com/api/zero/mutate
|
||||
# ZERO_QUERY_URL=http://host.docker.internal:3000/api/zero/query
|
||||
# ZERO_MUTATE_URL=http://host.docker.internal:3000/api/zero/mutate
|
||||
# Override for custom domain:
|
||||
# ZERO_QUERY_URL=https://app.yourdomain.com/api/zero/query
|
||||
# ZERO_MUTATE_URL=https://app.yourdomain.com/api/zero/mutate
|
||||
# ZERO_QUERY_URL=http://frontend:3000/api/zero/query
|
||||
# ZERO_MUTATE_URL=http://frontend:3000/api/zero/mutate
|
||||
|
||||
|
|
@ -164,26 +166,25 @@ CERT_EMAIL=
|
|||
# REDIS_URL=redis://redis:6379/0
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Stripe (unified credit wallet, disabled by default)
|
||||
# Stripe (pay-as-you-go page packs, disabled by default)
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# Set TRUE to allow users to buy credit packs via Stripe Checkout. $1 buys
|
||||
# 1_000_000 micro-USD of credit; both ETL page processing and premium turns
|
||||
# debit this balance at the actual per-call provider cost from LiteLLM.
|
||||
STRIPE_CREDIT_BUYING_ENABLED=FALSE
|
||||
# Set TRUE to allow users to buy additional page packs via Stripe Checkout
|
||||
STRIPE_PAGE_BUYING_ENABLED=FALSE
|
||||
# STRIPE_SECRET_KEY=sk_test_...
|
||||
# STRIPE_WEBHOOK_SECRET=whsec_...
|
||||
# STRIPE_CREDIT_PRICE_ID=price_...
|
||||
# STRIPE_CREDIT_MICROS_PER_UNIT=1000000
|
||||
# STRIPE_PRICE_ID=price_...
|
||||
# STRIPE_PAGES_PER_UNIT=1000
|
||||
# STRIPE_RECONCILIATION_INTERVAL=10m
|
||||
# STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10
|
||||
# STRIPE_RECONCILIATION_BATCH_SIZE=100
|
||||
|
||||
# Auto-reload: top up via a saved Stripe card when the balance drops below
|
||||
# the user-chosen threshold. Off by default.
|
||||
# AUTO_RELOAD_ENABLED=FALSE
|
||||
# AUTO_RELOAD_MIN_AMOUNT_MICROS=1000000
|
||||
# AUTO_RELOAD_COOLDOWN_MINUTES=10
|
||||
# Premium credit purchases via Stripe ($1 buys 1_000_000 micro-USD of
|
||||
# credit; premium turns debit the actual per-call provider cost
|
||||
# reported by LiteLLM, so cheap and expensive models bill proportionally)
|
||||
# STRIPE_TOKEN_BUYING_ENABLED=FALSE
|
||||
# STRIPE_PREMIUM_TOKEN_PRICE_ID=price_...
|
||||
# STRIPE_CREDIT_MICROS_PER_UNIT=1000000
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# TTS & STT (Text-to-Speech / Speech-to-Text)
|
||||
|
|
@ -220,74 +221,73 @@ STT_SERVICE=local/base
|
|||
# ------------------------------------------------------------------------------
|
||||
|
||||
# -- Google Connectors --
|
||||
# GOOGLE_CALENDAR_REDIRECT_URI=http://localhost:3929/api/v1/auth/google/calendar/connector/callback
|
||||
# GOOGLE_GMAIL_REDIRECT_URI=http://localhost:3929/api/v1/auth/google/gmail/connector/callback
|
||||
# GOOGLE_DRIVE_REDIRECT_URI=http://localhost:3929/api/v1/auth/google/drive/connector/callback
|
||||
# GOOGLE_CALENDAR_REDIRECT_URI=http://localhost:8000/api/v1/auth/google/calendar/connector/callback
|
||||
# GOOGLE_GMAIL_REDIRECT_URI=http://localhost:8000/api/v1/auth/google/gmail/connector/callback
|
||||
# GOOGLE_DRIVE_REDIRECT_URI=http://localhost:8000/api/v1/auth/google/drive/connector/callback
|
||||
|
||||
# -- Notion --
|
||||
# NOTION_CLIENT_ID=
|
||||
# NOTION_CLIENT_SECRET=
|
||||
# NOTION_REDIRECT_URI=http://localhost:3929/api/v1/auth/notion/connector/callback
|
||||
# NOTION_REDIRECT_URI=http://localhost:8000/api/v1/auth/notion/connector/callback
|
||||
|
||||
# -- Slack --
|
||||
# SLACK_CLIENT_ID=
|
||||
# SLACK_CLIENT_SECRET=
|
||||
# SLACK_REDIRECT_URI=http://localhost:3929/api/v1/auth/slack/connector/callback
|
||||
# SLACK_REDIRECT_URI=http://localhost:8000/api/v1/auth/slack/connector/callback
|
||||
|
||||
# -- Discord --
|
||||
# DISCORD_CLIENT_ID=
|
||||
# DISCORD_CLIENT_SECRET=
|
||||
# DISCORD_REDIRECT_URI=http://localhost:3929/api/v1/auth/discord/connector/callback
|
||||
# DISCORD_REDIRECT_URI=http://localhost:8000/api/v1/auth/discord/connector/callback
|
||||
# DISCORD_BOT_TOKEN=
|
||||
|
||||
# -- Atlassian (Jira & Confluence) --
|
||||
# ATLASSIAN_CLIENT_ID=
|
||||
# ATLASSIAN_CLIENT_SECRET=
|
||||
# JIRA_REDIRECT_URI=http://localhost:3929/api/v1/auth/jira/connector/callback
|
||||
# CONFLUENCE_REDIRECT_URI=http://localhost:3929/api/v1/auth/confluence/connector/callback
|
||||
# JIRA_REDIRECT_URI=http://localhost:8000/api/v1/auth/jira/connector/callback
|
||||
# CONFLUENCE_REDIRECT_URI=http://localhost:8000/api/v1/auth/confluence/connector/callback
|
||||
|
||||
# -- Linear --
|
||||
# LINEAR_CLIENT_ID=
|
||||
# LINEAR_CLIENT_SECRET=
|
||||
# LINEAR_REDIRECT_URI=http://localhost:3929/api/v1/auth/linear/connector/callback
|
||||
# LINEAR_REDIRECT_URI=http://localhost:8000/api/v1/auth/linear/connector/callback
|
||||
|
||||
# -- ClickUp --
|
||||
# CLICKUP_CLIENT_ID=
|
||||
# CLICKUP_CLIENT_SECRET=
|
||||
# CLICKUP_REDIRECT_URI=http://localhost:3929/api/v1/auth/clickup/connector/callback
|
||||
# CLICKUP_REDIRECT_URI=http://localhost:8000/api/v1/auth/clickup/connector/callback
|
||||
|
||||
# -- Airtable --
|
||||
# AIRTABLE_CLIENT_ID=
|
||||
# AIRTABLE_CLIENT_SECRET=
|
||||
# AIRTABLE_REDIRECT_URI=http://localhost:3929/api/v1/auth/airtable/connector/callback
|
||||
# AIRTABLE_REDIRECT_URI=http://localhost:8000/api/v1/auth/airtable/connector/callback
|
||||
|
||||
# -- Microsoft OAuth (Teams & OneDrive) --
|
||||
# MICROSOFT_CLIENT_ID=
|
||||
# MICROSOFT_CLIENT_SECRET=
|
||||
# TEAMS_REDIRECT_URI=http://localhost:3929/api/v1/auth/teams/connector/callback
|
||||
# ONEDRIVE_REDIRECT_URI=http://localhost:3929/api/v1/auth/onedrive/connector/callback
|
||||
# TEAMS_REDIRECT_URI=http://localhost:8000/api/v1/auth/teams/connector/callback
|
||||
# ONEDRIVE_REDIRECT_URI=http://localhost:8000/api/v1/auth/onedrive/connector/callback
|
||||
|
||||
# -- Dropbox --
|
||||
# DROPBOX_APP_KEY=
|
||||
# DROPBOX_APP_SECRET=
|
||||
# DROPBOX_REDIRECT_URI=http://localhost:3929/api/v1/auth/dropbox/connector/callback
|
||||
# DROPBOX_REDIRECT_URI=http://localhost:8000/api/v1/auth/dropbox/connector/callback
|
||||
|
||||
# -- Composio --
|
||||
# COMPOSIO_API_KEY=
|
||||
# COMPOSIO_ENABLED=TRUE
|
||||
# COMPOSIO_REDIRECT_URI=http://localhost:3929/api/v1/auth/composio/connector/callback
|
||||
# COMPOSIO_REDIRECT_URI=http://localhost:8000/api/v1/auth/composio/connector/callback
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Messaging Channels (optional)
|
||||
# ------------------------------------------------------------------------------
|
||||
# Configure only the external chat channels you want to use.
|
||||
# GATEWAY_ENABLED=TRUE
|
||||
|
||||
# -- Telegram --
|
||||
# TELEGRAM_SHARED_BOT_TOKEN=
|
||||
# TELEGRAM_SHARED_BOT_USERNAME=
|
||||
# TELEGRAM_WEBHOOK_SECRET=
|
||||
# GATEWAY_BASE_URL=http://localhost:3929
|
||||
# GATEWAY_BASE_URL=http://localhost:8929
|
||||
# GATEWAY_TELEGRAM_INTAKE_MODE=webhook
|
||||
|
||||
# -- WhatsApp --
|
||||
|
|
@ -306,20 +306,20 @@ STT_SERVICE=local/base
|
|||
#
|
||||
# GATEWAY_SLACK_ENABLED=FALSE
|
||||
# GATEWAY_SLACK_SIGNING_SECRET=
|
||||
# GATEWAY_SLACK_REDIRECT_URI=http://localhost:3929/api/v1/gateway/slack/callback
|
||||
# GATEWAY_SLACK_REDIRECT_URI=http://localhost:8929/api/v1/gateway/slack/callback
|
||||
|
||||
# -- Discord --
|
||||
# Uses DISCORD_CLIENT_ID, DISCORD_CLIENT_SECRET, and DISCORD_BOT_TOKEN from the
|
||||
# Discord connector section.
|
||||
#
|
||||
# GATEWAY_DISCORD_ENABLED=FALSE
|
||||
# GATEWAY_DISCORD_REDIRECT_URI=http://localhost:3929/api/v1/gateway/discord/callback
|
||||
# GATEWAY_DISCORD_REDIRECT_URI=http://localhost:8929/api/v1/gateway/discord/callback
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# SearXNG (bundled web search, works out of the box with no config needed)
|
||||
# ------------------------------------------------------------------------------
|
||||
# SearXNG provides web search to all search spaces automatically.
|
||||
# To access the SearXNG UI directly in dev/deps-only compose: http://localhost:8888
|
||||
# To access the SearXNG UI directly: http://localhost:8888
|
||||
# To disable the service entirely: docker compose up --scale searxng=0
|
||||
# To point at your own SearXNG instance instead of the bundled one:
|
||||
# SEARXNG_DEFAULT_HOST=http://your-searxng:8080
|
||||
|
|
@ -407,16 +407,13 @@ SURFSENSE_ENABLE_DOOM_LOOP=true
|
|||
# ACCESS_TOKEN_LIFETIME_SECONDS=86400
|
||||
# REFRESH_TOKEN_LIFETIME_SECONDS=1209600
|
||||
|
||||
# Unified credit wallet starting balance for new users, in micro-USD
|
||||
# (default: $5). Funds both ETL page processing and premium model calls,
|
||||
# debited at the actual per-call provider cost reported by LiteLLM.
|
||||
# DEFAULT_CREDIT_MICROS_BALANCE=5000000
|
||||
# Pages limit per user for ETL (default: unlimited)
|
||||
# PAGES_LIMIT=500
|
||||
|
||||
# Debit the credit wallet for ETL page processing. Default FALSE keeps ETL
|
||||
# effectively free for self-hosted installs. 1 page == MICROS_PER_PAGE
|
||||
# micro-USD ($0.001); premium ETL mode is 10x.
|
||||
# ETL_CREDIT_BILLING_ENABLED=FALSE
|
||||
# MICROS_PER_PAGE=1000
|
||||
# Premium credit quota per registered user, in micro-USD (default: $5).
|
||||
# Premium turns are debited at the actual per-call provider cost reported
|
||||
# by LiteLLM. Only applies to models with billing_tier=premium.
|
||||
# PREMIUM_CREDIT_MICROS_LIMIT=5000000
|
||||
|
||||
# Safety ceiling on per-call premium reservation, in micro-USD ($1.00 default).
|
||||
# QUOTA_MAX_RESERVE_MICROS=1000000
|
||||
|
|
@ -456,36 +453,3 @@ NOLOGIN_MODE_ENABLED=FALSE
|
|||
# RESIDENTIAL_PROXY_HOSTNAME=
|
||||
# RESIDENTIAL_PROXY_LOCATION=
|
||||
# RESIDENTIAL_PROXY_TYPE=1
|
||||
|
||||
# ==============================================================================
|
||||
# DEV / DEPS-ONLY COMPOSE OVERRIDES
|
||||
# These are only needed for docker-compose.dev.yml or docker-compose.deps-only.yml.
|
||||
# Production Docker exposes Caddy only; raw app ports below do not affect
|
||||
# docker-compose.yml.
|
||||
# ==============================================================================
|
||||
|
||||
# -- pgAdmin (database GUI, dev/deps-only only) --
|
||||
# PGADMIN_PORT=5050
|
||||
# PGADMIN_DEFAULT_EMAIL=admin@surfsense.com
|
||||
# PGADMIN_DEFAULT_PASSWORD=surfsense
|
||||
|
||||
# -- Redis exposed port (dev/deps-only only; Redis is internal-only in prod) --
|
||||
# REDIS_PORT=6379
|
||||
|
||||
# -- SearXNG exposed port (dev/deps-only only; internal-only in prod) --
|
||||
# SEARXNG_PORT=8888
|
||||
|
||||
# -- WhatsApp bridge exposed port (dev/hybrid only; prod keeps it Docker-internal) --
|
||||
# WHATSAPP_BRIDGE_PORT=9929
|
||||
|
||||
# -- Raw app ports (dev/deps-only only; prod exposes Caddy instead) --
|
||||
# BACKEND_PORT=8000
|
||||
# FRONTEND_PORT=3000
|
||||
# ZERO_CACHE_PORT=4848
|
||||
|
||||
# -- Frontend runtime flags (prod and dev compose) --
|
||||
# The frontend reads these at request time in Docker; no NEXT_PUBLIC_* rebuild
|
||||
# or startup substitution is required.
|
||||
# AUTH_TYPE=LOCAL
|
||||
# ETL_SERVICE=DOCLING
|
||||
# DEPLOYMENT_MODE=self-hosted
|
||||
|
|
|
|||
|
|
@ -106,7 +106,6 @@ services:
|
|||
volumes:
|
||||
- ../surfsense_backend/app:/app/app
|
||||
- shared_temp:/shared_tmp
|
||||
- object_store:/app/.local_object_store
|
||||
env_file:
|
||||
- ../surfsense_backend/.env
|
||||
extra_hosts:
|
||||
|
|
@ -120,7 +119,6 @@ services:
|
|||
- PYTHONPATH=/app
|
||||
- UVICORN_LOOP=asyncio
|
||||
- UNSTRUCTURED_HAS_PATCHED_LOOP=1
|
||||
- FILE_STORAGE_LOCAL_PATH=/app/.local_object_store
|
||||
- LANGCHAIN_TRACING_V2=false
|
||||
- LANGSMITH_TRACING=false
|
||||
- AUTH_TYPE=${AUTH_TYPE:-LOCAL}
|
||||
|
|
@ -173,7 +171,6 @@ services:
|
|||
volumes:
|
||||
- ../surfsense_backend/app:/app/app
|
||||
- shared_temp:/shared_tmp
|
||||
- object_store:/app/.local_object_store
|
||||
env_file:
|
||||
- ../surfsense_backend/.env
|
||||
extra_hosts:
|
||||
|
|
@ -185,7 +182,6 @@ services:
|
|||
- REDIS_APP_URL=${REDIS_URL:-redis://redis:6379/0}
|
||||
- CELERY_TASK_DEFAULT_QUEUE=surfsense
|
||||
- PYTHONPATH=/app
|
||||
- FILE_STORAGE_LOCAL_PATH=/app/.local_object_store
|
||||
- SEARXNG_DEFAULT_HOST=${SEARXNG_DEFAULT_HOST:-http://searxng:8080}
|
||||
- SERVICE_ROLE=worker
|
||||
depends_on:
|
||||
|
|
@ -257,15 +253,16 @@ services:
|
|||
frontend:
|
||||
build:
|
||||
context: ../surfsense_web
|
||||
args:
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_URL: ${NEXT_PUBLIC_FASTAPI_BACKEND_URL:-http://localhost:8000}
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE: ${NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE:-LOCAL}
|
||||
NEXT_PUBLIC_ETL_SERVICE: ${NEXT_PUBLIC_ETL_SERVICE:-DOCLING}
|
||||
NEXT_PUBLIC_ZERO_CACHE_URL: ${NEXT_PUBLIC_ZERO_CACHE_URL:-http://localhost:${ZERO_CACHE_PORT:-4848}}
|
||||
NEXT_PUBLIC_DEPLOYMENT_MODE: ${NEXT_PUBLIC_DEPLOYMENT_MODE:-self-hosted}
|
||||
ports:
|
||||
- "${FRONTEND_PORT:-3000}:3000"
|
||||
env_file:
|
||||
- ../surfsense_web/.env
|
||||
environment:
|
||||
AUTH_TYPE: ${AUTH_TYPE:-LOCAL}
|
||||
ETL_SERVICE: ${ETL_SERVICE:-DOCLING}
|
||||
DEPLOYMENT_MODE: ${DEPLOYMENT_MODE:-self-hosted}
|
||||
SURFSENSE_BACKEND_INTERNAL_URL: http://backend:8000
|
||||
depends_on:
|
||||
backend:
|
||||
condition: service_healthy
|
||||
|
|
@ -281,8 +278,6 @@ volumes:
|
|||
name: surfsense-dev-redis
|
||||
shared_temp:
|
||||
name: surfsense-dev-shared-temp
|
||||
object_store:
|
||||
name: surfsense-dev-object-store
|
||||
zero_cache_data:
|
||||
name: surfsense-dev-zero-cache
|
||||
whatsapp_sessions:
|
||||
|
|
|
|||
|
|
@ -1,54 +0,0 @@
|
|||
# =============================================================================
|
||||
# SurfSense — Optional Caddy reverse-proxy overlay
|
||||
# =============================================================================
|
||||
# Usage (from docker/):
|
||||
# PROXY_HTTP_PORT=8080 SURFSENSE_PUBLIC_URL=http://localhost:8080 \
|
||||
# docker compose -f docker-compose.yml -f docker-compose.proxy.yml up -d
|
||||
#
|
||||
# This overlay is for validation and custom deployments. The production
|
||||
# docker-compose.yml includes Caddy by default.
|
||||
# =============================================================================
|
||||
|
||||
services:
|
||||
backend:
|
||||
ports:
|
||||
- "${BACKEND_PORT:-8929}:8000"
|
||||
|
||||
zero-cache:
|
||||
ports:
|
||||
- "${ZERO_CACHE_PORT:-5929}:4848"
|
||||
|
||||
frontend:
|
||||
ports:
|
||||
- "${FRONTEND_PORT:-3929}:3000"
|
||||
|
||||
proxy:
|
||||
image: caddy:2-alpine
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "${PROXY_HTTP_PORT:-8080}:80"
|
||||
- "${PROXY_HTTPS_PORT:-8443}:443"
|
||||
volumes:
|
||||
- ./proxy/Caddyfile:/etc/caddy/Caddyfile:ro
|
||||
- caddy_data:/data
|
||||
- caddy_config:/config
|
||||
environment:
|
||||
SURFSENSE_SITE_ADDRESS: ${SURFSENSE_SITE_ADDRESS:-:80}
|
||||
CERT_EMAIL: ${CERT_EMAIL:-}
|
||||
CERT_ACME_CA: ${CERT_ACME_CA:-https://acme-v02.api.letsencrypt.org/directory}
|
||||
CERT_ACME_DNS: ${CERT_ACME_DNS:-}
|
||||
TRUSTED_PROXIES: ${TRUSTED_PROXIES:-0.0.0.0/0}
|
||||
SURFSENSE_MAX_BODY_SIZE: ${SURFSENSE_MAX_BODY_SIZE:-5GB}
|
||||
depends_on:
|
||||
frontend:
|
||||
condition: service_started
|
||||
backend:
|
||||
condition: service_healthy
|
||||
zero-cache:
|
||||
condition: service_healthy
|
||||
|
||||
volumes:
|
||||
caddy_data:
|
||||
name: surfsense-caddy-data
|
||||
caddy_config:
|
||||
name: surfsense-caddy-config
|
||||
|
|
@ -94,42 +94,12 @@ services:
|
|||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
# Single public entry point for the Docker stack. Comment this service out
|
||||
# only if you front SurfSense with your own reverse proxy.
|
||||
proxy:
|
||||
image: caddy:2-alpine
|
||||
# For DNS-01/wildcard certificates, replace image with:
|
||||
# build: ./proxy
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "${LISTEN_HTTP_PORT:-3929}:80"
|
||||
- "${LISTEN_HTTPS_PORT:-443}:443"
|
||||
volumes:
|
||||
- ./proxy/Caddyfile:/etc/caddy/Caddyfile:ro
|
||||
- caddy_data:/data
|
||||
- caddy_config:/config
|
||||
environment:
|
||||
SURFSENSE_SITE_ADDRESS: ${SURFSENSE_SITE_ADDRESS:-:80}
|
||||
CERT_EMAIL: ${CERT_EMAIL:-}
|
||||
CERT_ACME_CA: ${CERT_ACME_CA:-https://acme-v02.api.letsencrypt.org/directory}
|
||||
CERT_ACME_DNS: ${CERT_ACME_DNS:-}
|
||||
TRUSTED_PROXIES: ${TRUSTED_PROXIES:-0.0.0.0/0}
|
||||
SURFSENSE_MAX_BODY_SIZE: ${SURFSENSE_MAX_BODY_SIZE:-5GB}
|
||||
depends_on:
|
||||
frontend:
|
||||
condition: service_started
|
||||
backend:
|
||||
condition: service_healthy
|
||||
zero-cache:
|
||||
condition: service_healthy
|
||||
|
||||
backend:
|
||||
image: ghcr.io/modsetter/surfsense-backend:${SURFSENSE_VERSION:-latest}${SURFSENSE_VARIANT:+-${SURFSENSE_VARIANT}}
|
||||
expose:
|
||||
- "8000"
|
||||
ports:
|
||||
- "${BACKEND_PORT:-8929}:8000"
|
||||
volumes:
|
||||
- shared_temp:/shared_tmp
|
||||
- object_store:/app/.local_object_store
|
||||
env_file:
|
||||
- .env
|
||||
extra_hosts:
|
||||
|
|
@ -143,9 +113,7 @@ services:
|
|||
PYTHONPATH: /app
|
||||
UVICORN_LOOP: asyncio
|
||||
UNSTRUCTURED_HAS_PATCHED_LOOP: "1"
|
||||
FILE_STORAGE_LOCAL_PATH: /app/.local_object_store
|
||||
NEXT_FRONTEND_URL: ${NEXT_FRONTEND_URL:-${SURFSENSE_PUBLIC_URL:-http://localhost:${LISTEN_HTTP_PORT:-3929}}}
|
||||
BACKEND_URL: ${BACKEND_URL:-${SURFSENSE_PUBLIC_URL:-http://localhost:${LISTEN_HTTP_PORT:-3929}}}
|
||||
NEXT_FRONTEND_URL: ${NEXT_FRONTEND_URL:-http://localhost:${FRONTEND_PORT:-3929}}
|
||||
SEARXNG_DEFAULT_HOST: ${SEARXNG_DEFAULT_HOST:-http://searxng:8080}
|
||||
WHATSAPP_BRIDGE_URL: ${WHATSAPP_BRIDGE_URL:-http://whatsapp-bridge:9929}
|
||||
# Daytona Sandbox – uncomment and set credentials to enable cloud code execution
|
||||
|
|
@ -197,7 +165,6 @@ services:
|
|||
image: ghcr.io/modsetter/surfsense-backend:${SURFSENSE_VERSION:-latest}${SURFSENSE_VARIANT:+-${SURFSENSE_VARIANT}}
|
||||
volumes:
|
||||
- shared_temp:/shared_tmp
|
||||
- object_store:/app/.local_object_store
|
||||
env_file:
|
||||
- .env
|
||||
extra_hosts:
|
||||
|
|
@ -209,7 +176,6 @@ services:
|
|||
REDIS_APP_URL: ${REDIS_URL:-redis://redis:6379/0}
|
||||
CELERY_TASK_DEFAULT_QUEUE: surfsense
|
||||
PYTHONPATH: /app
|
||||
FILE_STORAGE_LOCAL_PATH: /app/.local_object_store
|
||||
SEARXNG_DEFAULT_HOST: ${SEARXNG_DEFAULT_HOST:-http://searxng:8080}
|
||||
SERVICE_ROLE: worker
|
||||
depends_on:
|
||||
|
|
@ -251,8 +217,8 @@ services:
|
|||
|
||||
zero-cache:
|
||||
image: rocicorp/zero:1.4.0
|
||||
expose:
|
||||
- "4848"
|
||||
ports:
|
||||
- "${ZERO_CACHE_PORT:-5929}:4848"
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
environment:
|
||||
|
|
@ -286,13 +252,16 @@ services:
|
|||
|
||||
frontend:
|
||||
image: ghcr.io/modsetter/surfsense-web:${SURFSENSE_VERSION:-latest}
|
||||
expose:
|
||||
- "3000"
|
||||
ports:
|
||||
- "${FRONTEND_PORT:-3929}:3000"
|
||||
environment:
|
||||
AUTH_TYPE: ${AUTH_TYPE:-LOCAL}
|
||||
ETL_SERVICE: ${ETL_SERVICE:-DOCLING}
|
||||
DEPLOYMENT_MODE: ${DEPLOYMENT_MODE:-self-hosted}
|
||||
SURFSENSE_BACKEND_INTERNAL_URL: http://backend:8000
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_URL: ${NEXT_PUBLIC_FASTAPI_BACKEND_URL:-http://localhost:${BACKEND_PORT:-8929}}
|
||||
NEXT_PUBLIC_ZERO_CACHE_URL: ${NEXT_PUBLIC_ZERO_CACHE_URL:-http://localhost:${ZERO_CACHE_PORT:-5929}}
|
||||
NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE: ${AUTH_TYPE:-LOCAL}
|
||||
NEXT_PUBLIC_ETL_SERVICE: ${ETL_SERVICE:-DOCLING}
|
||||
NEXT_PUBLIC_DEPLOYMENT_MODE: ${DEPLOYMENT_MODE:-self-hosted}
|
||||
NEXT_PUBLIC_WHATSAPP_DISPLAY_PHONE_NUMBER: ${WHATSAPP_SHARED_DISPLAY_PHONE_NUMBER:-}
|
||||
FASTAPI_BACKEND_INTERNAL_URL: ${FASTAPI_BACKEND_INTERNAL_URL:-http://backend:8000}
|
||||
labels:
|
||||
- "com.centurylinklabs.watchtower.enable=true"
|
||||
depends_on:
|
||||
|
|
@ -309,13 +278,7 @@ volumes:
|
|||
name: surfsense-redis
|
||||
shared_temp:
|
||||
name: surfsense-shared-temp
|
||||
object_store:
|
||||
name: surfsense-object-store
|
||||
zero_cache_data:
|
||||
name: surfsense-zero-cache
|
||||
caddy_data:
|
||||
name: surfsense-caddy-data
|
||||
caddy_config:
|
||||
name: surfsense-caddy-config
|
||||
whatsapp_sessions:
|
||||
name: surfsense-whatsapp-sessions
|
||||
|
|
|
|||
|
|
@ -1,45 +0,0 @@
|
|||
{
|
||||
# Optional ACME/global settings. These are harmless in the default :80
|
||||
# localhost mode and become active when SURFSENSE_SITE_ADDRESS is a domain.
|
||||
{$CERT_EMAIL}
|
||||
acme_ca {$CERT_ACME_CA:https://acme-v02.api.letsencrypt.org/directory}
|
||||
{$CERT_ACME_DNS}
|
||||
servers {
|
||||
client_ip_headers X-Forwarded-For X-Real-IP
|
||||
trusted_proxies static {$TRUSTED_PROXIES:0.0.0.0/0}
|
||||
}
|
||||
}
|
||||
|
||||
(surfsense_proxy) {
|
||||
request_body {
|
||||
max_size {$SURFSENSE_MAX_BODY_SIZE:5GB}
|
||||
}
|
||||
|
||||
# Frontend-owned auth page (the post-login token handler). More specific than
|
||||
# /auth/*, so Caddy's matcher-specificity sort routes it here, not to backend.
|
||||
reverse_proxy /auth/callback* frontend:3000
|
||||
|
||||
# Backend auth routes (FastAPI Users + OAuth helpers).
|
||||
reverse_proxy /auth/* backend:8000
|
||||
|
||||
# Backend user profile routes (FastAPI Users users router, mounted at /users).
|
||||
reverse_proxy /users/* backend:8000
|
||||
|
||||
# Backend REST, streaming, connector OAuth, and messaging gateway endpoints.
|
||||
# FastAPI already serves /api/v1, so the path is forwarded unchanged.
|
||||
reverse_proxy /api/v1/* backend:8000 {
|
||||
flush_interval -1
|
||||
}
|
||||
|
||||
# Zero accepts a single path-component base URL (Zero >= 0.6).
|
||||
# Preserve /zero so browser cacheURL can be ${SURFSENSE_PUBLIC_URL}/zero.
|
||||
reverse_proxy /zero/* zero-cache:4848
|
||||
|
||||
# Next.js app and frontend-owned API routes:
|
||||
# /api/zero/*, /api/search, /api/contact, etc.
|
||||
reverse_proxy /* frontend:3000
|
||||
}
|
||||
|
||||
{$SURFSENSE_SITE_ADDRESS::80} {
|
||||
import surfsense_proxy
|
||||
}
|
||||
|
|
@ -1,10 +0,0 @@
|
|||
FROM caddy:2-builder-alpine AS builder
|
||||
|
||||
RUN xcaddy build \
|
||||
--with github.com/caddy-dns/cloudflare \
|
||||
--with github.com/caddy-dns/digitalocean
|
||||
|
||||
FROM caddy:2-alpine
|
||||
|
||||
COPY --from=builder /usr/bin/caddy /usr/bin/caddy
|
||||
COPY Caddyfile /etc/caddy/Caddyfile
|
||||
|
|
@ -17,14 +17,10 @@
|
|||
# into the new PostgreSQL 17 stack. The user runs one command for both.
|
||||
# =============================================================================
|
||||
|
||||
# NOTE: Do not use [ValidateSet()] (or other validation attributes without a
|
||||
# valid default) on these params. When the script is piped into iex, PowerShell
|
||||
# applies the attributes to variables in the caller's scope, and an empty
|
||||
# $Variant fails ValidateSet with a ValidationMetadataException. Validate
|
||||
# manually below instead.
|
||||
param(
|
||||
[switch]$NoWatchtower,
|
||||
[int]$WatchtowerInterval = 86400,
|
||||
[ValidateSet("cpu", "cuda", "cuda126")]
|
||||
[string]$Variant,
|
||||
[string]$GpuCount,
|
||||
[switch]$Quiet
|
||||
|
|
@ -44,11 +40,6 @@ $MigrationMode = $false
|
|||
$SetupWatchtower = -not $NoWatchtower
|
||||
$WatchtowerContainer = "watchtower"
|
||||
|
||||
if ($Variant -and $Variant -notin @("cpu", "cuda", "cuda126")) {
|
||||
Write-Host "[SurfSense] ERROR: Invalid -Variant '$Variant'. Use 'cpu', 'cuda', or 'cuda126'." -ForegroundColor Red
|
||||
exit 1
|
||||
}
|
||||
|
||||
if ($GpuCount -and $GpuCount -notmatch '^([0-9]+|all)$') {
|
||||
Write-Host "[SurfSense] ERROR: Invalid -GpuCount '$GpuCount'. Use a number or 'all'." -ForegroundColor Red
|
||||
exit 1
|
||||
|
|
|
|||
|
|
@ -333,13 +333,11 @@ step "Downloading SurfSense files"
|
|||
info "Installation directory: ${INSTALL_DIR}"
|
||||
mkdir -p "${INSTALL_DIR}/scripts"
|
||||
mkdir -p "${INSTALL_DIR}/searxng"
|
||||
mkdir -p "${INSTALL_DIR}/proxy"
|
||||
|
||||
FILES=(
|
||||
"docker/docker-compose.yml:docker-compose.yml"
|
||||
"docker/docker-compose.gpu.yml:docker-compose.gpu.yml"
|
||||
"docker/.env.example:.env.example"
|
||||
"docker/proxy/Caddyfile:proxy/Caddyfile"
|
||||
"docker/postgresql.conf:postgresql.conf"
|
||||
"docker/scripts/migrate-database.sh:scripts/migrate-database.sh"
|
||||
"docker/searxng/settings.yml:searxng/settings.yml"
|
||||
|
|
@ -534,12 +532,9 @@ _variant_display=$(grep '^SURFSENSE_VARIANT=' "${INSTALL_DIR}/.env" 2>/dev/null
|
|||
_variant_display="${_variant_display:-cpu}"
|
||||
step "SurfSense is now installed [${_version_display}]"
|
||||
|
||||
_public_url=$(grep '^SURFSENSE_PUBLIC_URL=' "${INSTALL_DIR}/.env" 2>/dev/null | cut -d= -f2- | tr -d '"' | head -1 || true)
|
||||
_public_url="${_public_url:-http://localhost:3929}"
|
||||
|
||||
info " SurfSense: ${_public_url}"
|
||||
info " Backend: ${_public_url}/api/v1"
|
||||
info " Zero sync: ${_public_url}/zero"
|
||||
info " Frontend: http://localhost:3929"
|
||||
info " Backend: http://localhost:8929"
|
||||
info " API Docs: http://localhost:8929/docs"
|
||||
info ""
|
||||
info " Config: ${INSTALL_DIR}/.env"
|
||||
info " Variant: ${_variant_display}"
|
||||
|
|
|
|||
|
|
@ -1,20 +1,5 @@
|
|||
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense
|
||||
|
||||
# --- Database startup / safety knobs (optional) ---
|
||||
# Run extension/table/index DDL on app startup. Set FALSE when schema is owned
|
||||
# exclusively by Alembic migrations.
|
||||
# DB_BOOTSTRAP_ON_STARTUP=TRUE
|
||||
# lock_timeout (ms) for boot-time DDL so a contended CREATE INDEX/TABLE fails
|
||||
# fast instead of hanging the FastAPI lifespan behind another transaction.
|
||||
# DB_DDL_LOCK_TIMEOUT_MS=5000
|
||||
# idle_in_transaction_session_timeout (ms) so an abandoned "idle in transaction"
|
||||
# session can't wedge the DB indefinitely. 0 disables. (asyncpg only)
|
||||
# DB_IDLE_IN_TX_TIMEOUT_MS=900000
|
||||
# Same, for the Celery worker engine (long ingestion/podcast/video tasks). If a
|
||||
# task hasn't touched the DB in this window it's treated as orphaned and dropped.
|
||||
# 0 disables. (asyncpg only)
|
||||
# DB_CELERY_IDLE_IN_TX_TIMEOUT_MS=3600000
|
||||
|
||||
# Deployment environment: dev or production
|
||||
SURFSENSE_ENV=dev
|
||||
|
||||
|
|
@ -30,9 +15,12 @@ CELERY_TASK_DEFAULT_QUEUE=surfsense
|
|||
# Optional: TTL in seconds for connector indexing lock key
|
||||
# CONNECTOR_INDEXING_LOCK_TTL_SECONDS=28800
|
||||
|
||||
# Messaging Gateway: disabled by default; set TRUE to enable chat integrations.
|
||||
# Supported messaging gateways: WhatsApp, Telegram, Discord, Slack
|
||||
# GATEWAY_ENABLED=TRUE
|
||||
# Messaging Gateway (global)
|
||||
# GATEWAY_ENABLED: master switch for ALL messaging gateway channels (Telegram, WhatsApp,
|
||||
# Slack, Discord). When FALSE, no gateway background workers/supervisors start and all
|
||||
# gateway HTTP routes (webhooks, OAuth callbacks, pairing) return 404. Set per-channel
|
||||
# flags below to control individual platforms once the gateway is enabled.
|
||||
GATEWAY_ENABLED=TRUE
|
||||
|
||||
# Telegram Gateway
|
||||
# TELEGRAM_WEBHOOK_SECRET must be 1-256 chars and contain only A-Z, a-z, 0-9, _ or -
|
||||
|
|
@ -87,16 +75,23 @@ SECRET_KEY=SECRET
|
|||
|
||||
NEXT_FRONTEND_URL=http://localhost:3000
|
||||
|
||||
# Stripe Checkout for the unified credit wallet.
|
||||
# Each pack grants STRIPE_CREDIT_MICROS_PER_UNIT micro-USD of credit
|
||||
# (default 1_000_000 = $1.00). Both ETL page processing and premium model
|
||||
# turns are billed against this single balance at actual provider cost.
|
||||
# Stripe Checkout for pay-as-you-go page packs
|
||||
# Configure STRIPE_PRICE_ID to point at your 1,000-page price in Stripe.
|
||||
# Pages granted per purchase = quantity * STRIPE_PAGES_PER_UNIT.
|
||||
STRIPE_SECRET_KEY=sk_test_...
|
||||
STRIPE_WEBHOOK_SECRET=whsec_...
|
||||
STRIPE_CREDIT_PRICE_ID=price_...
|
||||
STRIPE_CREDIT_MICROS_PER_UNIT=1000000
|
||||
STRIPE_PRICE_ID=price_...
|
||||
STRIPE_PAGES_PER_UNIT=1000
|
||||
# Set FALSE to disable new checkout session creation temporarily
|
||||
STRIPE_CREDIT_BUYING_ENABLED=FALSE
|
||||
STRIPE_PAGE_BUYING_ENABLED=TRUE
|
||||
|
||||
# Premium credit purchases via Stripe (for premium-tier model usage).
|
||||
# Each pack grants STRIPE_CREDIT_MICROS_PER_UNIT micro-USD of credit
|
||||
# (default 1_000_000 = $1.00). Premium turns are billed at the actual
|
||||
# per-call provider cost reported by LiteLLM.
|
||||
STRIPE_TOKEN_BUYING_ENABLED=FALSE
|
||||
STRIPE_PREMIUM_TOKEN_PRICE_ID=price_...
|
||||
STRIPE_CREDIT_MICROS_PER_UNIT=1000000
|
||||
|
||||
# Periodic Stripe safety net for purchases left in PENDING (minutes old)
|
||||
STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10
|
||||
|
|
@ -226,25 +221,15 @@ VIDEO_PRESENTATION_FPS=30
|
|||
VIDEO_PRESENTATION_DEFAULT_DURATION_IN_FRAMES=300
|
||||
|
||||
|
||||
# Unified credit wallet starting balance for new users, in micro-USD
|
||||
# (default: 5,000,000 == $5.00). The same balance funds ETL page processing
|
||||
# and premium model calls, debited at actual provider cost.
|
||||
DEFAULT_CREDIT_MICROS_BALANCE=5000000
|
||||
# (Optional) Maximum pages limit per user for ETL services (default: `999999999` for unlimited in OSS version)
|
||||
PAGES_LIMIT=500
|
||||
|
||||
# Debit the credit wallet for ETL page processing. Default FALSE keeps ETL
|
||||
# effectively free for self-hosted/OSS installs; hosted deployments set TRUE.
|
||||
# 1 page == MICROS_PER_PAGE micro-USD ($0.001); premium ETL mode is 10x.
|
||||
ETL_CREDIT_BILLING_ENABLED=FALSE
|
||||
MICROS_PER_PAGE=1000
|
||||
|
||||
# Low-balance warning threshold (micro-USD), surfaced to the UI. Default $0.50.
|
||||
CREDIT_LOW_BALANCE_WARNING_MICROS=500000
|
||||
|
||||
# Auto-reload: automatically top up via a saved Stripe card when the balance
|
||||
# drops below the user-chosen threshold. Off by default.
|
||||
AUTO_RELOAD_ENABLED=FALSE
|
||||
AUTO_RELOAD_MIN_AMOUNT_MICROS=1000000
|
||||
AUTO_RELOAD_COOLDOWN_MINUTES=10
|
||||
# Premium credit quota per registered user, in micro-USD
|
||||
# (default: 5,000,000 == $5.00 of credit). Premium turns are debited at the
|
||||
# actual per-call provider cost reported by LiteLLM, so cheap and expensive
|
||||
# models bill proportionally. Applies only to models with
|
||||
# billing_tier=premium in global_llm_config.yaml.
|
||||
PREMIUM_CREDIT_MICROS_LIMIT=5000000
|
||||
|
||||
# Safety ceiling on per-call premium reservation, in micro-USD.
|
||||
# stream_new_chat estimates an upper-bound cost from the model's
|
||||
|
|
@ -323,42 +308,6 @@ FILE_STORAGE_BACKEND=local
|
|||
# AZURE_STORAGE_CONNECTION_STRING=DefaultEndpointsProtocol=https;AccountName=...;AccountKey=...;EndpointSuffix=core.windows.net
|
||||
# AZURE_STORAGE_CONTAINER=surfsense-documents
|
||||
|
||||
# ETL Parse Cache
|
||||
# Reuse parser output for identical file bytes across workspaces (skips paid
|
||||
# re-parsing on LlamaCloud / Azure DI / Unstructured). Off by default.
|
||||
ETL_CACHE_ENABLED=false
|
||||
# Bump to invalidate all cached entries after a parser/behaviour change.
|
||||
# ETL_CACHE_PARSER_VERSION=1
|
||||
# Prune entries unused for this many days.
|
||||
# ETL_CACHE_TTL_DAYS=90
|
||||
# Soft cap on total cached markdown; coldest entries are evicted past it.
|
||||
# ETL_CACHE_MAX_TOTAL_MB=5120
|
||||
# Rows deleted per eviction pass.
|
||||
# ETL_CACHE_EVICTION_BATCH=500
|
||||
# Optional dedicated blob storage; unset reuses the main file storage backend.
|
||||
# ETL_CACHE_STORAGE_BACKEND=azure
|
||||
# ETL_CACHE_STORAGE_CONTAINER=surfsense-etl-cache
|
||||
# ETL_CACHE_STORAGE_LOCAL_PATH=/var/lib/surfsense/etl-cache
|
||||
|
||||
# Embedding Cache
|
||||
# Reuse chunk+embedding output for identical markdown across workspaces (skips
|
||||
# re-chunking and re-embedding). Blobs share the ETL_CACHE_STORAGE_* backend.
|
||||
# Off by default.
|
||||
EMBEDDING_CACHE_ENABLED=false
|
||||
# Bump to invalidate all cached embedding sets after a chunker change.
|
||||
# EMBEDDING_CACHE_CHUNKER_VERSION=1
|
||||
# Prune entries unused for this many days.
|
||||
# EMBEDDING_CACHE_TTL_DAYS=90
|
||||
# Soft cap on total cached embeddings; coldest entries are evicted past it.
|
||||
# EMBEDDING_CACHE_MAX_TOTAL_MB=5120
|
||||
# Rows deleted per eviction pass.
|
||||
# EMBEDDING_CACHE_EVICTION_BATCH=500
|
||||
|
||||
# Incremental re-indexing: on document edits, keep chunks whose text is
|
||||
# unchanged (reusing their embeddings) and embed only new/changed ones.
|
||||
# Set to false to fall back to delete-all + full re-embed (kill switch).
|
||||
# CHUNK_RECONCILE_ENABLED=true
|
||||
|
||||
# Daytona Sandbox (isolated code execution)
|
||||
# DAYTONA_SANDBOX_ENABLED=FALSE
|
||||
# DAYTONA_API_KEY=your-daytona-api-key
|
||||
|
|
@ -398,9 +347,7 @@ LANGSMITH_PROJECT=surfsense
|
|||
# SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false # adds a per-turn LLM call
|
||||
|
||||
# Observability - OTel
|
||||
# Disabled by default. Uncomment to enable OpenTelemetry.
|
||||
# SURFSENSE_ENABLE_OTEL=true
|
||||
|
||||
# SURFSENSE_ENABLE_OTEL=false
|
||||
# OpenTelemetry - endpoint enables export; absent = no-op.
|
||||
# Production should point at an OTel Collector. For local docker-compose.dev.yml,
|
||||
# use http://otel-lgtm:4317 instead.
|
||||
|
|
|
|||
4
surfsense_backend/.gitignore
vendored
4
surfsense_backend/.gitignore
vendored
|
|
@ -1,12 +1,12 @@
|
|||
.env
|
||||
.venv
|
||||
venv/
|
||||
/data/
|
||||
data/
|
||||
.local_object_store/
|
||||
__pycache__/
|
||||
.flashrank_cache
|
||||
surf_new_backend.egg-info/
|
||||
/podcasts/
|
||||
podcasts/
|
||||
video_presentation_audio/
|
||||
sandbox_files/
|
||||
temp_audio/
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ Revision ID: 138
|
|||
Revises: 137
|
||||
Create Date: 2026-04-30
|
||||
|
||||
Add a single thread-level column to persist the Auto model pin:
|
||||
Add a single thread-level column to persist the Auto (Fastest) model pin:
|
||||
- pinned_llm_config_id: concrete resolved global LLM config id used for this
|
||||
thread. NULL means "no pin; Auto will resolve on next turn".
|
||||
|
||||
|
|
|
|||
|
|
@ -1,235 +0,0 @@
|
|||
"""unify page limits and premium credits into a single credit_micros_balance wallet
|
||||
|
||||
Collapses the two separate economies (ETL ``pages_limit``/``pages_used`` and
|
||||
premium ``premium_credit_micros_limit``/``premium_credit_micros_used``) into one
|
||||
USD-micro wallet column ``user.credit_micros_balance`` that decreases on use and
|
||||
increases on purchase / grant. ``premium_credit_micros_reserved`` is kept (renamed
|
||||
to ``credit_micros_reserved``) for in-flight reservation holds.
|
||||
|
||||
Backfill (per existing user row):
|
||||
|
||||
balance = GREATEST(0, premium_credit_micros_limit - premium_credit_micros_used)
|
||||
+ (CASE WHEN pages_limit < 100000000
|
||||
THEN GREATEST(0, pages_limit - pages_used) * 1000
|
||||
ELSE 0 END)
|
||||
|
||||
The ``pages_limit < 100000000`` guard skips the OSS "unlimited" default
|
||||
(``PAGES_LIMIT=999999999``) so self-hosters don't get a ~$1M credit grant.
|
||||
1 page == 1000 micros == $0.001 (matches the prior $1 / 1000 pages price).
|
||||
|
||||
Table / type renames:
|
||||
|
||||
premium_token_purchases -> credit_purchases
|
||||
premiumtokenpurchasestatus (enum)-> creditpurchasestatus
|
||||
user_incentive_tasks.pages_awarded -> credit_micros_awarded (backfilled * 1000)
|
||||
|
||||
The "user" table is in zero_publication's column list, so this migration updates
|
||||
the publication via ``apply_publication`` (canonical reconcile, per migration 155)
|
||||
BEFORE dropping the old columns it referenced.
|
||||
|
||||
IMPORTANT - before AND after running this migration (same as migration 140):
|
||||
1. Stop zero-cache (it holds replication locks that will deadlock DDL)
|
||||
2. Run: alembic upgrade head
|
||||
3. Delete / reset the zero-cache data volume
|
||||
4. Restart zero-cache (it will do a fresh initial sync)
|
||||
|
||||
Revision ID: 156
|
||||
Revises: 155
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
from app.zero_publication import apply_publication
|
||||
|
||||
revision: str = "156"
|
||||
down_revision: str | None = "155"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def _column_exists(conn, table: str, column: str) -> bool:
|
||||
return (
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"SELECT 1 FROM information_schema.columns "
|
||||
"WHERE table_name = :tbl AND column_name = :col "
|
||||
"AND table_schema = current_schema()"
|
||||
),
|
||||
{"tbl": table, "col": column},
|
||||
).fetchone()
|
||||
is not None
|
||||
)
|
||||
|
||||
|
||||
def _table_exists(conn, table: str) -> bool:
|
||||
return (
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"SELECT 1 FROM information_schema.tables "
|
||||
"WHERE table_name = :tbl AND table_schema = current_schema()"
|
||||
),
|
||||
{"tbl": table},
|
||||
).fetchone()
|
||||
is not None
|
||||
)
|
||||
|
||||
|
||||
def _terminate_blocked_pids(conn, table: str) -> None:
|
||||
"""Kill backends whose locks on *table* would block our AccessExclusiveLock."""
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"SELECT pg_terminate_backend(l.pid) "
|
||||
"FROM pg_locks l "
|
||||
"JOIN pg_class c ON c.oid = l.relation "
|
||||
"WHERE c.relname = :tbl "
|
||||
" AND l.pid != pg_backend_pid()"
|
||||
),
|
||||
{"tbl": table},
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 1. Add credit_micros_balance + backfill from both legacy economies.
|
||||
# ------------------------------------------------------------------
|
||||
if not _column_exists(conn, "user", "credit_micros_balance"):
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"credit_micros_balance",
|
||||
sa.BigInteger(),
|
||||
nullable=False,
|
||||
server_default="5000000",
|
||||
),
|
||||
)
|
||||
|
||||
# Backfill only when ALL legacy source columns are present (fresh DBs
|
||||
# created from current models won't have them).
|
||||
if all(
|
||||
_column_exists(conn, "user", col)
|
||||
for col in (
|
||||
"premium_credit_micros_limit",
|
||||
"premium_credit_micros_used",
|
||||
"pages_limit",
|
||||
"pages_used",
|
||||
)
|
||||
):
|
||||
conn.execute(
|
||||
sa.text(
|
||||
'UPDATE "user" SET credit_micros_balance = '
|
||||
"GREATEST(0, premium_credit_micros_limit - premium_credit_micros_used) "
|
||||
"+ (CASE WHEN pages_limit < 100000000 "
|
||||
" THEN GREATEST(0, pages_limit - pages_used) * 1000 "
|
||||
" ELSE 0 END)"
|
||||
)
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 2. Rename premium_credit_micros_reserved -> credit_micros_reserved.
|
||||
# ------------------------------------------------------------------
|
||||
if _column_exists(
|
||||
conn, "user", "premium_credit_micros_reserved"
|
||||
) and not _column_exists(conn, "user", "credit_micros_reserved"):
|
||||
op.alter_column(
|
||||
"user",
|
||||
"premium_credit_micros_reserved",
|
||||
new_column_name="credit_micros_reserved",
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 3. Reconcile the Zero publication to the new column list
|
||||
# (id, credit_micros_balance) BEFORE dropping the columns it used
|
||||
# to reference, otherwise Postgres rejects the column drops with
|
||||
# "cannot drop column ... referenced by publication".
|
||||
# ------------------------------------------------------------------
|
||||
conn.execute(sa.text("SET lock_timeout = '10s'"))
|
||||
_terminate_blocked_pids(conn, "user")
|
||||
apply_publication(conn)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 4. Drop the legacy quota columns now that nothing references them.
|
||||
# ------------------------------------------------------------------
|
||||
for col in (
|
||||
"premium_credit_micros_limit",
|
||||
"premium_credit_micros_used",
|
||||
"pages_limit",
|
||||
"pages_used",
|
||||
):
|
||||
if _column_exists(conn, "user", col):
|
||||
op.drop_column("user", col)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 5. Rename premium_token_purchases -> credit_purchases and its enum.
|
||||
# ------------------------------------------------------------------
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (
|
||||
SELECT 1 FROM pg_type t
|
||||
JOIN pg_namespace n ON n.oid = t.typnamespace
|
||||
WHERE t.typname = 'premiumtokenpurchasestatus'
|
||||
AND n.nspname = current_schema()
|
||||
)
|
||||
AND NOT EXISTS (
|
||||
SELECT 1 FROM pg_type t
|
||||
JOIN pg_namespace n ON n.oid = t.typnamespace
|
||||
WHERE t.typname = 'creditpurchasestatus'
|
||||
AND n.nspname = current_schema()
|
||||
)
|
||||
THEN
|
||||
ALTER TYPE premiumtokenpurchasestatus RENAME TO creditpurchasestatus;
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
|
||||
if _table_exists(conn, "premium_token_purchases") and not _table_exists(
|
||||
conn, "credit_purchases"
|
||||
):
|
||||
op.rename_table("premium_token_purchases", "credit_purchases")
|
||||
|
||||
# ``source`` distinguishes user checkout from auto-reload top-ups.
|
||||
if _table_exists(conn, "credit_purchases") and not _column_exists(
|
||||
conn, "credit_purchases", "source"
|
||||
):
|
||||
op.add_column(
|
||||
"credit_purchases",
|
||||
sa.Column(
|
||||
"source",
|
||||
sa.String(length=20),
|
||||
nullable=False,
|
||||
server_default="checkout",
|
||||
),
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 6. Rename user_incentive_tasks.pages_awarded -> credit_micros_awarded
|
||||
# and convert page counts to micros (1 page == 1000 micros).
|
||||
# ------------------------------------------------------------------
|
||||
if _column_exists(
|
||||
conn, "user_incentive_tasks", "pages_awarded"
|
||||
) and not _column_exists(conn, "user_incentive_tasks", "credit_micros_awarded"):
|
||||
op.alter_column(
|
||||
"user_incentive_tasks",
|
||||
"pages_awarded",
|
||||
new_column_name="credit_micros_awarded",
|
||||
type_=sa.BigInteger(),
|
||||
)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"UPDATE user_incentive_tasks "
|
||||
"SET credit_micros_awarded = credit_micros_awarded * 1000"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""No-op. This is a one-way data-model unification; the legacy split
|
||||
columns cannot be faithfully reconstructed from a single balance."""
|
||||
|
|
@ -1,92 +0,0 @@
|
|||
"""add auto-reload (off-session Stripe top-up) columns to user
|
||||
|
||||
Adds the saved-card + threshold plumbing that powers feature-flagged credit
|
||||
auto-reload (``AUTO_RELOAD_ENABLED``):
|
||||
|
||||
user.stripe_customer_id (text, nullable)
|
||||
user.auto_reload_enabled (bool, default false)
|
||||
user.auto_reload_threshold_micros (bigint, nullable)
|
||||
user.auto_reload_amount_micros (bigint, nullable)
|
||||
user.auto_reload_payment_method_id (text, nullable)
|
||||
user.auto_reload_failed_at (timestamptz, nullable)
|
||||
|
||||
None of these columns are part of the Zero publication (``USER_COLS`` is
|
||||
``["id", "credit_micros_balance"]``), so this migration does NOT touch the
|
||||
publication and is safe to run without the zero-cache stop/reset dance that
|
||||
migration 156 required.
|
||||
|
||||
The ``credit_purchases.source`` column (``checkout`` | ``auto_reload``) was
|
||||
already added in migration 156, so it is not repeated here.
|
||||
|
||||
Revision ID: 157
|
||||
Revises: 156
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "157"
|
||||
down_revision: str | None = "156"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def _column_exists(conn, table: str, column: str) -> bool:
|
||||
return (
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"SELECT 1 FROM information_schema.columns "
|
||||
"WHERE table_name = :tbl AND column_name = :col "
|
||||
"AND table_schema = current_schema()"
|
||||
),
|
||||
{"tbl": table, "col": column},
|
||||
).fetchone()
|
||||
is not None
|
||||
)
|
||||
|
||||
|
||||
_COLUMNS: list[tuple[str, sa.Column]] = [
|
||||
("stripe_customer_id", sa.Column("stripe_customer_id", sa.String(), nullable=True)),
|
||||
(
|
||||
"auto_reload_enabled",
|
||||
sa.Column(
|
||||
"auto_reload_enabled",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
),
|
||||
),
|
||||
(
|
||||
"auto_reload_threshold_micros",
|
||||
sa.Column("auto_reload_threshold_micros", sa.BigInteger(), nullable=True),
|
||||
),
|
||||
(
|
||||
"auto_reload_amount_micros",
|
||||
sa.Column("auto_reload_amount_micros", sa.BigInteger(), nullable=True),
|
||||
),
|
||||
(
|
||||
"auto_reload_payment_method_id",
|
||||
sa.Column("auto_reload_payment_method_id", sa.String(), nullable=True),
|
||||
),
|
||||
(
|
||||
"auto_reload_failed_at",
|
||||
sa.Column("auto_reload_failed_at", sa.TIMESTAMP(timezone=True), nullable=True),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
for name, column in _COLUMNS:
|
||||
if not _column_exists(conn, "user", name):
|
||||
op.add_column("user", column)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
for name, _ in reversed(_COLUMNS):
|
||||
if _column_exists(conn, "user", name):
|
||||
op.drop_column("user", name)
|
||||
|
|
@ -1,215 +0,0 @@
|
|||
"""evolve podcasts: expand status lifecycle and add brief/transcript/storage columns
|
||||
|
||||
Revision ID: 158
|
||||
Revises: 157
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "158"
|
||||
down_revision: str | None = "157"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
PUBLICATION_NAME = "zero_publication"
|
||||
TARGET_STATUS_LABELS = (
|
||||
"pending",
|
||||
"awaiting_brief",
|
||||
"drafting",
|
||||
"awaiting_review",
|
||||
"rendering",
|
||||
"ready",
|
||||
"failed",
|
||||
"cancelled",
|
||||
)
|
||||
LEGACY_STATUS_LABELS = ("pending", "generating", "ready", "failed")
|
||||
|
||||
|
||||
def _drop_podcasts_from_publication() -> None:
|
||||
"""Detach podcasts from zero_publication so status can be retyped.
|
||||
|
||||
Postgres refuses ``ALTER COLUMN ... TYPE`` on a column a publication
|
||||
depends on. Some databases reach this migration with podcasts already
|
||||
published (an interim apply_publication ran during 156); drop it here and
|
||||
let migration 159 reconcile the publication to the canonical shape.
|
||||
"""
|
||||
conn = op.get_bind()
|
||||
published = conn.execute(
|
||||
sa.text(
|
||||
"SELECT 1 FROM pg_publication_tables "
|
||||
"WHERE pubname = :publication "
|
||||
"AND schemaname = current_schema() AND tablename = 'podcasts'"
|
||||
),
|
||||
{"publication": PUBLICATION_NAME},
|
||||
).fetchone()
|
||||
if published:
|
||||
op.execute(f'ALTER PUBLICATION "{PUBLICATION_NAME}" DROP TABLE "podcasts";')
|
||||
|
||||
|
||||
def _enum_labels(type_name: str) -> list[str] | None:
|
||||
rows = (
|
||||
op.get_bind()
|
||||
.execute(
|
||||
sa.text(
|
||||
"SELECT e.enumlabel "
|
||||
"FROM pg_type t "
|
||||
"JOIN pg_namespace n ON n.oid = t.typnamespace "
|
||||
"JOIN pg_enum e ON e.enumtypid = t.oid "
|
||||
"WHERE n.nspname = current_schema() AND t.typname = :type_name "
|
||||
"ORDER BY e.enumsortorder"
|
||||
),
|
||||
{"type_name": type_name},
|
||||
)
|
||||
.fetchall()
|
||||
)
|
||||
if not rows:
|
||||
return None
|
||||
return [str(row[0]) for row in rows]
|
||||
|
||||
|
||||
def _column_type_name(table: str, column: str) -> str | None:
|
||||
row = (
|
||||
op.get_bind()
|
||||
.execute(
|
||||
sa.text(
|
||||
"SELECT udt_name "
|
||||
"FROM information_schema.columns "
|
||||
"WHERE table_schema = current_schema() "
|
||||
"AND table_name = :table AND column_name = :column"
|
||||
),
|
||||
{"table": table, "column": column},
|
||||
)
|
||||
.fetchone()
|
||||
)
|
||||
return str(row[0]) if row else None
|
||||
|
||||
|
||||
def _ensure_status_enum(
|
||||
*,
|
||||
desired_labels: tuple[str, ...],
|
||||
temporary_type: str,
|
||||
create_sql: str,
|
||||
alter_sql: str,
|
||||
default_value: str,
|
||||
) -> None:
|
||||
current_labels = _enum_labels("podcast_status")
|
||||
desired = list(desired_labels)
|
||||
|
||||
if current_labels != desired:
|
||||
if current_labels is None:
|
||||
if _enum_labels(temporary_type) is None:
|
||||
raise RuntimeError("podcast_status enum is missing")
|
||||
elif _enum_labels(temporary_type) is None:
|
||||
op.execute(f"ALTER TYPE podcast_status RENAME TO {temporary_type};")
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"podcast_status and its temporary replacement both exist"
|
||||
)
|
||||
|
||||
if _enum_labels("podcast_status") is None:
|
||||
op.execute(create_sql)
|
||||
|
||||
if _enum_labels("podcast_status") != desired:
|
||||
raise RuntimeError("podcast_status enum is not in the expected shape")
|
||||
|
||||
op.execute("ALTER TABLE podcasts ALTER COLUMN status DROP DEFAULT;")
|
||||
if _column_type_name("podcasts", "status") != "podcast_status":
|
||||
op.execute(alter_sql)
|
||||
op.execute(
|
||||
f"ALTER TABLE podcasts ALTER COLUMN status SET DEFAULT '{default_value}';"
|
||||
)
|
||||
|
||||
if _enum_labels(temporary_type) is not None:
|
||||
op.execute(f"DROP TYPE {temporary_type};")
|
||||
|
||||
|
||||
def _upgrade_status_enum() -> None:
|
||||
_ensure_status_enum(
|
||||
desired_labels=TARGET_STATUS_LABELS,
|
||||
temporary_type="podcast_status_old",
|
||||
create_sql="""
|
||||
CREATE TYPE podcast_status AS ENUM (
|
||||
'pending', 'awaiting_brief', 'drafting', 'awaiting_review',
|
||||
'rendering', 'ready', 'failed', 'cancelled'
|
||||
);
|
||||
""",
|
||||
alter_sql="""
|
||||
ALTER TABLE podcasts
|
||||
ALTER COLUMN status TYPE podcast_status
|
||||
USING (
|
||||
CASE status::text
|
||||
WHEN 'generating' THEN 'rendering'
|
||||
ELSE status::text
|
||||
END
|
||||
)::podcast_status;
|
||||
""",
|
||||
default_value="pending",
|
||||
)
|
||||
|
||||
|
||||
def _downgrade_status_enum() -> None:
|
||||
_ensure_status_enum(
|
||||
desired_labels=LEGACY_STATUS_LABELS,
|
||||
temporary_type="podcast_status_new",
|
||||
create_sql=(
|
||||
"CREATE TYPE podcast_status AS ENUM "
|
||||
"('pending', 'generating', 'ready', 'failed');"
|
||||
),
|
||||
alter_sql="""
|
||||
ALTER TABLE podcasts
|
||||
ALTER COLUMN status TYPE podcast_status
|
||||
USING (
|
||||
CASE status::text
|
||||
WHEN 'awaiting_brief' THEN 'pending'
|
||||
WHEN 'drafting' THEN 'generating'
|
||||
WHEN 'awaiting_review' THEN 'generating'
|
||||
WHEN 'rendering' THEN 'generating'
|
||||
WHEN 'cancelled' THEN 'failed'
|
||||
ELSE status::text
|
||||
END
|
||||
)::podcast_status;
|
||||
""",
|
||||
default_value="ready",
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
_drop_podcasts_from_publication()
|
||||
|
||||
# Retype the status enum by swapping in a fresh type and casting existing
|
||||
# rows. The legacy transient value 'generating' maps onto 'rendering'.
|
||||
_upgrade_status_enum()
|
||||
|
||||
op.execute("ALTER TABLE podcasts ADD COLUMN IF NOT EXISTS source_content TEXT;")
|
||||
op.execute("ALTER TABLE podcasts ADD COLUMN IF NOT EXISTS spec JSONB;")
|
||||
op.execute(
|
||||
"ALTER TABLE podcasts ADD COLUMN IF NOT EXISTS spec_version "
|
||||
"INTEGER NOT NULL DEFAULT 1;"
|
||||
)
|
||||
op.execute(
|
||||
"ALTER TABLE podcasts ADD COLUMN IF NOT EXISTS storage_backend VARCHAR(32);"
|
||||
)
|
||||
op.execute("ALTER TABLE podcasts ADD COLUMN IF NOT EXISTS storage_key TEXT;")
|
||||
op.execute(
|
||||
"ALTER TABLE podcasts ADD COLUMN IF NOT EXISTS duration_seconds INTEGER;"
|
||||
)
|
||||
op.execute("ALTER TABLE podcasts ADD COLUMN IF NOT EXISTS error TEXT;")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
_drop_podcasts_from_publication()
|
||||
|
||||
op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS error;")
|
||||
op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS duration_seconds;")
|
||||
op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS storage_key;")
|
||||
op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS storage_backend;")
|
||||
op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS spec_version;")
|
||||
op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS spec;")
|
||||
op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS source_content;")
|
||||
|
||||
# Collapse the expanded lifecycle back onto the original four values.
|
||||
_downgrade_status_enum()
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
"""publish podcasts to zero_publication
|
||||
|
||||
Reconciles ``zero_publication`` after migration 158 added the lifecycle columns,
|
||||
so the frontend observes podcast status and the reviewable brief by push.
|
||||
|
||||
Revision ID: 159
|
||||
Revises: 158
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
from app.zero_publication import apply_publication
|
||||
|
||||
revision: str = "159"
|
||||
down_revision: str | None = "158"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
apply_publication(op.get_bind())
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""No-op. Historical publication shapes are immutable."""
|
||||
|
|
@ -1,299 +0,0 @@
|
|||
"""add model connections
|
||||
|
||||
Revision ID: 160
|
||||
Revises: 159
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "160"
|
||||
down_revision: str | None = "159"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
connection_scope = postgresql.ENUM(
|
||||
"GLOBAL",
|
||||
"SEARCH_SPACE",
|
||||
"USER",
|
||||
name="connectionscope",
|
||||
create_type=False,
|
||||
)
|
||||
model_source = postgresql.ENUM(
|
||||
"DISCOVERED",
|
||||
"MANUAL",
|
||||
name="modelsource",
|
||||
create_type=False,
|
||||
)
|
||||
|
||||
|
||||
def _table_exists(table_name: str) -> bool:
|
||||
return table_name in sa.inspect(op.get_bind()).get_table_names()
|
||||
|
||||
|
||||
def _column_exists(table_name: str, column_name: str) -> bool:
|
||||
if not _table_exists(table_name):
|
||||
return False
|
||||
return column_name in {
|
||||
column["name"] for column in sa.inspect(op.get_bind()).get_columns(table_name)
|
||||
}
|
||||
|
||||
|
||||
def _index_exists(table_name: str, index_name: str) -> bool:
|
||||
if not _table_exists(table_name):
|
||||
return False
|
||||
return index_name in {
|
||||
index["name"] for index in sa.inspect(op.get_bind()).get_indexes(table_name)
|
||||
}
|
||||
|
||||
|
||||
def _create_index_if_missing(
|
||||
index_name: str,
|
||||
table_name: str,
|
||||
columns: list[str],
|
||||
) -> None:
|
||||
if not _index_exists(table_name, index_name):
|
||||
op.create_index(index_name, table_name, columns, unique=False)
|
||||
|
||||
|
||||
def _add_searchspace_column_if_missing(
|
||||
column_name: str,
|
||||
*,
|
||||
server_default: object | None = None,
|
||||
) -> None:
|
||||
if not _column_exists("searchspaces", column_name):
|
||||
op.add_column(
|
||||
"searchspaces",
|
||||
sa.Column(
|
||||
column_name,
|
||||
sa.Integer(),
|
||||
nullable=True,
|
||||
server_default=server_default,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _drop_column_if_exists(table_name: str, column_name: str) -> None:
|
||||
if _column_exists(table_name, column_name):
|
||||
op.drop_column(table_name, column_name)
|
||||
|
||||
|
||||
def _drop_index_if_exists(table_name: str, index_name: str) -> None:
|
||||
if _index_exists(table_name, index_name):
|
||||
op.drop_index(index_name, table_name=table_name)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
bind = op.get_bind()
|
||||
connection_scope.create(bind, checkfirst=True)
|
||||
model_source.create(bind, checkfirst=True)
|
||||
|
||||
if _table_exists("connections"):
|
||||
if _column_exists("connections", "litellm_provider") and not _column_exists(
|
||||
"connections", "provider"
|
||||
):
|
||||
op.alter_column(
|
||||
"connections",
|
||||
"litellm_provider",
|
||||
new_column_name="provider",
|
||||
existing_type=sa.String(length=100),
|
||||
existing_nullable=True,
|
||||
)
|
||||
op.alter_column(
|
||||
"connections",
|
||||
"provider",
|
||||
existing_type=sa.String(length=100),
|
||||
nullable=False,
|
||||
)
|
||||
elif _column_exists("connections", "native_provider") and not _column_exists(
|
||||
"connections", "provider"
|
||||
):
|
||||
op.alter_column(
|
||||
"connections",
|
||||
"native_provider",
|
||||
new_column_name="provider",
|
||||
existing_type=sa.String(length=100),
|
||||
existing_nullable=True,
|
||||
)
|
||||
op.alter_column(
|
||||
"connections",
|
||||
"provider",
|
||||
existing_type=sa.String(length=100),
|
||||
nullable=False,
|
||||
)
|
||||
elif not _column_exists("connections", "provider"):
|
||||
op.add_column(
|
||||
"connections",
|
||||
sa.Column("provider", sa.String(length=100), nullable=False),
|
||||
)
|
||||
_drop_index_if_exists("connections", "ix_connections_protocol")
|
||||
_drop_column_if_exists("connections", "protocol")
|
||||
else:
|
||||
op.create_table(
|
||||
"connections",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("provider", sa.String(length=100), nullable=False),
|
||||
sa.Column("base_url", sa.String(length=500), nullable=True),
|
||||
sa.Column("api_key", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"extra",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
server_default=sa.text("'{}'::jsonb"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("scope", connection_scope, nullable=False),
|
||||
sa.Column(
|
||||
"enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False
|
||||
),
|
||||
sa.Column("search_space_id", sa.Integer(), nullable=True),
|
||||
sa.Column("user_id", sa.UUID(), nullable=True),
|
||||
sa.CheckConstraint(
|
||||
"(scope = 'GLOBAL' AND search_space_id IS NULL AND user_id IS NULL) OR "
|
||||
"(scope = 'SEARCH_SPACE' AND search_space_id IS NOT NULL AND user_id IS NOT NULL) OR "
|
||||
"(scope = 'USER' AND user_id IS NOT NULL)",
|
||||
name="ck_connections_scope_owner",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["search_space_id"], ["searchspaces.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
if _index_exists(
|
||||
"connections", "ix_connections_native_provider"
|
||||
) and not _index_exists("connections", "ix_connections_provider"):
|
||||
op.execute(
|
||||
"ALTER INDEX ix_connections_native_provider "
|
||||
"RENAME TO ix_connections_provider"
|
||||
)
|
||||
if _index_exists(
|
||||
"connections", "ix_connections_litellm_provider"
|
||||
) and not _index_exists("connections", "ix_connections_provider"):
|
||||
op.execute(
|
||||
"ALTER INDEX ix_connections_litellm_provider "
|
||||
"RENAME TO ix_connections_provider"
|
||||
)
|
||||
_create_index_if_missing("ix_connections_provider", "connections", ["provider"])
|
||||
_create_index_if_missing("ix_connections_scope", "connections", ["scope"])
|
||||
|
||||
if not _table_exists("models"):
|
||||
op.create_table(
|
||||
"models",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("connection_id", sa.Integer(), nullable=False),
|
||||
sa.Column("model_id", sa.String(length=255), nullable=False),
|
||||
sa.Column("display_name", sa.String(length=255), nullable=True),
|
||||
sa.Column(
|
||||
"source",
|
||||
model_source,
|
||||
server_default="DISCOVERED",
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("supports_chat", sa.Boolean(), nullable=True),
|
||||
sa.Column("max_input_tokens", sa.Integer(), nullable=True),
|
||||
sa.Column("supports_image_input", sa.Boolean(), nullable=True),
|
||||
sa.Column("supports_tools", sa.Boolean(), nullable=True),
|
||||
sa.Column("supports_image_generation", sa.Boolean(), nullable=True),
|
||||
sa.Column(
|
||||
"capabilities_override",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
server_default=sa.text("'{}'::jsonb"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False
|
||||
),
|
||||
sa.Column("billing_tier", sa.String(length=50), nullable=True),
|
||||
sa.Column(
|
||||
"catalog",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
server_default=sa.text("'{}'::jsonb"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["connection_id"], ["connections.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint(
|
||||
"connection_id", "model_id", name="uq_models_connection_model_id"
|
||||
),
|
||||
)
|
||||
else:
|
||||
if not _column_exists("models", "supports_chat"):
|
||||
op.add_column(
|
||||
"models", sa.Column("supports_chat", sa.Boolean(), nullable=True)
|
||||
)
|
||||
if not _column_exists("models", "max_input_tokens"):
|
||||
op.add_column(
|
||||
"models", sa.Column("max_input_tokens", sa.Integer(), nullable=True)
|
||||
)
|
||||
if not _column_exists("models", "supports_image_input"):
|
||||
op.add_column(
|
||||
"models", sa.Column("supports_image_input", sa.Boolean(), nullable=True)
|
||||
)
|
||||
if not _column_exists("models", "supports_tools"):
|
||||
op.add_column(
|
||||
"models", sa.Column("supports_tools", sa.Boolean(), nullable=True)
|
||||
)
|
||||
if not _column_exists("models", "supports_image_generation"):
|
||||
op.add_column(
|
||||
"models",
|
||||
sa.Column("supports_image_generation", sa.Boolean(), nullable=True),
|
||||
)
|
||||
_drop_column_if_exists("models", "capabilities")
|
||||
_drop_column_if_exists("models", "capabilities_declared")
|
||||
_drop_column_if_exists("models", "capabilities_verified")
|
||||
_create_index_if_missing("ix_models_connection_id", "models", ["connection_id"])
|
||||
_create_index_if_missing("ix_models_model_id", "models", ["model_id"])
|
||||
_create_index_if_missing("ix_models_billing_tier", "models", ["billing_tier"])
|
||||
|
||||
_add_searchspace_column_if_missing("chat_model_id", server_default=sa.text("0"))
|
||||
_add_searchspace_column_if_missing(
|
||||
"image_gen_model_id", server_default=sa.text("0")
|
||||
)
|
||||
_add_searchspace_column_if_missing("vision_model_id", server_default=sa.text("0"))
|
||||
for column_name in ("chat_model_id", "image_gen_model_id", "vision_model_id"):
|
||||
op.alter_column(
|
||||
"searchspaces",
|
||||
column_name,
|
||||
existing_type=sa.Integer(),
|
||||
existing_nullable=True,
|
||||
server_default=sa.text("0"),
|
||||
)
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE searchspaces
|
||||
SET
|
||||
chat_model_id = COALESCE(chat_model_id, 0),
|
||||
image_gen_model_id = COALESCE(image_gen_model_id, 0),
|
||||
vision_model_id = COALESCE(vision_model_id, 0)
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute("DROP TYPE IF EXISTS connectionprotocol")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("searchspaces", "vision_model_id")
|
||||
op.drop_column("searchspaces", "image_gen_model_id")
|
||||
op.drop_column("searchspaces", "chat_model_id")
|
||||
|
||||
op.drop_index(op.f("ix_models_billing_tier"), table_name="models")
|
||||
op.drop_index("ix_models_model_id", table_name="models")
|
||||
op.drop_index(op.f("ix_models_connection_id"), table_name="models")
|
||||
op.drop_table("models")
|
||||
|
||||
op.drop_index(op.f("ix_connections_scope"), table_name="connections")
|
||||
op.drop_index(op.f("ix_connections_provider"), table_name="connections")
|
||||
op.drop_table("connections")
|
||||
|
||||
bind = op.get_bind()
|
||||
model_source.drop(bind, checkfirst=True)
|
||||
connection_scope.drop(bind, checkfirst=True)
|
||||
|
|
@ -1,270 +0,0 @@
|
|||
"""remove legacy model config tables
|
||||
|
||||
Revision ID: 161
|
||||
Revises: 160
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.types import TypeEngine
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "161"
|
||||
down_revision: str | None = "160"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
litellm_provider = postgresql.ENUM(
|
||||
"OPENAI",
|
||||
"ANTHROPIC",
|
||||
"GOOGLE",
|
||||
"AZURE_OPENAI",
|
||||
"BEDROCK",
|
||||
"VERTEX_AI",
|
||||
"GROQ",
|
||||
"COHERE",
|
||||
"MISTRAL",
|
||||
"DEEPSEEK",
|
||||
"XAI",
|
||||
"OPENROUTER",
|
||||
"TOGETHER_AI",
|
||||
"FIREWORKS_AI",
|
||||
"REPLICATE",
|
||||
"PERPLEXITY",
|
||||
"OLLAMA",
|
||||
"ALIBABA_QWEN",
|
||||
"MOONSHOT",
|
||||
"ZHIPU",
|
||||
"ANYSCALE",
|
||||
"DEEPINFRA",
|
||||
"CEREBRAS",
|
||||
"SAMBANOVA",
|
||||
"AI21",
|
||||
"CLOUDFLARE",
|
||||
"DATABRICKS",
|
||||
"COMETAPI",
|
||||
"HUGGINGFACE",
|
||||
"GITHUB_MODELS",
|
||||
"MINIMAX",
|
||||
"CUSTOM",
|
||||
name="litellmprovider",
|
||||
create_type=False,
|
||||
)
|
||||
image_gen_provider = postgresql.ENUM(
|
||||
"OPENAI",
|
||||
"AZURE_OPENAI",
|
||||
"GOOGLE",
|
||||
"VERTEX_AI",
|
||||
"BEDROCK",
|
||||
"RECRAFT",
|
||||
"OPENROUTER",
|
||||
"XINFERENCE",
|
||||
"NSCALE",
|
||||
name="imagegenprovider",
|
||||
create_type=False,
|
||||
)
|
||||
vision_provider = postgresql.ENUM(
|
||||
"OPENAI",
|
||||
"ANTHROPIC",
|
||||
"GOOGLE",
|
||||
"AZURE_OPENAI",
|
||||
"VERTEX_AI",
|
||||
"BEDROCK",
|
||||
"XAI",
|
||||
"OPENROUTER",
|
||||
"OLLAMA",
|
||||
"GROQ",
|
||||
"TOGETHER_AI",
|
||||
"FIREWORKS_AI",
|
||||
"DEEPSEEK",
|
||||
"MISTRAL",
|
||||
"CUSTOM",
|
||||
name="visionprovider",
|
||||
create_type=False,
|
||||
)
|
||||
|
||||
|
||||
def _table_exists(table_name: str) -> bool:
|
||||
return table_name in sa.inspect(op.get_bind()).get_table_names()
|
||||
|
||||
|
||||
def _column_exists(table_name: str, column_name: str) -> bool:
|
||||
if not _table_exists(table_name):
|
||||
return False
|
||||
return column_name in {
|
||||
column["name"] for column in sa.inspect(op.get_bind()).get_columns(table_name)
|
||||
}
|
||||
|
||||
|
||||
def _drop_column_if_exists(table_name: str, column_name: str) -> None:
|
||||
if _column_exists(table_name, column_name):
|
||||
op.drop_column(table_name, column_name)
|
||||
|
||||
|
||||
def _rename_column_if_exists(
|
||||
table_name: str,
|
||||
old_column_name: str,
|
||||
new_column_name: str,
|
||||
*,
|
||||
existing_type: TypeEngine,
|
||||
existing_nullable: bool = True,
|
||||
) -> None:
|
||||
if _column_exists(table_name, old_column_name) and not _column_exists(
|
||||
table_name, new_column_name
|
||||
):
|
||||
op.alter_column(
|
||||
table_name,
|
||||
old_column_name,
|
||||
new_column_name=new_column_name,
|
||||
existing_type=existing_type,
|
||||
existing_nullable=existing_nullable,
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
for table_name in (
|
||||
"new_llm_configs",
|
||||
"vision_llm_configs",
|
||||
"image_generation_configs",
|
||||
):
|
||||
if _table_exists(table_name):
|
||||
op.drop_table(table_name)
|
||||
|
||||
_drop_column_if_exists("searchspaces", "agent_llm_id")
|
||||
_drop_column_if_exists("searchspaces", "image_generation_config_id")
|
||||
_drop_column_if_exists("searchspaces", "vision_llm_config_id")
|
||||
|
||||
_rename_column_if_exists(
|
||||
"image_generations",
|
||||
"image_generation_config_id",
|
||||
"image_gen_model_id",
|
||||
existing_type=sa.Integer(),
|
||||
)
|
||||
|
||||
op.execute("DROP TYPE IF EXISTS litellmprovider")
|
||||
op.execute("DROP TYPE IF EXISTS imagegenprovider")
|
||||
op.execute("DROP TYPE IF EXISTS visionprovider")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
bind = op.get_bind()
|
||||
litellm_provider.create(bind, checkfirst=True)
|
||||
image_gen_provider.create(bind, checkfirst=True)
|
||||
vision_provider.create(bind, checkfirst=True)
|
||||
|
||||
_rename_column_if_exists(
|
||||
"image_generations",
|
||||
"image_gen_model_id",
|
||||
"image_generation_config_id",
|
||||
existing_type=sa.Integer(),
|
||||
)
|
||||
|
||||
if _table_exists("searchspaces"):
|
||||
if not _column_exists("searchspaces", "agent_llm_id"):
|
||||
op.add_column(
|
||||
"searchspaces",
|
||||
sa.Column("agent_llm_id", sa.Integer(), nullable=True),
|
||||
)
|
||||
if not _column_exists("searchspaces", "image_generation_config_id"):
|
||||
op.add_column(
|
||||
"searchspaces",
|
||||
sa.Column("image_generation_config_id", sa.Integer(), nullable=True),
|
||||
)
|
||||
if not _column_exists("searchspaces", "vision_llm_config_id"):
|
||||
op.add_column(
|
||||
"searchspaces",
|
||||
sa.Column("vision_llm_config_id", sa.Integer(), nullable=True),
|
||||
)
|
||||
|
||||
if not _table_exists("image_generation_configs"):
|
||||
op.create_table(
|
||||
"image_generation_configs",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("name", sa.String(length=100), nullable=False),
|
||||
sa.Column("description", sa.String(length=500), nullable=True),
|
||||
sa.Column("provider", image_gen_provider, nullable=False),
|
||||
sa.Column("custom_provider", sa.String(length=100), nullable=True),
|
||||
sa.Column("model_name", sa.String(length=100), nullable=False),
|
||||
sa.Column("api_key", sa.String(), nullable=False),
|
||||
sa.Column("api_base", sa.String(length=500), nullable=True),
|
||||
sa.Column("api_version", sa.String(length=50), nullable=True),
|
||||
sa.Column("litellm_params", sa.JSON(), nullable=True),
|
||||
sa.Column("search_space_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["search_space_id"], ["searchspaces.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_image_generation_configs_name"),
|
||||
"image_generation_configs",
|
||||
["name"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
if not _table_exists("vision_llm_configs"):
|
||||
op.create_table(
|
||||
"vision_llm_configs",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("name", sa.String(length=100), nullable=False),
|
||||
sa.Column("description", sa.String(length=500), nullable=True),
|
||||
sa.Column("provider", vision_provider, nullable=False),
|
||||
sa.Column("custom_provider", sa.String(length=100), nullable=True),
|
||||
sa.Column("model_name", sa.String(length=100), nullable=False),
|
||||
sa.Column("api_key", sa.String(), nullable=False),
|
||||
sa.Column("api_base", sa.String(length=500), nullable=True),
|
||||
sa.Column("api_version", sa.String(length=50), nullable=True),
|
||||
sa.Column("litellm_params", sa.JSON(), nullable=True),
|
||||
sa.Column("search_space_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["search_space_id"], ["searchspaces.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_vision_llm_configs_name"),
|
||||
"vision_llm_configs",
|
||||
["name"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
if not _table_exists("new_llm_configs"):
|
||||
op.create_table(
|
||||
"new_llm_configs",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("name", sa.String(length=100), nullable=False),
|
||||
sa.Column("description", sa.String(length=500), nullable=True),
|
||||
sa.Column("provider", litellm_provider, nullable=False),
|
||||
sa.Column("custom_provider", sa.String(length=100), nullable=True),
|
||||
sa.Column("model_name", sa.String(length=100), nullable=False),
|
||||
sa.Column("api_key", sa.String(), nullable=False),
|
||||
sa.Column("api_base", sa.String(length=500), nullable=True),
|
||||
sa.Column("litellm_params", sa.JSON(), nullable=True),
|
||||
sa.Column("system_instructions", sa.Text(), nullable=False),
|
||||
sa.Column("use_default_system_instructions", sa.Boolean(), nullable=False),
|
||||
sa.Column("citations_enabled", sa.Boolean(), nullable=False),
|
||||
sa.Column("search_space_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["search_space_id"], ["searchspaces.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_new_llm_configs_name"),
|
||||
"new_llm_configs",
|
||||
["name"],
|
||||
unique=False,
|
||||
)
|
||||
|
|
@ -1,53 +0,0 @@
|
|||
"""add etl_cache_parses table for content-addressed parse reuse
|
||||
|
||||
Revision ID: 162
|
||||
Revises: 161
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "162"
|
||||
down_revision: str | None = "161"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS etl_cache_parses (
|
||||
id SERIAL PRIMARY KEY,
|
||||
source_sha256 VARCHAR(64) NOT NULL,
|
||||
etl_service VARCHAR(32) NOT NULL,
|
||||
mode VARCHAR(16) NOT NULL,
|
||||
parser_version INTEGER NOT NULL,
|
||||
storage_backend VARCHAR(32) NOT NULL,
|
||||
storage_key TEXT NOT NULL,
|
||||
size_bytes BIGINT NOT NULL,
|
||||
content_type VARCHAR(32) NOT NULL,
|
||||
actual_pages INTEGER NOT NULL DEFAULT 0,
|
||||
times_reused BIGINT NOT NULL DEFAULT 0,
|
||||
last_used_at TIMESTAMP WITH TIME ZONE NOT NULL,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
CONSTRAINT uq_etl_cache_parses_key
|
||||
UNIQUE (source_sha256, etl_service, mode, parser_version)
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_etl_cache_parses_last_used_at "
|
||||
"ON etl_cache_parses(last_used_at);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_etl_cache_parses_created_at "
|
||||
"ON etl_cache_parses(created_at);"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP INDEX IF EXISTS ix_etl_cache_parses_created_at;")
|
||||
op.execute("DROP INDEX IF EXISTS ix_etl_cache_parses_last_used_at;")
|
||||
op.execute("DROP TABLE IF EXISTS etl_cache_parses;")
|
||||
|
|
@ -1,53 +0,0 @@
|
|||
"""add embedding_cache_sets table for content-addressed embedding reuse
|
||||
|
||||
Revision ID: 163
|
||||
Revises: 162
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "163"
|
||||
down_revision: str | None = "162"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS embedding_cache_sets (
|
||||
id SERIAL PRIMARY KEY,
|
||||
markdown_sha256 VARCHAR(64) NOT NULL,
|
||||
embedding_model VARCHAR(255) NOT NULL,
|
||||
embedding_dim INTEGER NOT NULL,
|
||||
chunker_kind VARCHAR(8) NOT NULL,
|
||||
chunker_version INTEGER NOT NULL,
|
||||
storage_backend VARCHAR(32) NOT NULL,
|
||||
storage_key TEXT NOT NULL,
|
||||
size_bytes BIGINT NOT NULL,
|
||||
chunk_count INTEGER NOT NULL DEFAULT 0,
|
||||
times_reused BIGINT NOT NULL DEFAULT 0,
|
||||
last_used_at TIMESTAMP WITH TIME ZONE NOT NULL,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
CONSTRAINT uq_embedding_cache_sets_key
|
||||
UNIQUE (markdown_sha256, embedding_model, chunker_kind, chunker_version)
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_embedding_cache_sets_last_used_at "
|
||||
"ON embedding_cache_sets(last_used_at);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_embedding_cache_sets_created_at "
|
||||
"ON embedding_cache_sets(created_at);"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP INDEX IF EXISTS ix_embedding_cache_sets_created_at;")
|
||||
op.execute("DROP INDEX IF EXISTS ix_embedding_cache_sets_last_used_at;")
|
||||
op.execute("DROP TABLE IF EXISTS embedding_cache_sets;")
|
||||
|
|
@ -1,219 +0,0 @@
|
|||
"""remove users that never logged back in (last_login IS NULL)
|
||||
|
||||
Migration 103 added ``user.last_login``. Any user whose ``last_login`` is still
|
||||
NULL has never authenticated since that column existed, i.e. they never logged
|
||||
back in. This migration purges those users together with everything that hangs
|
||||
off them: the search spaces they own, and (via ON DELETE CASCADE)
|
||||
``searchspaces -> documents -> chunks`` plus all other user/space-scoped rows.
|
||||
|
||||
This runs BEFORE the chunks.position backfill (revision 165) on purpose: it
|
||||
removes a large amount of dead chunk data first, so the expensive backfill has
|
||||
far fewer rows to rewrite.
|
||||
|
||||
Work is done in committed batches (not one giant cascading DELETE) so that on a
|
||||
large table it streams progress to the alembic console, keeps each transaction
|
||||
small, bounds WAL/bloat growth, and is resumable if interrupted.
|
||||
|
||||
Revision ID: 164
|
||||
Revises: 163
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "164"
|
||||
down_revision: str | None = "163"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
# Documents removed per committed batch. Each document delete cascades to its
|
||||
# chunks (via ix_chunks_document_id), so keep this modest to bound batch size.
|
||||
DOC_BATCH = 1_000
|
||||
# Users removed per committed batch. Each cascades to owned search spaces and
|
||||
# the remaining space-/user-scoped rows.
|
||||
USER_BATCH = 500
|
||||
# Minimum seconds between progress log lines (keeps the console readable).
|
||||
LOG_EVERY_SECONDS = 5.0
|
||||
|
||||
USER_SCRATCH = "_inactive_user_ids"
|
||||
DOC_SCRATCH = "_inactive_doc_ids"
|
||||
|
||||
logger = logging.getLogger("alembic.runtime.migration")
|
||||
|
||||
|
||||
def _fmt_duration(seconds: float) -> str:
|
||||
seconds = int(seconds)
|
||||
h, rem = divmod(seconds, 3600)
|
||||
m, s = divmod(rem, 60)
|
||||
if h:
|
||||
return f"{h}h{m:02d}m{s:02d}s"
|
||||
if m:
|
||||
return f"{m}m{s:02d}s"
|
||||
return f"{s}s"
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
bind = op.get_bind()
|
||||
|
||||
# Run the heavy work outside the migration's single transaction so each
|
||||
# batch can commit on its own.
|
||||
with op.get_context().autocommit_block():
|
||||
# Materialize the target user ids once. Rebuilt from scratch on every
|
||||
# run, so a re-run after an interruption simply picks up whoever still
|
||||
# has NULL last_login -> the migration is idempotent and resumable.
|
||||
op.execute(f"DROP TABLE IF EXISTS {USER_SCRATCH};")
|
||||
op.execute(
|
||||
f"CREATE UNLOGGED TABLE {USER_SCRATCH} AS "
|
||||
'SELECT id FROM "user" WHERE last_login IS NULL;'
|
||||
)
|
||||
op.execute(f"ALTER TABLE {USER_SCRATCH} ADD PRIMARY KEY (id);")
|
||||
|
||||
total_users = (
|
||||
bind.execute(sa.text(f"SELECT count(*) FROM {USER_SCRATCH}")).scalar() or 0
|
||||
)
|
||||
if total_users == 0:
|
||||
logger.info("no users with NULL last_login; nothing to remove")
|
||||
op.execute(f"DROP TABLE IF EXISTS {USER_SCRATCH};")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"found %s users with NULL last_login (never logged back in); "
|
||||
"removing them and all data in search spaces they own",
|
||||
f"{total_users:,}",
|
||||
)
|
||||
|
||||
# Documents living in search spaces owned by those users. Deleting these
|
||||
# explicitly (in batches) is what bounds the otherwise-unbounded
|
||||
# chunks cascade.
|
||||
op.execute(f"DROP TABLE IF EXISTS {DOC_SCRATCH};")
|
||||
op.execute(
|
||||
f"""
|
||||
CREATE UNLOGGED TABLE {DOC_SCRATCH} AS
|
||||
SELECT d.id
|
||||
FROM documents d
|
||||
JOIN searchspaces s ON s.id = d.search_space_id
|
||||
WHERE s.user_id IN (SELECT id FROM {USER_SCRATCH});
|
||||
"""
|
||||
)
|
||||
op.execute(f"ALTER TABLE {DOC_SCRATCH} ADD PRIMARY KEY (id);")
|
||||
total_docs = (
|
||||
bind.execute(sa.text(f"SELECT count(*) FROM {DOC_SCRATCH}")).scalar() or 0
|
||||
)
|
||||
|
||||
# Phase 1: delete documents (cascades chunks, document_versions,
|
||||
# document_files) in committed batches.
|
||||
logger.info(
|
||||
"phase 1/2: deleting %s documents (cascades their chunks) "
|
||||
"in batches of %s...",
|
||||
f"{total_docs:,}",
|
||||
f"{DOC_BATCH:,}",
|
||||
)
|
||||
_batched_delete(
|
||||
bind,
|
||||
scratch=DOC_SCRATCH,
|
||||
target_table="documents",
|
||||
target_col="id",
|
||||
batch_size=DOC_BATCH,
|
||||
total=total_docs,
|
||||
label="documents",
|
||||
)
|
||||
op.execute(f"DROP TABLE IF EXISTS {DOC_SCRATCH};")
|
||||
|
||||
# Phase 2: delete the users themselves. This cascades the now-empty
|
||||
# search spaces plus all remaining user-/space-scoped rows.
|
||||
logger.info(
|
||||
"phase 2/2: deleting %s users (cascades search spaces and "
|
||||
"remaining data) in batches of %s...",
|
||||
f"{total_users:,}",
|
||||
f"{USER_BATCH:,}",
|
||||
)
|
||||
_batched_delete(
|
||||
bind,
|
||||
scratch=USER_SCRATCH,
|
||||
target_table='"user"',
|
||||
target_col="id",
|
||||
batch_size=USER_BATCH,
|
||||
total=total_users,
|
||||
label="users",
|
||||
)
|
||||
op.execute(f"DROP TABLE IF EXISTS {USER_SCRATCH};")
|
||||
|
||||
logger.info("migration 164 finished")
|
||||
|
||||
|
||||
def _batched_delete(
|
||||
bind: sa.engine.Connection,
|
||||
*,
|
||||
scratch: str,
|
||||
target_table: str,
|
||||
target_col: str,
|
||||
batch_size: int,
|
||||
total: int,
|
||||
label: str,
|
||||
) -> None:
|
||||
"""Pop ids from ``scratch`` and delete the matching rows, one committed
|
||||
batch at a time, logging progress. Atomic per batch: the row delete and the
|
||||
scratch pop happen in a single statement, so an interrupted run leaves the
|
||||
scratch table in sync with what has actually been deleted."""
|
||||
started = time.monotonic()
|
||||
last_log = 0.0
|
||||
done = 0
|
||||
|
||||
stmt = sa.text(
|
||||
f"""
|
||||
WITH batch AS (
|
||||
SELECT id FROM {scratch} LIMIT :n
|
||||
), deleted AS (
|
||||
DELETE FROM {target_table}
|
||||
WHERE {target_col} IN (SELECT id FROM batch)
|
||||
), popped AS (
|
||||
DELETE FROM {scratch}
|
||||
WHERE id IN (SELECT id FROM batch)
|
||||
RETURNING id
|
||||
)
|
||||
SELECT count(*) FROM popped
|
||||
"""
|
||||
)
|
||||
|
||||
while True:
|
||||
popped = bind.execute(stmt, {"n": batch_size}).scalar() or 0
|
||||
if popped == 0:
|
||||
break
|
||||
done += popped
|
||||
|
||||
now = time.monotonic()
|
||||
if now - last_log >= LOG_EVERY_SECONDS or done >= total:
|
||||
elapsed = now - started
|
||||
pct = (100.0 * done / total) if total else 100.0
|
||||
eta = (elapsed / pct * (100.0 - pct)) if pct > 0 else 0.0
|
||||
logger.info(
|
||||
"%s deleted: %.1f%% (%s/%s) elapsed %s eta %s",
|
||||
label,
|
||||
pct,
|
||||
f"{done:,}",
|
||||
f"{total:,}",
|
||||
_fmt_duration(elapsed),
|
||||
_fmt_duration(eta),
|
||||
)
|
||||
last_log = now
|
||||
|
||||
logger.info(
|
||||
"deleted %s %s in %s",
|
||||
f"{done:,}",
|
||||
label,
|
||||
_fmt_duration(time.monotonic() - started),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Irreversible: deleted users and their cascaded data cannot be restored.
|
||||
# No-op so the downgrade chain can still pass through this revision.
|
||||
logger.warning(
|
||||
"migration 164 (remove_inactive_users) is irreversible; "
|
||||
"downgrade is a no-op (deleted users/data are not restored)"
|
||||
)
|
||||
|
|
@ -1,49 +0,0 @@
|
|||
"""add chunks.position for explicit document order
|
||||
|
||||
Incremental re-indexing keeps unchanged chunk rows, so auto-increment ids no
|
||||
longer reflect document order. The ``position`` column makes that order
|
||||
explicit and is written by the indexing pipeline for every new or re-indexed
|
||||
document.
|
||||
|
||||
This migration intentionally does NOT backfill historical rows. On a large,
|
||||
heavily-indexed table (notably a multi-hundred-GB HNSW embedding index) a bulk
|
||||
UPDATE of every chunk becomes a non-HOT update that rewrites every secondary
|
||||
index per row -- turning a one-off migration into a multi-day operation.
|
||||
Instead, existing rows keep ``position = 0`` and therefore order by the
|
||||
``Chunk.id`` tiebreaker (identical to the pre-feature behavior); new and
|
||||
re-indexed documents get correct positions from application code, and any
|
||||
document needing exact order can simply be re-indexed on demand.
|
||||
|
||||
Revision ID: 165
|
||||
Revises: 164
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "165"
|
||||
down_revision: str | None = "164"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
# Leftover UNLOGGED scratch table from earlier backfill attempts; dropped here
|
||||
# so re-running this migration converges the schema regardless of past state.
|
||||
SCRATCH_TABLE = "_chunk_position_backfill"
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Adding a NOT NULL column with a constant default is metadata-only on
|
||||
# PostgreSQL 11+, so this is fast even on very large tables. IF NOT EXISTS
|
||||
# makes it a no-op where the column already exists.
|
||||
op.execute(
|
||||
"ALTER TABLE chunks ADD COLUMN IF NOT EXISTS position INTEGER NOT NULL DEFAULT 0;"
|
||||
)
|
||||
|
||||
# Clean up the scratch table left behind by the abandoned backfill approach.
|
||||
op.execute(f"DROP TABLE IF EXISTS {SCRATCH_TABLE};")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(f"DROP TABLE IF EXISTS {SCRATCH_TABLE};")
|
||||
op.execute("ALTER TABLE chunks DROP COLUMN IF EXISTS position;")
|
||||
|
|
@ -241,15 +241,8 @@ async def _create_document(
|
|||
chunk_embeddings = await asyncio.to_thread(embed_texts, chunks)
|
||||
session.add_all(
|
||||
[
|
||||
Chunk(
|
||||
document_id=doc.id,
|
||||
content=text,
|
||||
embedding=embedding,
|
||||
position=i,
|
||||
)
|
||||
for i, (text, embedding) in enumerate(
|
||||
zip(chunks, chunk_embeddings, strict=True)
|
||||
)
|
||||
Chunk(document_id=doc.id, content=text, embedding=embedding)
|
||||
for text, embedding in zip(chunks, chunk_embeddings, strict=True)
|
||||
]
|
||||
)
|
||||
return doc
|
||||
|
|
@ -296,15 +289,8 @@ async def _update_document(
|
|||
chunk_embeddings = await asyncio.to_thread(embed_texts, chunks)
|
||||
session.add_all(
|
||||
[
|
||||
Chunk(
|
||||
document_id=document.id,
|
||||
content=text,
|
||||
embedding=embedding,
|
||||
position=i,
|
||||
)
|
||||
for i, (text, embedding) in enumerate(
|
||||
zip(chunks, chunk_embeddings, strict=True)
|
||||
)
|
||||
Chunk(document_id=document.id, content=text, embedding=embedding)
|
||||
for text, embedding in zip(chunks, chunk_embeddings, strict=True)
|
||||
]
|
||||
)
|
||||
return document
|
||||
|
|
@ -489,9 +475,7 @@ async def _load_chunks_for_snapshot(
|
|||
session: AsyncSession, *, doc_id: int
|
||||
) -> list[dict[str, str]]:
|
||||
rows = await session.execute(
|
||||
select(Chunk.content)
|
||||
.where(Chunk.document_id == doc_id)
|
||||
.order_by(Chunk.position, Chunk.id)
|
||||
select(Chunk.content).where(Chunk.document_id == doc_id).order_by(Chunk.id)
|
||||
)
|
||||
return [{"content": row.content} for row in rows.all() if row.content is not None]
|
||||
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ async def build_agent_with_cache(
|
|||
mcp_tools_by_agent: dict[str, list[BaseTool]],
|
||||
disabled_tools: list[str] | None,
|
||||
config_id: str | None,
|
||||
image_gen_model_id_override: int | None = None,
|
||||
image_generation_config_id_override: int | None = None,
|
||||
) -> Any:
|
||||
"""Compile the multi-agent graph, serving from cache when key components are stable."""
|
||||
|
||||
|
|
@ -121,7 +121,7 @@ async def build_agent_with_cache(
|
|||
# Bound into the generate_image subagent tool at construction time, so it
|
||||
# must key the compiled-agent cache to avoid leaking one automation's
|
||||
# image model into another with the same config_id/search_space.
|
||||
image_gen_model_id_override,
|
||||
image_generation_config_id_override,
|
||||
)
|
||||
return await get_cache().get_or_build(cache_key, builder=_build)
|
||||
|
||||
|
|
|
|||
|
|
@ -72,11 +72,11 @@ async def create_multi_agent_chat_deep_agent(
|
|||
mentioned_document_ids: list[int] | None = None,
|
||||
anon_session_id: str | None = None,
|
||||
filesystem_selection: FilesystemSelection | None = None,
|
||||
image_gen_model_id: int | None = None,
|
||||
image_generation_config_id: int | None = None,
|
||||
):
|
||||
"""Deep agent with SurfSense tools/middleware; registry route subagents behind ``task`` when enabled.
|
||||
|
||||
``image_gen_model_id`` overrides the search space's image model for
|
||||
``image_generation_config_id`` overrides the search space's image model for
|
||||
this invocation (used by automations to run on their captured model). When
|
||||
``None``, the ``generate_image`` tool resolves the live search-space pref.
|
||||
"""
|
||||
|
|
@ -147,7 +147,7 @@ async def create_multi_agent_chat_deep_agent(
|
|||
"llm": llm,
|
||||
# Per-invocation image model override (automations run on their captured
|
||||
# model). Reaches the generate_image subagent tool via subagent_dependencies.
|
||||
"image_gen_model_id_override": image_gen_model_id,
|
||||
"image_generation_config_id_override": image_generation_config_id,
|
||||
}
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
|
|
@ -303,7 +303,7 @@ async def create_multi_agent_chat_deep_agent(
|
|||
mcp_tools_by_agent=mcp_tools_by_agent,
|
||||
disabled_tools=disabled_tools,
|
||||
config_id=config_id,
|
||||
image_gen_model_id_override=image_gen_model_id,
|
||||
image_generation_config_id_override=image_generation_config_id,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[create_agent] Middleware stack + graph compiled in %.3fs",
|
||||
|
|
|
|||
|
|
@ -126,25 +126,23 @@ user: "Create issues in Linear for each of these five bugs: <list>"
|
|||
|
||||
<example>
|
||||
user: "Make a 30-second podcast of this conversation."
|
||||
→ Podcast deliverable. The `deliverables` subagent sets the podcast up and
|
||||
returns **immediately** — generation does not happen during the call. A
|
||||
live card in the chat takes over from there: the user reviews the brief
|
||||
(language, voices, length) on the card, and the episode drafts and
|
||||
renders automatically after they approve.
|
||||
→ Celery-backed deliverable. The `deliverables` subagent dispatches the
|
||||
Celery job and then **waits for it to finish** before returning. The
|
||||
call may take 10-60 seconds (or longer for video presentations) —
|
||||
that is intentional, not a hang. You always get back one of two
|
||||
Receipt shapes:
|
||||
task(deliverables, "Generate a podcast titled '<title>' from the
|
||||
following content. Aim for a 30-second style brief. Return the
|
||||
podcast id and title.\n\n<source content>")
|
||||
following content. Use a 30-second style brief. Return the podcast
|
||||
id and title.\n\n<source content>")
|
||||
Outcomes:
|
||||
- **`status="success"`**: the podcast is set up. Do NOT describe its
|
||||
current status or promise it is ready — the card tracks progress
|
||||
live and will outlive whatever you say. Just point the user at the
|
||||
card in the chat.
|
||||
- **`status="success"`**: the audio is saved. Tell the user the
|
||||
podcast is **ready** and quote the `external_id` / `preview` so
|
||||
they can find it in the podcast panel.
|
||||
- **`status="failed"`**: surface the Receipt's `error` field
|
||||
verbatim. Do NOT silently re-dispatch — the backend already tried
|
||||
and reported a real error.
|
||||
Video presentations differ: that Celery-backed call **waits for the
|
||||
render to finish** before returning (possibly minutes — intentional,
|
||||
not a hang) and ends with a terminal status. If a
|
||||
Same two-way pattern applies to video presentations (which take
|
||||
longer to render, but still return a terminal status). If a
|
||||
`task(deliverables, ...)` invocation itself times out at the subagent
|
||||
layer (separate from the Receipt), that's an operator-side problem
|
||||
with the subagent invoke timeout, not a deliverable failure — pass
|
||||
|
|
|
|||
|
|
@ -508,7 +508,7 @@ class KBPostgresBackend(BackendProtocol):
|
|||
chunk_rows = await session.execute(
|
||||
select(Chunk.id, Chunk.content)
|
||||
.where(Chunk.document_id == document.id)
|
||||
.order_by(Chunk.position, Chunk.id)
|
||||
.order_by(Chunk.id)
|
||||
)
|
||||
chunks = [
|
||||
{"chunk_id": row.id, "content": row.content} for row in chunk_rows.all()
|
||||
|
|
@ -725,7 +725,7 @@ class KBPostgresBackend(BackendProtocol):
|
|||
.join(Document, Document.id == Chunk.document_id)
|
||||
.where(Document.search_space_id == self.search_space_id)
|
||||
.where(Chunk.content.ilike(f"%{pattern}%"))
|
||||
.order_by(Chunk.document_id, Chunk.position, Chunk.id)
|
||||
.order_by(Chunk.document_id, Chunk.id)
|
||||
)
|
||||
chunk_rows = await session.execute(sub)
|
||||
per_doc: dict[int, int] = {}
|
||||
|
|
|
|||
|
|
@ -394,10 +394,7 @@ async def browse_recent_documents(
|
|||
Chunk.document_id,
|
||||
Chunk.content,
|
||||
func.row_number()
|
||||
.over(
|
||||
partition_by=Chunk.document_id,
|
||||
order_by=(Chunk.position, Chunk.id),
|
||||
)
|
||||
.over(partition_by=Chunk.document_id, order_by=Chunk.id)
|
||||
.label("rn"),
|
||||
)
|
||||
.where(Chunk.document_id.in_(doc_ids))
|
||||
|
|
@ -407,7 +404,7 @@ async def browse_recent_documents(
|
|||
chunk_query = (
|
||||
select(numbered.c.chunk_id, numbered.c.document_id, numbered.c.content)
|
||||
.where(numbered.c.rn <= _RECENCY_MAX_CHUNKS_PER_DOC)
|
||||
.order_by(numbered.c.document_id, numbered.c.rn)
|
||||
.order_by(numbered.c.document_id, numbered.c.chunk_id)
|
||||
)
|
||||
chunk_result = await session.execute(chunk_query)
|
||||
fetched_chunks = chunk_result.all()
|
||||
|
|
@ -534,7 +531,7 @@ async def fetch_mentioned_documents(
|
|||
chunk_result = await session.execute(
|
||||
select(Chunk.id, Chunk.content, Chunk.document_id)
|
||||
.where(Chunk.document_id.in_(list(docs.keys())))
|
||||
.order_by(Chunk.document_id, Chunk.position, Chunk.id)
|
||||
.order_by(Chunk.document_id, Chunk.id)
|
||||
)
|
||||
chunks_by_doc: dict[int, list[dict[str, Any]]] = {doc_id: [] for doc_id in docs}
|
||||
for row in chunk_result.all():
|
||||
|
|
|
|||
|
|
@ -10,53 +10,70 @@ from langgraph.types import Command
|
|||
from litellm import aimage_generation
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.agents.chat.multi_agent_chat.shared.receipts.command import with_receipt
|
||||
from app.agents.chat.multi_agent_chat.shared.receipts.receipt import make_receipt
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
ImageGeneration,
|
||||
Model,
|
||||
ImageGenerationConfig,
|
||||
SearchSpace,
|
||||
shielded_async_session,
|
||||
)
|
||||
from app.services.auto_model_pin_service import (
|
||||
auto_model_candidates,
|
||||
choose_auto_model_candidate,
|
||||
)
|
||||
from app.services.image_gen_router_service import (
|
||||
IMAGE_GEN_AUTO_MODE_ID,
|
||||
ImageGenRouterService,
|
||||
is_image_gen_auto_mode,
|
||||
)
|
||||
from app.services.model_capabilities import has_capability
|
||||
from app.services.model_resolver import to_litellm
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
from app.utils.signed_image_urls import generate_image_token
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_global_model(model_id: int) -> dict | None:
|
||||
return next((m for m in config.GLOBAL_MODELS if m.get("id") == model_id), None)
|
||||
# Provider mapping (same as routes)
|
||||
_PROVIDER_MAP = {
|
||||
"OPENAI": "openai",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"GOOGLE": "gemini",
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"BEDROCK": "bedrock",
|
||||
"RECRAFT": "recraft",
|
||||
"OPENROUTER": "openrouter",
|
||||
"XINFERENCE": "xinference",
|
||||
"NSCALE": "nscale",
|
||||
}
|
||||
|
||||
|
||||
def _get_global_connection(connection_id: int) -> dict | None:
|
||||
return next(
|
||||
(c for c in config.GLOBAL_CONNECTIONS if c.get("id") == connection_id),
|
||||
None,
|
||||
)
|
||||
def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str:
|
||||
if custom_provider:
|
||||
return custom_provider
|
||||
return _PROVIDER_MAP.get(provider.upper(), provider.lower())
|
||||
|
||||
|
||||
def _build_model_string(
|
||||
provider: str, model_name: str, custom_provider: str | None
|
||||
) -> str:
|
||||
return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}"
|
||||
|
||||
|
||||
def _get_global_image_gen_config(config_id: int) -> dict | None:
|
||||
"""Get a global image gen config by negative ID."""
|
||||
for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS:
|
||||
if cfg.get("id") == config_id:
|
||||
return cfg
|
||||
return None
|
||||
|
||||
|
||||
def create_generate_image_tool(
|
||||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
image_gen_model_id_override: int | None = None,
|
||||
image_generation_config_id_override: int | None = None,
|
||||
):
|
||||
"""Create ``generate_image`` with bound search space; DB work uses a per-call session.
|
||||
|
||||
``image_gen_model_id_override``: when set (automations running on a
|
||||
captured model), use this model id instead of reading the search space's
|
||||
live ``image_gen_model_id``.
|
||||
``image_generation_config_id_override``: when set (automations running on a
|
||||
captured model), use this config id instead of reading the search space's
|
||||
live ``image_generation_config_id``.
|
||||
"""
|
||||
del db_session # tool uses a fresh per-call session instead
|
||||
|
||||
|
|
@ -101,23 +118,26 @@ def create_generate_image_tool(
|
|||
# task's session is shared across every tool; without isolation,
|
||||
# autoflushes from a concurrent writer poison this tool too.
|
||||
async with shielded_async_session() as session:
|
||||
result = await session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
||||
)
|
||||
search_space = result.scalars().first()
|
||||
if not search_space:
|
||||
return _failed(
|
||||
{"error": "Search space not found"},
|
||||
error="Search space not found",
|
||||
)
|
||||
|
||||
if image_gen_model_id_override is not None:
|
||||
if image_generation_config_id_override is not None:
|
||||
# Automation run: use the captured image model, insulated from
|
||||
# later search-space changes. No search-space read needed.
|
||||
config_id = image_gen_model_id_override or IMAGE_GEN_AUTO_MODE_ID
|
||||
else:
|
||||
config_id = (
|
||||
search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID
|
||||
image_generation_config_id_override or IMAGE_GEN_AUTO_MODE_ID
|
||||
)
|
||||
else:
|
||||
result = await session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
||||
)
|
||||
search_space = result.scalars().first()
|
||||
if not search_space:
|
||||
return _failed(
|
||||
{"error": "Search space not found"},
|
||||
error="Search space not found",
|
||||
)
|
||||
|
||||
config_id = (
|
||||
search_space.image_generation_config_id
|
||||
or IMAGE_GEN_AUTO_MODE_ID
|
||||
)
|
||||
|
||||
# size/quality/style are intentionally omitted: valid values
|
||||
|
|
@ -127,86 +147,73 @@ def create_generate_image_tool(
|
|||
gen_kwargs["n"] = n
|
||||
|
||||
if is_image_gen_auto_mode(config_id):
|
||||
candidates = await auto_model_candidates(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
user_id=search_space.user_id,
|
||||
capability="image_gen",
|
||||
)
|
||||
if not candidates:
|
||||
if not ImageGenRouterService.is_initialized():
|
||||
err = (
|
||||
"No image generation models available. "
|
||||
"No image generation models configured. "
|
||||
"Please add an image model in Settings > Image Models."
|
||||
)
|
||||
return _failed({"error": err}, error=err)
|
||||
config_id = int(
|
||||
choose_auto_model_candidate(candidates, search_space_id)["id"]
|
||||
response = await ImageGenRouterService.aimage_generation(
|
||||
prompt=prompt, model="auto", **gen_kwargs
|
||||
)
|
||||
|
||||
provider_base_url: str | None = None
|
||||
|
||||
if config_id < 0:
|
||||
global_model = _get_global_model(config_id)
|
||||
if not global_model or not has_capability(
|
||||
global_model, "image_gen"
|
||||
):
|
||||
err = f"Image generation model {config_id} not found"
|
||||
return _failed({"error": err}, error=err)
|
||||
global_connection = _get_global_connection(
|
||||
global_model["connection_id"]
|
||||
)
|
||||
if not global_connection:
|
||||
err = f"Image generation connection for model {config_id} not found"
|
||||
elif config_id < 0:
|
||||
cfg = _get_global_image_gen_config(config_id)
|
||||
if not cfg:
|
||||
err = f"Image generation config {config_id} not found"
|
||||
return _failed({"error": err}, error=err)
|
||||
|
||||
model_string, resolved_kwargs = to_litellm(
|
||||
global_connection,
|
||||
global_model["model_id"],
|
||||
provider_prefix = _resolve_provider_prefix(
|
||||
cfg.get("provider", ""), cfg.get("custom_provider")
|
||||
)
|
||||
gen_kwargs.update(resolved_kwargs)
|
||||
provider_base_url = resolved_kwargs.get("api_base")
|
||||
model_string = f"{provider_prefix}/{cfg['model_name']}"
|
||||
gen_kwargs["api_key"] = cfg.get("api_key")
|
||||
# Defense-in-depth: an empty ``api_base`` must not fall
|
||||
# through to LiteLLM's global ``api_base`` (e.g. Azure).
|
||||
api_base = resolve_api_base(
|
||||
provider=cfg.get("provider"),
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=cfg.get("api_base"),
|
||||
)
|
||||
if api_base:
|
||||
gen_kwargs["api_base"] = api_base
|
||||
if cfg.get("api_version"):
|
||||
gen_kwargs["api_version"] = cfg["api_version"]
|
||||
if cfg.get("litellm_params"):
|
||||
gen_kwargs.update(cfg["litellm_params"])
|
||||
|
||||
response = await aimage_generation(
|
||||
prompt=prompt, model=model_string, **gen_kwargs
|
||||
)
|
||||
else:
|
||||
# Positive ID = Model + Connection
|
||||
# Positive ID = user-created ImageGenerationConfig
|
||||
cfg_result = await session.execute(
|
||||
select(Model)
|
||||
.options(selectinload(Model.connection))
|
||||
.filter(Model.id == config_id, Model.enabled.is_(True))
|
||||
select(ImageGenerationConfig).filter(
|
||||
ImageGenerationConfig.id == config_id
|
||||
)
|
||||
)
|
||||
db_model = cfg_result.scalars().first()
|
||||
if (
|
||||
not db_model
|
||||
or not db_model.connection
|
||||
or not db_model.connection.enabled
|
||||
):
|
||||
err = f"Image generation model {config_id} not found"
|
||||
return _failed({"error": err}, error=err)
|
||||
conn = db_model.connection
|
||||
if (
|
||||
conn.search_space_id is not None
|
||||
and conn.search_space_id != search_space_id
|
||||
):
|
||||
err = f"Image generation model {config_id} not found"
|
||||
return _failed({"error": err}, error=err)
|
||||
if (
|
||||
conn.user_id is not None
|
||||
and conn.user_id != search_space.user_id
|
||||
):
|
||||
err = f"Image generation model {config_id} not found"
|
||||
return _failed({"error": err}, error=err)
|
||||
if not has_capability(db_model, "image_gen"):
|
||||
err = f"Model {config_id} is not image-generation capable"
|
||||
db_cfg = cfg_result.scalars().first()
|
||||
if not db_cfg:
|
||||
err = f"Image generation config {config_id} not found"
|
||||
return _failed({"error": err}, error=err)
|
||||
|
||||
model_string, resolved_kwargs = to_litellm(
|
||||
db_model.connection,
|
||||
db_model.model_id,
|
||||
provider_prefix = _resolve_provider_prefix(
|
||||
db_cfg.provider.value, db_cfg.custom_provider
|
||||
)
|
||||
gen_kwargs.update(resolved_kwargs)
|
||||
provider_base_url = resolved_kwargs.get("api_base")
|
||||
model_string = f"{provider_prefix}/{db_cfg.model_name}"
|
||||
gen_kwargs["api_key"] = db_cfg.api_key
|
||||
# Defense-in-depth: an empty ``api_base`` must not fall
|
||||
# through to LiteLLM's global ``api_base`` (e.g. Azure).
|
||||
api_base = resolve_api_base(
|
||||
provider=db_cfg.provider.value,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=db_cfg.api_base,
|
||||
)
|
||||
if api_base:
|
||||
gen_kwargs["api_base"] = api_base
|
||||
if db_cfg.api_version:
|
||||
gen_kwargs["api_version"] = db_cfg.api_version
|
||||
if db_cfg.litellm_params:
|
||||
gen_kwargs.update(db_cfg.litellm_params)
|
||||
|
||||
response = await aimage_generation(
|
||||
prompt=prompt, model=model_string, **gen_kwargs
|
||||
|
|
@ -223,7 +230,7 @@ def create_generate_image_tool(
|
|||
prompt=prompt,
|
||||
model=getattr(response, "_hidden_params", {}).get("model"),
|
||||
n=n,
|
||||
image_gen_model_id=config_id,
|
||||
image_generation_config_id=config_id,
|
||||
response_data=response_dict,
|
||||
search_space_id=search_space_id,
|
||||
access_token=access_token,
|
||||
|
|
@ -245,19 +252,8 @@ def create_generate_image_tool(
|
|||
|
||||
# b64_json (e.g. gpt-image-1) is served via our backend endpoint so
|
||||
# megabytes of base64 don't bloat the LLM context.
|
||||
# Some OpenAI-compatible backends (e.g. Xinference) return a relative
|
||||
# URL like /files/image.png. Browsers can't resolve these, so we
|
||||
# prepend the provider's base origin when the URL starts with "/".
|
||||
if first_image.get("url"):
|
||||
raw_url: str = first_image["url"]
|
||||
if raw_url.startswith("/") and provider_base_url:
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(provider_base_url)
|
||||
origin = f"{parsed.scheme}://{parsed.netloc}"
|
||||
image_url = f"{origin}{raw_url}"
|
||||
else:
|
||||
image_url = raw_url
|
||||
image_url = first_image["url"]
|
||||
elif first_image.get("b64_json"):
|
||||
backend_url = config.BACKEND_URL or "http://localhost:8000"
|
||||
image_url = (
|
||||
|
|
|
|||
|
|
@ -51,6 +51,8 @@ def load_tools(
|
|||
create_generate_image_tool(
|
||||
search_space_id=d["search_space_id"],
|
||||
db_session=d["db_session"],
|
||||
image_gen_model_id_override=d.get("image_gen_model_id_override"),
|
||||
image_generation_config_id_override=d.get(
|
||||
"image_generation_config_id_override"
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -122,7 +122,7 @@ async def _browse_recent_documents(
|
|||
chunk_query = (
|
||||
select(Chunk)
|
||||
.where(Chunk.document_id.in_(doc_ids))
|
||||
.order_by(Chunk.document_id, Chunk.position, Chunk.id)
|
||||
.order_by(Chunk.document_id, Chunk.id)
|
||||
)
|
||||
chunk_result = await session.execute(chunk_query)
|
||||
raw_chunks = chunk_result.scalars().all()
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
"""Factory for a podcast-generation tool.
|
||||
|
||||
Creates the podcast and proposes its brief (language, voices, length) inline,
|
||||
then returns immediately with the row awaiting review. Everything after —
|
||||
brief approval, drafting, rendering — happens on the live podcast card, so
|
||||
this tool never blocks on generation and the chat text must not describe a
|
||||
status that the card will outgrow.
|
||||
Dispatches the heavy generation to Celery and then polls the podcast row
|
||||
until it reaches a terminal status (READY/FAILED). The tool always
|
||||
returns a real terminal ``Receipt`` — never a pending one. The wait is
|
||||
bounded by the existing per-invocation safety net
|
||||
(``SURFSENSE_SUBAGENT_INVOKE_TIMEOUT_SECONDS`` in multi-agent mode,
|
||||
HTTP / process lifetime in single-agent mode).
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
|
@ -17,12 +18,13 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.agents.chat.multi_agent_chat.shared.receipts.command import with_receipt
|
||||
from app.agents.chat.multi_agent_chat.shared.receipts.receipt import make_receipt
|
||||
from app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.deliverable_wait import (
|
||||
wait_for_deliverable,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.tools.thread_resolver import (
|
||||
resolve_root_thread_id,
|
||||
)
|
||||
from app.db import PodcastStatus, shielded_async_session
|
||||
from app.podcasts.generation.brief import propose_brief
|
||||
from app.podcasts.service import PodcastService
|
||||
from app.db import Podcast, PodcastStatus, shielded_async_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -43,7 +45,7 @@ def create_generate_podcast_tool(
|
|||
user_prompt: str | None = None,
|
||||
) -> Command:
|
||||
"""
|
||||
Prepare a podcast from the provided content for the user to review.
|
||||
Generate a podcast from the provided content.
|
||||
|
||||
Use this tool when the user asks to create, generate, or make a podcast.
|
||||
Common triggers include phrases like:
|
||||
|
|
@ -53,59 +55,100 @@ def create_generate_podcast_tool(
|
|||
- "Make a podcast about..."
|
||||
- "Turn this into a podcast"
|
||||
|
||||
This sets up the podcast and proposes its brief (language, voices,
|
||||
length). The user reviews the brief on the live podcast card in the
|
||||
chat; after approval the episode drafts and renders automatically.
|
||||
Generation does not start here, and the card tracks all progress — do
|
||||
not describe the podcast's current status in your reply.
|
||||
|
||||
Args:
|
||||
source_content: The text content to convert into a podcast.
|
||||
podcast_title: Title for the podcast (default: "SurfSense Podcast")
|
||||
user_prompt: Optional steer for what the episode should focus on.
|
||||
user_prompt: Optional instructions for podcast style, tone, or format.
|
||||
|
||||
Returns:
|
||||
A dictionary containing:
|
||||
- status: the podcast lifecycle status (awaiting_brief on success)
|
||||
- podcast_id: the podcast ID to review in the panel
|
||||
- title: the podcast title
|
||||
- message: what the user should do next (or "error" when failed)
|
||||
- status: PodcastStatus value (pending, generating, or failed)
|
||||
- podcast_id: The podcast ID for polling (when status is pending or generating)
|
||||
- title: The podcast title
|
||||
- message: Status message (or "error" field if status is failed)
|
||||
"""
|
||||
try:
|
||||
# One DB session per tool call so parallel invocations never share an AsyncSession.
|
||||
async with shielded_async_session() as session:
|
||||
service = PodcastService(session)
|
||||
podcast = await service.create(
|
||||
podcast = Podcast(
|
||||
title=podcast_title,
|
||||
status=PodcastStatus.PENDING,
|
||||
search_space_id=search_space_id,
|
||||
thread_id=resolve_root_thread_id(runtime, thread_id),
|
||||
)
|
||||
podcast.source_content = source_content
|
||||
spec = await propose_brief(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
focus=user_prompt,
|
||||
)
|
||||
await service.attach_brief(podcast, spec)
|
||||
session.add(podcast)
|
||||
await session.commit()
|
||||
await session.refresh(podcast)
|
||||
podcast_id = podcast.id
|
||||
|
||||
logger.info(
|
||||
"[generate_podcast] Prepared podcast %s awaiting brief review",
|
||||
podcast_id,
|
||||
from app.tasks.celery_tasks.podcast_tasks import (
|
||||
generate_content_podcast_task,
|
||||
)
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"status": PodcastStatus.AWAITING_BRIEF.value,
|
||||
task = generate_content_podcast_task.delay(
|
||||
podcast_id=podcast_id,
|
||||
source_content=source_content,
|
||||
search_space_id=search_space_id,
|
||||
user_prompt=user_prompt,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"[generate_podcast] Created podcast %s, task: %s",
|
||||
podcast_id,
|
||||
task.id,
|
||||
)
|
||||
|
||||
# Wait until the Celery worker flips the row to a terminal
|
||||
# state. The wait is bounded only by the subagent invoke
|
||||
# timeout (multi-agent) or HTTP lifetime (single-agent) —
|
||||
# see app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.deliverable_wait for details.
|
||||
terminal_status, columns, elapsed = await wait_for_deliverable(
|
||||
model=Podcast,
|
||||
row_id=podcast_id,
|
||||
columns=[Podcast.status, Podcast.file_location],
|
||||
terminal_statuses={PodcastStatus.READY, PodcastStatus.FAILED},
|
||||
)
|
||||
|
||||
if terminal_status == PodcastStatus.READY:
|
||||
file_location = columns[1] if columns else None
|
||||
logger.info(
|
||||
"[generate_podcast] Podcast %s READY in %.2fs (file=%s)",
|
||||
podcast_id,
|
||||
elapsed,
|
||||
file_location,
|
||||
)
|
||||
payload: dict[str, Any] = {
|
||||
"status": PodcastStatus.READY.value,
|
||||
"podcast_id": podcast_id,
|
||||
"title": podcast_title,
|
||||
"file_location": file_location,
|
||||
"message": ("Podcast generated and saved to your podcast panel."),
|
||||
}
|
||||
return with_receipt(
|
||||
payload=payload,
|
||||
receipt=make_receipt(
|
||||
route="deliverables",
|
||||
type="podcast",
|
||||
operation="generate",
|
||||
status="success",
|
||||
external_id=str(podcast_id),
|
||||
preview=podcast_title,
|
||||
),
|
||||
tool_call_id=runtime.tool_call_id,
|
||||
)
|
||||
|
||||
# Only other terminal state is FAILED.
|
||||
logger.warning(
|
||||
"[generate_podcast] Podcast %s FAILED in %.2fs",
|
||||
podcast_id,
|
||||
elapsed,
|
||||
)
|
||||
err = "Background worker reported FAILED status for this podcast."
|
||||
payload = {
|
||||
"status": PodcastStatus.FAILED.value,
|
||||
"podcast_id": podcast_id,
|
||||
"title": podcast_title,
|
||||
"message": (
|
||||
"Podcast set up. The card in the chat handles the rest: "
|
||||
"the user reviews the brief (language, voices, length) "
|
||||
"there, and the episode drafts and renders automatically "
|
||||
"after approval. The card tracks progress live, so do not "
|
||||
"state the podcast's current status in your reply."
|
||||
),
|
||||
"error": err,
|
||||
}
|
||||
return with_receipt(
|
||||
payload=payload,
|
||||
|
|
@ -113,9 +156,10 @@ def create_generate_podcast_tool(
|
|||
route="deliverables",
|
||||
type="podcast",
|
||||
operation="generate",
|
||||
status="success",
|
||||
status="failed",
|
||||
external_id=str(podcast_id),
|
||||
preview=podcast_title,
|
||||
error=err,
|
||||
),
|
||||
tool_call_id=runtime.tool_call_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@
|
|||
LLM configuration utilities for SurfSense agents.
|
||||
|
||||
This module provides functions for loading LLM configurations from:
|
||||
1. Auto mode (ID 0) - Resolved by callers to a concrete model-connection model
|
||||
1. Auto mode (ID 0) - Uses LiteLLM Router for load balancing
|
||||
2. YAML files (global configs with negative IDs)
|
||||
3. Database model-connections table (user-created configs with positive IDs)
|
||||
3. Database NewLLMConfig table (user-created configs with positive IDs)
|
||||
|
||||
It also provides utilities for creating ChatLiteLLM instances and
|
||||
managing prompt configurations.
|
||||
|
|
@ -24,6 +24,8 @@ from langchain_core.messages import AIMessage, BaseMessage
|
|||
from langchain_core.outputs import ChatGenerationChunk, ChatResult
|
||||
from langchain_litellm import ChatLiteLLM
|
||||
from litellm import get_model_info
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.chat.runtime.prompt_caching import (
|
||||
apply_litellm_prompt_caching,
|
||||
|
|
@ -31,7 +33,10 @@ from app.agents.chat.runtime.prompt_caching import (
|
|||
from app.services.llm_router_service import (
|
||||
AUTO_MODE_ID,
|
||||
ChatLiteLLMRouter,
|
||||
LLMRouterService,
|
||||
_sanitize_content,
|
||||
get_auto_mode_llm,
|
||||
is_auto_mode,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -46,19 +51,16 @@ def _sanitize_messages(messages: list[BaseMessage]) -> list[BaseMessage]:
|
|||
reject the blank text. The OpenAI spec says ``content`` should be
|
||||
``null`` when an assistant message only carries tool calls.
|
||||
"""
|
||||
sanitized: list[BaseMessage] = []
|
||||
for msg in messages:
|
||||
next_msg = msg.model_copy(deep=True)
|
||||
if isinstance(next_msg.content, list):
|
||||
next_msg.content = _sanitize_content(next_msg.content)
|
||||
if isinstance(msg.content, list):
|
||||
msg.content = _sanitize_content(msg.content)
|
||||
if (
|
||||
isinstance(next_msg, AIMessage)
|
||||
and (not next_msg.content or next_msg.content == "")
|
||||
and getattr(next_msg, "tool_calls", None)
|
||||
isinstance(msg, AIMessage)
|
||||
and (not msg.content or msg.content == "")
|
||||
and getattr(msg, "tool_calls", None)
|
||||
):
|
||||
next_msg.content = None # type: ignore[assignment]
|
||||
sanitized.append(next_msg)
|
||||
return sanitized
|
||||
msg.content = None # type: ignore[assignment]
|
||||
return messages
|
||||
|
||||
|
||||
class SanitizedChatLiteLLM(ChatLiteLLM):
|
||||
|
|
@ -89,21 +91,13 @@ class SanitizedChatLiteLLM(ChatLiteLLM):
|
|||
):
|
||||
yield chunk
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
stream: bool | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
return await super()._agenerate(
|
||||
_sanitize_messages(messages),
|
||||
stop=stop,
|
||||
run_manager=run_manager,
|
||||
stream=stream,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Re-exported under the historical name ``PROVIDER_MAP``. Source of truth lives
|
||||
# in provider_capabilities so the YAML loader can resolve prefixes during
|
||||
# app.config init without importing the agent/tools tree.
|
||||
from app.services.provider_capabilities import ( # noqa: E402
|
||||
_PROVIDER_PREFIX_MAP as PROVIDER_MAP,
|
||||
)
|
||||
|
||||
|
||||
def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None:
|
||||
|
|
@ -127,9 +121,8 @@ class AgentConfig:
|
|||
"""
|
||||
Complete configuration for the SurfSense agent.
|
||||
|
||||
This combines resolved model settings with prompt configuration.
|
||||
Supports Auto mode metadata (ID 0). Runtime callers must resolve Auto to
|
||||
a concrete global or BYOK model before constructing ChatLiteLLM.
|
||||
This combines LLM settings with prompt configuration from NewLLMConfig.
|
||||
Supports Auto mode (ID 0) which uses LiteLLM Router for load balancing.
|
||||
"""
|
||||
|
||||
# LLM Model Settings
|
||||
|
|
@ -177,7 +170,7 @@ class AgentConfig:
|
|||
use_default_system_instructions=True,
|
||||
citations_enabled=True,
|
||||
config_id=AUTO_MODE_ID,
|
||||
config_name="Auto",
|
||||
config_name="Auto (Fastest)",
|
||||
is_auto_mode=True,
|
||||
billing_tier="free",
|
||||
is_premium=False,
|
||||
|
|
@ -188,21 +181,64 @@ class AgentConfig:
|
|||
supports_image_input=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_new_llm_config(cls, config) -> "AgentConfig":
|
||||
"""Build an AgentConfig from a NewLLMConfig database model."""
|
||||
# Lazy import: keeps provider_capabilities (and litellm) out of init order.
|
||||
from app.services.provider_capabilities import derive_supports_image_input
|
||||
|
||||
provider_value = (
|
||||
config.provider.value
|
||||
if hasattr(config.provider, "value")
|
||||
else str(config.provider)
|
||||
)
|
||||
litellm_params = config.litellm_params or {}
|
||||
base_model = (
|
||||
litellm_params.get("base_model")
|
||||
if isinstance(litellm_params, dict)
|
||||
else None
|
||||
)
|
||||
|
||||
return cls(
|
||||
provider=provider_value,
|
||||
model_name=config.model_name,
|
||||
api_key=config.api_key,
|
||||
api_base=config.api_base,
|
||||
custom_provider=config.custom_provider,
|
||||
litellm_params=config.litellm_params,
|
||||
system_instructions=config.system_instructions,
|
||||
use_default_system_instructions=config.use_default_system_instructions,
|
||||
citations_enabled=config.citations_enabled,
|
||||
config_id=config.id,
|
||||
config_name=config.name,
|
||||
is_auto_mode=False,
|
||||
billing_tier="free",
|
||||
is_premium=False,
|
||||
anonymous_enabled=False,
|
||||
quota_reserve_tokens=None,
|
||||
# BYOK rows have no curated flag; ask LiteLLM (default-allow on
|
||||
# unknown). The streaming safety net still blocks explicit text-only.
|
||||
supports_image_input=derive_supports_image_input(
|
||||
provider=provider_value,
|
||||
model_name=config.model_name,
|
||||
base_model=base_model,
|
||||
custom_provider=config.custom_provider,
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_yaml_config(cls, yaml_config: dict) -> "AgentConfig":
|
||||
"""Build an AgentConfig from a YAML configuration dictionary.
|
||||
|
||||
Supports prompt fields such as system_instructions,
|
||||
use_default_system_instructions, and citations_enabled.
|
||||
Supports the same prompt fields as NewLLMConfig (system_instructions,
|
||||
use_default_system_instructions, citations_enabled).
|
||||
"""
|
||||
# Lazy import: keeps provider_capabilities (and litellm) out of init order.
|
||||
from app.services.provider_capabilities import derive_supports_image_input
|
||||
|
||||
system_instructions = yaml_config.get("system_instructions", "")
|
||||
|
||||
provider = yaml_config.get("provider") or yaml_config.get(
|
||||
"litellm_provider", ""
|
||||
)
|
||||
provider = yaml_config.get("provider", "").upper()
|
||||
model_name = yaml_config.get("model_name", "")
|
||||
custom_provider = yaml_config.get("custom_provider")
|
||||
litellm_params = yaml_config.get("litellm_params") or {}
|
||||
|
|
@ -288,15 +324,93 @@ def load_global_llm_config_by_id(llm_config_id: int) -> dict | None:
|
|||
return load_llm_config_from_yaml(llm_config_id)
|
||||
|
||||
|
||||
async def load_new_llm_config_from_db(
|
||||
session: AsyncSession,
|
||||
config_id: int,
|
||||
) -> "AgentConfig | None":
|
||||
"""Load a NewLLMConfig from the database by ID."""
|
||||
from app.db import NewLLMConfig
|
||||
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(NewLLMConfig).filter(NewLLMConfig.id == config_id)
|
||||
)
|
||||
config = result.scalars().first()
|
||||
|
||||
if not config:
|
||||
print(f"Error: NewLLMConfig with id {config_id} not found")
|
||||
return None
|
||||
|
||||
return AgentConfig.from_new_llm_config(config)
|
||||
except Exception as e:
|
||||
print(f"Error loading NewLLMConfig from database: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def load_agent_llm_config_for_search_space(
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
) -> "AgentConfig | None":
|
||||
"""Load the agent LLM config for a search space via its agent_llm_id.
|
||||
|
||||
Positive id -> DB; negative -> YAML; None -> first global config (-1).
|
||||
"""
|
||||
from app.db import SearchSpace
|
||||
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
||||
)
|
||||
search_space = result.scalars().first()
|
||||
|
||||
if not search_space:
|
||||
print(f"Error: SearchSpace with id {search_space_id} not found")
|
||||
return None
|
||||
|
||||
config_id = (
|
||||
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
|
||||
)
|
||||
return await load_agent_config(session, config_id, search_space_id)
|
||||
except Exception as e:
|
||||
print(f"Error loading agent LLM config for search space {search_space_id}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def load_agent_config(
|
||||
session: AsyncSession,
|
||||
config_id: int,
|
||||
search_space_id: int | None = None,
|
||||
) -> "AgentConfig | None":
|
||||
"""Main config loader: id 0 -> Auto mode; negative -> YAML; positive -> DB."""
|
||||
if is_auto_mode(config_id):
|
||||
if not LLMRouterService.is_initialized():
|
||||
print("Error: Auto mode requested but LLM Router not initialized")
|
||||
return None
|
||||
return AgentConfig.from_auto_mode()
|
||||
|
||||
if config_id < 0:
|
||||
# In-memory covers static YAML + dynamic OpenRouter configs.
|
||||
from app.config import config as app_config
|
||||
|
||||
for cfg in app_config.GLOBAL_LLM_CONFIGS:
|
||||
if cfg.get("id") == config_id:
|
||||
return AgentConfig.from_yaml_config(cfg)
|
||||
yaml_config = load_llm_config_from_yaml(config_id)
|
||||
if yaml_config:
|
||||
return AgentConfig.from_yaml_config(yaml_config)
|
||||
return None
|
||||
else:
|
||||
return await load_new_llm_config_from_db(session, config_id)
|
||||
|
||||
|
||||
def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
|
||||
"""Create a ChatLiteLLM instance from a global LLM config dictionary."""
|
||||
if llm_config.get("custom_provider"):
|
||||
model_string = f"{llm_config['custom_provider']}/{llm_config['model_name']}"
|
||||
else:
|
||||
provider = llm_config.get("provider") or llm_config.get(
|
||||
"litellm_provider", "openai"
|
||||
)
|
||||
model_string = f"{provider}/{llm_config['model_name']}"
|
||||
provider = llm_config.get("provider", "").upper()
|
||||
provider_prefix = PROVIDER_MAP.get(provider, provider.lower())
|
||||
model_string = f"{provider_prefix}/{llm_config['model_name']}"
|
||||
|
||||
litellm_kwargs = {
|
||||
"model": model_string,
|
||||
|
|
@ -319,17 +433,29 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
|
|||
def create_chat_litellm_from_agent_config(
|
||||
agent_config: AgentConfig,
|
||||
) -> ChatLiteLLM | ChatLiteLLMRouter | None:
|
||||
"""Create a ChatLiteLLM from an already resolved concrete model config."""
|
||||
"""Create a ChatLiteLLM (or, for Auto mode, a load-balancing router) from config."""
|
||||
if agent_config.is_auto_mode:
|
||||
print(
|
||||
"Error: Auto mode must be resolved to a concrete model before LLM creation"
|
||||
)
|
||||
return None
|
||||
if not LLMRouterService.is_initialized():
|
||||
print("Error: Auto mode requested but LLM Router not initialized")
|
||||
return None
|
||||
try:
|
||||
router_llm = get_auto_mode_llm()
|
||||
if router_llm is not None:
|
||||
# Universal injection points only: auto-mode fans out across
|
||||
# providers, so provider-specific kwargs have no known target.
|
||||
apply_litellm_prompt_caching(router_llm, agent_config=agent_config)
|
||||
return router_llm
|
||||
except Exception as e:
|
||||
print(f"Error creating ChatLiteLLMRouter: {e}")
|
||||
return None
|
||||
|
||||
if agent_config.custom_provider:
|
||||
model_string = f"{agent_config.custom_provider}/{agent_config.model_name}"
|
||||
else:
|
||||
model_string = f"{agent_config.provider}/{agent_config.model_name}"
|
||||
provider_prefix = PROVIDER_MAP.get(
|
||||
agent_config.provider, agent_config.provider.lower()
|
||||
)
|
||||
model_string = f"{provider_prefix}/{agent_config.model_name}"
|
||||
|
||||
litellm_kwargs = {
|
||||
"model": model_string,
|
||||
|
|
|
|||
8
surfsense_backend/app/agents/podcaster/__init__.py
Normal file
8
surfsense_backend/app/agents/podcaster/__init__.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
"""New LangGraph Agent.
|
||||
|
||||
This module defines a custom graph.
|
||||
"""
|
||||
|
||||
from .graph import graph
|
||||
|
||||
__all__ = ["graph"]
|
||||
29
surfsense_backend/app/agents/podcaster/configuration.py
Normal file
29
surfsense_backend/app/agents/podcaster/configuration.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
"""Define the configurable parameters for the agent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, fields
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class Configuration:
|
||||
"""The configuration for the agent."""
|
||||
|
||||
# Changeme: Add configurable values here!
|
||||
# these values can be pre-set when you
|
||||
# create assistants (https://langchain-ai.github.io/langgraph/cloud/how-tos/configuration_cloud/)
|
||||
# and when you invoke the graph
|
||||
podcast_title: str
|
||||
search_space_id: int
|
||||
user_prompt: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_runnable_config(
|
||||
cls, config: RunnableConfig | None = None
|
||||
) -> Configuration:
|
||||
"""Create a Configuration instance from a RunnableConfig object."""
|
||||
configurable = (config.get("configurable") or {}) if config else {}
|
||||
_fields = {f.name for f in fields(cls) if f.init}
|
||||
return cls(**{k: v for k, v in configurable.items() if k in _fields})
|
||||
29
surfsense_backend/app/agents/podcaster/graph.py
Normal file
29
surfsense_backend/app/agents/podcaster/graph.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
from langgraph.graph import StateGraph
|
||||
|
||||
from .configuration import Configuration
|
||||
from .nodes import create_merged_podcast_audio, create_podcast_transcript
|
||||
from .state import State
|
||||
|
||||
|
||||
def build_graph():
|
||||
# Define a new graph
|
||||
workflow = StateGraph(State, config_schema=Configuration)
|
||||
|
||||
# Add the node to the graph
|
||||
workflow.add_node("create_podcast_transcript", create_podcast_transcript)
|
||||
workflow.add_node("create_merged_podcast_audio", create_merged_podcast_audio)
|
||||
|
||||
# Set the entrypoint as `call_model`
|
||||
workflow.add_edge("__start__", "create_podcast_transcript")
|
||||
workflow.add_edge("create_podcast_transcript", "create_merged_podcast_audio")
|
||||
workflow.add_edge("create_merged_podcast_audio", "__end__")
|
||||
|
||||
# Compile the workflow into an executable graph
|
||||
graph = workflow.compile()
|
||||
graph.name = "Surfsense Podcaster" # This defines the custom name in LangSmith
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
# Compile the graph once when the module is loaded
|
||||
graph = build_graph()
|
||||
195
surfsense_backend/app/agents/podcaster/nodes.py
Normal file
195
surfsense_backend/app/agents/podcaster/nodes.py
Normal file
|
|
@ -0,0 +1,195 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from ffmpeg.asyncio import FFmpeg
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from litellm import aspeech
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.services.kokoro_tts_service import get_kokoro_tts_service
|
||||
from app.services.llm_service import get_agent_llm
|
||||
from app.utils.content_utils import extract_text_content, strip_markdown_fences
|
||||
|
||||
from .configuration import Configuration
|
||||
from .prompts import get_podcast_generation_prompt
|
||||
from .state import PodcastTranscriptEntry, PodcastTranscripts, State
|
||||
from .utils import get_voice_for_provider
|
||||
|
||||
|
||||
async def create_podcast_transcript(
|
||||
state: State, config: RunnableConfig
|
||||
) -> dict[str, Any]:
|
||||
"""Generate the podcast transcript from the source content."""
|
||||
configuration = Configuration.from_runnable_config(config)
|
||||
search_space_id = configuration.search_space_id
|
||||
user_prompt = configuration.user_prompt
|
||||
|
||||
llm = await get_agent_llm(state.db_session, search_space_id)
|
||||
if not llm:
|
||||
error_message = f"No agent LLM configured for search space {search_space_id}"
|
||||
print(error_message)
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
prompt = get_podcast_generation_prompt(user_prompt)
|
||||
messages = [
|
||||
SystemMessage(content=prompt),
|
||||
HumanMessage(
|
||||
content=f"<source_content>{state.source_content}</source_content>"
|
||||
),
|
||||
]
|
||||
llm_response = await llm.ainvoke(messages)
|
||||
|
||||
# Reasoning models may return content as blocks; normalise to a string.
|
||||
content = strip_markdown_fences(extract_text_content(llm_response.content))
|
||||
|
||||
try:
|
||||
podcast_transcript = PodcastTranscripts.model_validate(json.loads(content))
|
||||
except (json.JSONDecodeError, TypeError, ValueError) as e:
|
||||
print(f"Direct JSON parsing failed, trying fallback approach: {e!s}")
|
||||
|
||||
try:
|
||||
json_start = content.find("{")
|
||||
json_end = content.rfind("}") + 1
|
||||
if json_start >= 0 and json_end > json_start:
|
||||
json_str = content[json_start:json_end]
|
||||
parsed_data = json.loads(json_str)
|
||||
podcast_transcript = PodcastTranscripts.model_validate(parsed_data)
|
||||
print("Successfully parsed podcast transcript using fallback approach")
|
||||
else:
|
||||
error_message = f"Could not find valid JSON in LLM response. Raw response: {content}"
|
||||
print(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
except (json.JSONDecodeError, TypeError, ValueError) as e2:
|
||||
error_message = f"Error parsing LLM response (fallback also failed): {e2!s}"
|
||||
print(f"Error parsing LLM response: {e2!s}")
|
||||
print(f"Raw response: {content}")
|
||||
raise
|
||||
|
||||
return {"podcast_transcript": podcast_transcript.podcast_transcripts}
|
||||
|
||||
|
||||
async def create_merged_podcast_audio(
|
||||
state: State, config: RunnableConfig
|
||||
) -> dict[str, Any]:
|
||||
"""Generate audio for each transcript and merge them into a single podcast file."""
|
||||
starting_transcript = PodcastTranscriptEntry(
|
||||
speaker_id=1, dialog="Welcome to Surfsense Podcast."
|
||||
)
|
||||
|
||||
transcript = state.podcast_transcript
|
||||
|
||||
# transcript may be a PodcastTranscripts object or already a list.
|
||||
if hasattr(transcript, "podcast_transcripts"):
|
||||
transcript_entries = transcript.podcast_transcripts
|
||||
else:
|
||||
transcript_entries = transcript
|
||||
|
||||
merged_transcript = [starting_transcript, *transcript_entries]
|
||||
|
||||
temp_dir = Path("temp_audio")
|
||||
temp_dir.mkdir(exist_ok=True)
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
output_path = f"podcasts/{session_id}_podcast.mp3"
|
||||
os.makedirs("podcasts", exist_ok=True)
|
||||
|
||||
audio_files = []
|
||||
|
||||
async def generate_speech_for_segment(segment, index):
|
||||
if hasattr(segment, "speaker_id"):
|
||||
speaker_id = segment.speaker_id
|
||||
dialog = segment.dialog
|
||||
else:
|
||||
speaker_id = segment.get("speaker_id", 0)
|
||||
dialog = segment.get("dialog", "")
|
||||
|
||||
voice = get_voice_for_provider(app_config.TTS_SERVICE, speaker_id)
|
||||
|
||||
if app_config.TTS_SERVICE == "local/kokoro":
|
||||
filename = f"{temp_dir}/{session_id}_{index}.wav"
|
||||
else:
|
||||
filename = f"{temp_dir}/{session_id}_{index}.mp3"
|
||||
|
||||
try:
|
||||
if app_config.TTS_SERVICE == "local/kokoro":
|
||||
kokoro_service = await get_kokoro_tts_service(
|
||||
lang_code="a"
|
||||
) # American English
|
||||
audio_path = await kokoro_service.generate_speech(
|
||||
text=dialog, voice=voice, speed=1.0, output_path=filename
|
||||
)
|
||||
return audio_path
|
||||
else:
|
||||
if app_config.TTS_SERVICE_API_BASE:
|
||||
response = await aspeech(
|
||||
model=app_config.TTS_SERVICE,
|
||||
api_base=app_config.TTS_SERVICE_API_BASE,
|
||||
api_key=app_config.TTS_SERVICE_API_KEY,
|
||||
voice=voice,
|
||||
input=dialog,
|
||||
max_retries=2,
|
||||
timeout=600,
|
||||
)
|
||||
else:
|
||||
response = await aspeech(
|
||||
model=app_config.TTS_SERVICE,
|
||||
api_key=app_config.TTS_SERVICE_API_KEY,
|
||||
voice=voice,
|
||||
input=dialog,
|
||||
max_retries=2,
|
||||
timeout=600,
|
||||
)
|
||||
|
||||
with open(filename, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
return filename
|
||||
except Exception as e:
|
||||
print(f"Error generating speech for segment {index}: {e!s}")
|
||||
raise
|
||||
|
||||
tasks = [
|
||||
generate_speech_for_segment(segment, i)
|
||||
for i, segment in enumerate(merged_transcript)
|
||||
]
|
||||
audio_files = await asyncio.gather(*tasks)
|
||||
|
||||
try:
|
||||
ffmpeg = FFmpeg().option("y")
|
||||
for audio_file in audio_files:
|
||||
ffmpeg = ffmpeg.input(audio_file)
|
||||
|
||||
filter_complex = []
|
||||
for i in range(len(audio_files)):
|
||||
filter_complex.append(f"[{i}:0]")
|
||||
|
||||
filter_complex_str = (
|
||||
"".join(filter_complex) + f"concat=n={len(audio_files)}:v=0:a=1[outa]"
|
||||
)
|
||||
ffmpeg = ffmpeg.option("filter_complex", filter_complex_str)
|
||||
ffmpeg = ffmpeg.output(output_path, map="[outa]")
|
||||
await ffmpeg.execute()
|
||||
|
||||
print(f"Successfully created podcast audio: {output_path}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error merging audio files: {e!s}")
|
||||
raise
|
||||
finally:
|
||||
for audio_file in audio_files:
|
||||
try:
|
||||
os.remove(audio_file)
|
||||
except Exception as e:
|
||||
print(f"Error removing audio file {audio_file}: {e!s}")
|
||||
pass
|
||||
|
||||
return {
|
||||
"podcast_transcript": merged_transcript,
|
||||
"final_podcast_file_path": output_path,
|
||||
}
|
||||
122
surfsense_backend/app/agents/podcaster/prompts.py
Normal file
122
surfsense_backend/app/agents/podcaster/prompts.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
import datetime
|
||||
|
||||
|
||||
def get_podcast_generation_prompt(user_prompt: str | None = None):
|
||||
return f"""
|
||||
Today's date: {datetime.datetime.now().strftime("%Y-%m-%d")}
|
||||
<podcast_generation_system>
|
||||
You are a master podcast scriptwriter, adept at transforming diverse input content into a lively, engaging, and natural-sounding conversation between two distinct podcast hosts. Your primary objective is to craft authentic, flowing dialogue that captures the spontaneity and chemistry of a real podcast discussion, completely avoiding any hint of robotic scripting or stiff formality. Think dynamic interplay, not just information delivery.
|
||||
|
||||
{
|
||||
f'''
|
||||
You **MUST** strictly adhere to the following user instruction while generating the podcast script:
|
||||
<user_instruction>
|
||||
{user_prompt}
|
||||
</user_instruction>
|
||||
'''
|
||||
if user_prompt
|
||||
else ""
|
||||
}
|
||||
|
||||
<input>
|
||||
- '<source_content>': A block of text containing the information to be discussed in the podcast. This could be research findings, an article summary, a detailed outline, user chat history related to the topic, or any other relevant raw information. The content might be unstructured but serves as the factual basis for the podcast dialogue.
|
||||
</input>
|
||||
|
||||
<output_format>
|
||||
A JSON object containing the podcast transcript with alternating speakers:
|
||||
{{
|
||||
"podcast_transcripts": [
|
||||
{{
|
||||
"speaker_id": 0,
|
||||
"dialog": "Speaker 0 dialog here"
|
||||
}},
|
||||
{{
|
||||
"speaker_id": 1,
|
||||
"dialog": "Speaker 1 dialog here"
|
||||
}},
|
||||
{{
|
||||
"speaker_id": 0,
|
||||
"dialog": "Speaker 0 dialog here"
|
||||
}},
|
||||
{{
|
||||
"speaker_id": 1,
|
||||
"dialog": "Speaker 1 dialog here"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
</output_format>
|
||||
|
||||
<guidelines>
|
||||
1. **Establish Distinct & Consistent Host Personas:**
|
||||
* **Speaker 0 (Lead Host):** Drives the conversation forward, introduces segments, poses key questions derived from the source content, and often summarizes takeaways. Maintain a guiding, clear, and engaging tone.
|
||||
* **Speaker 1 (Co-Host/Expert):** Offers deeper insights, provides alternative viewpoints or elaborations on the source content, asks clarifying or challenging questions, and shares relevant anecdotes or examples. Adopt a complementary tone (e.g., analytical, enthusiastic, reflective, slightly skeptical).
|
||||
* **Consistency is Key:** Ensure each speaker maintains their distinct voice, vocabulary choice, sentence structure, and perspective throughout the entire script. Avoid having them sound interchangeable. Their interaction should feel like a genuine partnership.
|
||||
|
||||
2. **Craft Natural & Dynamic Dialogue:**
|
||||
* **Emulate Real Conversation:** Use contractions (e.g., "don't", "it's"), interjections ("Oh!", "Wow!", "Hmm"), discourse markers ("you know", "right?", "well"), and occasional natural pauses or filler words. Avoid overly formal language or complex sentence structures typical of written text.
|
||||
* **Foster Interaction & Chemistry:** Write dialogue where speakers genuinely react *to each other*. They should build on points ("Exactly, and that reminds me..."), ask follow-up questions ("Could you expand on that?"), express agreement/disagreement respectfully ("That's a fair point, but have you considered...?"), and show active listening.
|
||||
* **Vary Rhythm & Pace:** Mix short, punchy lines with longer, more explanatory ones. Vary sentence beginnings. Use questions to break up exposition. The rhythm should feel spontaneous, not monotonous.
|
||||
* **Inject Personality & Relatability:** Allow for appropriate humor, moments of surprise or curiosity, brief personal reflections ("I actually experienced something similar..."), or relatable asides that fit the hosts' personas and the topic. Lightly reference past discussions if it enhances context ("Remember last week when we touched on...?").
|
||||
|
||||
3. **Structure for Flow and Listener Engagement:**
|
||||
* **Natural Beginning:** Start with dialogue that flows naturally after an introduction (which will be added manually). Avoid redundant greetings or podcast name mentions since these will be added separately.
|
||||
* **Logical Progression & Signposting:** Guide the listener through the information smoothly. Use clear transitions to link different ideas or segments ("So, now that we've covered X, let's dive into Y...", "That actually brings me to another key finding..."). Ensure topics flow logically from one to the next.
|
||||
* **Meaningful Conclusion:** Summarize the key takeaways or main points discussed, reinforcing the core message derived from the source content. End with a final thought, a lingering question for the audience, or a brief teaser for what's next, providing a sense of closure. Avoid abrupt endings.
|
||||
|
||||
4. **Integrate Source Content Seamlessly & Accurately:**
|
||||
* **Translate, Don't Recite:** Rephrase information from the `<source_content>` into conversational language suitable for each host's persona. Avoid directly copying dense sentences or technical jargon without explanation. The goal is discussion, not narration.
|
||||
* **Explain & Contextualize:** Use analogies, simple examples, storytelling, or have one host ask clarifying questions (acting as a listener surrogate) to break down complex ideas from the source.
|
||||
* **Weave Information Naturally:** Integrate facts, data, or key points from the source *within* the dialogue, not as standalone, undigested blocks. Attribute information conversationally where appropriate ("The research mentioned...", "Apparently, the key factor is...").
|
||||
* **Balance Depth & Accessibility:** Ensure the conversation is informative and factually accurate based on the source content, but prioritize clear communication and engaging delivery over exhaustive technical detail. Make it understandable and interesting for a general audience.
|
||||
|
||||
5. **Length & Pacing:**
|
||||
* **Six-Minute Duration:** Create a transcript that, when read at a natural speaking pace, would result in approximately 6 minutes of audio. Typically, this means around 1000 words total (based on average speaking rate of 150 words per minute).
|
||||
* **Concise Speaking Turns:** Keep most speaking turns relatively brief and focused. Aim for a natural back-and-forth rhythm rather than extended monologues.
|
||||
* **Essential Content Only:** Prioritize the most important information from the source content. Focus on quality over quantity, ensuring every line contributes meaningfully to the topic.
|
||||
</guidelines>
|
||||
|
||||
<examples>
|
||||
Input: "Quantum computing uses quantum bits or qubits which can exist in multiple states simultaneously due to superposition."
|
||||
|
||||
Output:
|
||||
{{
|
||||
"podcast_transcripts": [
|
||||
{{
|
||||
"speaker_id": 0,
|
||||
"dialog": "Today we're diving into the mind-bending world of quantum computing. You know, this is a topic I've been excited to cover for weeks."
|
||||
}},
|
||||
{{
|
||||
"speaker_id": 1,
|
||||
"dialog": "Same here! And I know our listeners have been asking for it. But I have to admit, the concept of quantum computing makes my head spin a little. Can we start with the basics?"
|
||||
}},
|
||||
{{
|
||||
"speaker_id": 0,
|
||||
"dialog": "Absolutely. So regular computers use bits, right? Little on-off switches that are either 1 or 0. But quantum computers use something called qubits, and this is where it gets fascinating."
|
||||
}},
|
||||
{{
|
||||
"speaker_id": 1,
|
||||
"dialog": "Wait, what makes qubits so special compared to regular bits?"
|
||||
}},
|
||||
{{
|
||||
"speaker_id": 0,
|
||||
"dialog": "The magic is in something called superposition. These qubits can exist in multiple states at the same time, not just 1 or 0."
|
||||
}},
|
||||
{{
|
||||
"speaker_id": 1,
|
||||
"dialog": "That sounds impossible! How would you even picture that?"
|
||||
}},
|
||||
{{
|
||||
"speaker_id": 0,
|
||||
"dialog": "Think of it like a coin spinning in the air. Before it lands, is it heads or tails?"
|
||||
}},
|
||||
{{
|
||||
"speaker_id": 1,
|
||||
"dialog": "Well, it's... neither? Or I guess both, until it lands? Oh, I think I see where you're going with this."
|
||||
}}
|
||||
]
|
||||
}}
|
||||
</examples>
|
||||
|
||||
Transform the source material into a lively and engaging podcast conversation. Craft dialogue that showcases authentic host chemistry and natural interaction (including occasional disagreement, building on points, or asking follow-up questions). Use varied speech patterns reflecting real human conversation, ensuring the final script effectively educates *and* entertains the listener while keeping within a 5-minute audio duration.
|
||||
</podcast_generation_system>
|
||||
"""
|
||||
43
surfsense_backend/app/agents/podcaster/state.py
Normal file
43
surfsense_backend/app/agents/podcaster/state.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
"""Define the state structures for the agent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
class PodcastTranscriptEntry(BaseModel):
|
||||
"""
|
||||
Represents a single entry in a podcast transcript.
|
||||
"""
|
||||
|
||||
speaker_id: int = Field(..., description="The ID of the speaker (0 or 1)")
|
||||
dialog: str = Field(..., description="The dialog text spoken by the speaker")
|
||||
|
||||
|
||||
class PodcastTranscripts(BaseModel):
|
||||
"""
|
||||
Represents the full podcast transcript structure.
|
||||
"""
|
||||
|
||||
podcast_transcripts: list[PodcastTranscriptEntry] = Field(
|
||||
..., description="List of transcript entries with alternating speakers"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class State:
|
||||
"""Defines the input state for the agent, representing a narrower interface to the outside world.
|
||||
|
||||
This class is used to define the initial state and structure of incoming data.
|
||||
See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state
|
||||
for more information.
|
||||
"""
|
||||
|
||||
# Runtime context
|
||||
db_session: AsyncSession
|
||||
source_content: str
|
||||
podcast_transcript: list[PodcastTranscriptEntry] | None = None
|
||||
final_podcast_file_path: str | None = None
|
||||
84
surfsense_backend/app/agents/podcaster/utils.py
Normal file
84
surfsense_backend/app/agents/podcaster/utils.py
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
def get_voice_for_provider(provider: str, speaker_id: int) -> dict | str:
|
||||
"""
|
||||
Get the appropriate voice configuration based on the TTS provider and speaker ID.
|
||||
|
||||
Args:
|
||||
provider: The TTS provider (e.g., "openai/tts-1", "vertex_ai/test")
|
||||
speaker_id: The ID of the speaker (0-5)
|
||||
|
||||
Returns:
|
||||
Voice configuration - string for OpenAI, dict for Vertex AI
|
||||
"""
|
||||
if provider == "local/kokoro":
|
||||
# Kokoro voice mapping - https://huggingface.co/hexgrad/Kokoro-82M/tree/main/voices
|
||||
kokoro_voices = {
|
||||
0: "am_adam", # Default/intro voice
|
||||
1: "af_bella", # First speaker
|
||||
}
|
||||
return kokoro_voices.get(speaker_id, "af_heart")
|
||||
|
||||
# Extract provider type from the model string
|
||||
provider_type = (
|
||||
provider.split("/")[0].lower() if "/" in provider else provider.lower()
|
||||
)
|
||||
|
||||
if provider_type == "openai":
|
||||
# OpenAI voice mapping - simple string values
|
||||
openai_voices = {
|
||||
0: "alloy", # Default/intro voice
|
||||
1: "echo", # First speaker
|
||||
2: "fable", # Second speaker
|
||||
3: "onyx", # Third speaker
|
||||
4: "nova", # Fourth speaker
|
||||
5: "shimmer", # Fifth speaker
|
||||
}
|
||||
return openai_voices.get(speaker_id, "alloy")
|
||||
|
||||
elif provider_type == "vertex_ai":
|
||||
# Vertex AI voice mapping - dict with languageCode and name
|
||||
vertex_voices = {
|
||||
0: {
|
||||
"languageCode": "en-US",
|
||||
"name": "en-US-Studio-O",
|
||||
},
|
||||
1: {
|
||||
"languageCode": "en-US",
|
||||
"name": "en-US-Studio-M",
|
||||
},
|
||||
2: {
|
||||
"languageCode": "en-UK",
|
||||
"name": "en-UK-Studio-A",
|
||||
},
|
||||
3: {
|
||||
"languageCode": "en-UK",
|
||||
"name": "en-UK-Studio-B",
|
||||
},
|
||||
4: {
|
||||
"languageCode": "en-AU",
|
||||
"name": "en-AU-Studio-A",
|
||||
},
|
||||
5: {
|
||||
"languageCode": "en-AU",
|
||||
"name": "en-AU-Studio-B",
|
||||
},
|
||||
}
|
||||
return vertex_voices.get(speaker_id, vertex_voices[0])
|
||||
elif provider_type == "azure":
|
||||
# OpenAI voice mapping - simple string values
|
||||
azure_voices = {
|
||||
0: "alloy", # Default/intro voice
|
||||
1: "echo", # First speaker
|
||||
2: "fable", # Second speaker
|
||||
3: "onyx", # Third speaker
|
||||
4: "nova", # Fourth speaker
|
||||
5: "shimmer", # Fifth speaker
|
||||
}
|
||||
return azure_voices.get(speaker_id, "alloy")
|
||||
|
||||
else:
|
||||
# Default fallback to OpenAI format for unknown providers
|
||||
default_voices = {
|
||||
0: {},
|
||||
1: {},
|
||||
}
|
||||
return default_voices.get(speaker_id, default_voices[0])
|
||||
|
|
@ -1,7 +1,8 @@
|
|||
"""Video Presentation LangGraph Agent.
|
||||
|
||||
This module defines a graph for generating slide-based video presentations
|
||||
from source content, with TTS narration per slide.
|
||||
This module defines a graph for generating video presentations
|
||||
from source content, similar to the podcaster agent but producing
|
||||
slide-based video presentations with TTS narration.
|
||||
"""
|
||||
|
||||
from .graph import graph
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ from app.config import (
|
|||
initialize_llm_router,
|
||||
initialize_openrouter_integration,
|
||||
initialize_pricing_registration,
|
||||
initialize_vision_llm_router,
|
||||
)
|
||||
from app.db import User, create_db_and_tables, get_async_session
|
||||
from app.exceptions import GENERIC_5XX_MESSAGE, ISSUES_URL, SurfSenseError
|
||||
|
|
@ -621,6 +622,7 @@ async def lifespan(app: FastAPI):
|
|||
initialize_pricing_registration()
|
||||
initialize_llm_router()
|
||||
initialize_image_gen_router()
|
||||
initialize_vision_llm_router()
|
||||
|
||||
# Phase 1.7 — JIT warmup. Bounded so a stuck warmup never delays
|
||||
# worker readiness. ``shield`` so Uvicorn cancelling startup
|
||||
|
|
|
|||
|
|
@ -39,31 +39,31 @@ async def build_dependencies(
|
|||
*,
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
chat_model_id: int | None = None,
|
||||
image_gen_model_id: int | None = None,
|
||||
vision_model_id: int | None = None,
|
||||
agent_llm_id: int | None = None,
|
||||
image_generation_config_id: int | None = None,
|
||||
vision_llm_config_id: int | None = None,
|
||||
) -> AgentDependencies:
|
||||
"""Load the LLM bundle, connector service, and a per-invoke in-memory checkpointer.
|
||||
|
||||
Resolves the chat model from the automation's *captured* model snapshot
|
||||
(``chat_model_id``) so runs are insulated from later chat/search-space model
|
||||
Resolves the agent LLM from the automation's *captured* model snapshot
|
||||
(``agent_llm_id``) so runs are insulated from later chat/search-space model
|
||||
changes. The model policy is enforced here as a runtime backstop: a captured
|
||||
model that is no longer billable (e.g. a premium global config was removed)
|
||||
fails the run clearly instead of silently consuming a free model.
|
||||
|
||||
When ``chat_model_id`` is ``None`` (no captured snapshot — defensive fallback),
|
||||
fall back to the live search space's ``chat_model_id`` and validate that.
|
||||
When ``agent_llm_id`` is ``None`` (no captured snapshot — defensive fallback),
|
||||
fall back to the live search space's ``agent_llm_id`` and validate that.
|
||||
"""
|
||||
if chat_model_id is not None:
|
||||
if agent_llm_id is not None:
|
||||
try:
|
||||
assert_models_billable(
|
||||
chat_model_id=chat_model_id,
|
||||
image_gen_model_id=image_gen_model_id,
|
||||
vision_model_id=vision_model_id,
|
||||
agent_llm_id=agent_llm_id,
|
||||
image_generation_config_id=image_generation_config_id,
|
||||
vision_llm_config_id=vision_llm_config_id,
|
||||
)
|
||||
except AutomationModelPolicyError as exc:
|
||||
raise DependencyError(str(exc)) from exc
|
||||
resolved_chat_model_id = chat_model_id or 0
|
||||
resolved_agent_llm_id = agent_llm_id or 0
|
||||
else:
|
||||
search_space = await session.get(SearchSpace, search_space_id)
|
||||
if search_space is None:
|
||||
|
|
@ -72,15 +72,15 @@ async def build_dependencies(
|
|||
assert_automation_models_billable(search_space)
|
||||
except AutomationModelPolicyError as exc:
|
||||
raise DependencyError(str(exc)) from exc
|
||||
resolved_chat_model_id = search_space.chat_model_id or 0
|
||||
resolved_agent_llm_id = search_space.agent_llm_id or 0
|
||||
|
||||
llm, agent_config, err = await load_llm_bundle(
|
||||
session,
|
||||
config_id=resolved_chat_model_id,
|
||||
config_id=resolved_agent_llm_id,
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
if err is not None or llm is None:
|
||||
raise DependencyError(err or "failed to load chat model config")
|
||||
raise DependencyError(err or "failed to load agent LLM config")
|
||||
|
||||
connector_service, firecrawl_api_key = await setup_connector_and_firecrawl(
|
||||
session, search_space_id=search_space_id
|
||||
|
|
|
|||
|
|
@ -150,9 +150,9 @@ async def run_agent_task(
|
|||
deps = await build_dependencies(
|
||||
session=agent_session,
|
||||
search_space_id=ctx.search_space_id,
|
||||
chat_model_id=ctx.chat_model_id,
|
||||
image_gen_model_id=ctx.image_gen_model_id,
|
||||
vision_model_id=ctx.vision_model_id,
|
||||
agent_llm_id=ctx.agent_llm_id,
|
||||
image_generation_config_id=ctx.image_generation_config_id,
|
||||
vision_llm_config_id=ctx.vision_llm_config_id,
|
||||
)
|
||||
|
||||
agent = await create_multi_agent_chat_deep_agent(
|
||||
|
|
@ -167,7 +167,7 @@ async def run_agent_task(
|
|||
firecrawl_api_key=deps.firecrawl_api_key,
|
||||
thread_visibility=ChatVisibility.PRIVATE,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
image_gen_model_id=ctx.image_gen_model_id,
|
||||
image_generation_config_id=ctx.image_generation_config_id,
|
||||
)
|
||||
|
||||
agent_query, runtime_context = await _resolve_mention_context(
|
||||
|
|
|
|||
|
|
@ -23,9 +23,9 @@ class ActionContext:
|
|||
# Captured model snapshot from the automation definition (``definition.models``),
|
||||
# resolved per run instead of the live search space. ``None`` falls back to the
|
||||
# search space's current prefs (defensive; should not happen post-capture).
|
||||
chat_model_id: int | None = None
|
||||
image_gen_model_id: int | None = None
|
||||
vision_model_id: int | None = None
|
||||
agent_llm_id: int | None = None
|
||||
image_generation_config_id: int | None = None
|
||||
vision_llm_config_id: int | None = None
|
||||
|
||||
|
||||
ActionHandler = Callable[[dict[str, Any]], Awaitable[Any]]
|
||||
|
|
|
|||
|
|
@ -132,7 +132,9 @@ def _build_action_ctx(
|
|||
step_id=step.step_id,
|
||||
search_space_id=automation.search_space_id,
|
||||
creator_user_id=automation.created_by_user_id,
|
||||
chat_model_id=models.chat_model_id if models else None,
|
||||
image_gen_model_id=models.image_gen_model_id if models else None,
|
||||
vision_model_id=models.vision_model_id if models else None,
|
||||
agent_llm_id=models.agent_llm_id if models else None,
|
||||
image_generation_config_id=(
|
||||
models.image_generation_config_id if models else None
|
||||
),
|
||||
vision_llm_config_id=models.vision_llm_config_id if models else None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -14,16 +14,16 @@ from .trigger_spec import TriggerSpec
|
|||
class AutomationModels(BaseModel):
|
||||
"""Captured model profile for an automation.
|
||||
|
||||
Snapshotted from the search space's model roles at create time so runs are
|
||||
insulated from later chat/search-space model changes. Model-id conventions
|
||||
Snapshotted from the search space's preferences at create time so runs are
|
||||
insulated from later chat/search-space model changes. Config-id conventions
|
||||
match the shared scheme (``0`` Auto, ``< 0`` global, ``> 0`` BYOK).
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
chat_model_id: int = 0
|
||||
image_gen_model_id: int = 0
|
||||
vision_model_id: int = 0
|
||||
agent_llm_id: int = 0
|
||||
image_generation_config_id: int = 0
|
||||
vision_llm_config_id: int = 0
|
||||
|
||||
|
||||
class AutomationDefinition(BaseModel):
|
||||
|
|
|
|||
|
|
@ -57,9 +57,9 @@ class AutomationService:
|
|||
else:
|
||||
search_space = await self._assert_models_billable(payload.search_space_id)
|
||||
payload.definition.models = AutomationModels(
|
||||
chat_model_id=search_space.chat_model_id or 0,
|
||||
image_gen_model_id=search_space.image_gen_model_id or 0,
|
||||
vision_model_id=search_space.vision_model_id or 0,
|
||||
agent_llm_id=search_space.agent_llm_id or 0,
|
||||
image_generation_config_id=search_space.image_generation_config_id or 0,
|
||||
vision_llm_config_id=search_space.vision_llm_config_id or 0,
|
||||
)
|
||||
|
||||
automation = Automation(
|
||||
|
|
@ -225,9 +225,9 @@ class AutomationService:
|
|||
"""
|
||||
try:
|
||||
assert_models_billable(
|
||||
chat_model_id=models.chat_model_id,
|
||||
image_gen_model_id=models.image_gen_model_id,
|
||||
vision_model_id=models.vision_model_id,
|
||||
agent_llm_id=models.agent_llm_id,
|
||||
image_generation_config_id=models.image_generation_config_id,
|
||||
vision_llm_config_id=models.vision_llm_config_id,
|
||||
)
|
||||
except AutomationModelPolicyError as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
|
|
|
|||
|
|
@ -2,11 +2,11 @@
|
|||
|
||||
Automations run unattended, so every run must be **billable**: it may only use
|
||||
either a premium global model (``billing_tier == "premium"``) or a user-provided
|
||||
BYOK model (a positive model id pointing at a per-user/per-space DB row). Free
|
||||
BYOK model (a positive config id pointing at a per-user/per-space DB row). Free
|
||||
global models and Auto mode are blocked, because Auto can dispatch to a free
|
||||
deployment and free models aren't metered in premium credits.
|
||||
|
||||
Model id conventions (shared across chat / image / vision):
|
||||
Config id conventions (shared across chat / image / vision):
|
||||
- ``id == 0`` → Auto mode (``AUTO_MODE_ID`` / ``IMAGE_GEN_AUTO_MODE_ID`` /
|
||||
``VISION_AUTO_MODE_ID``). Blocked.
|
||||
- ``id < 0`` → global YAML/OpenRouter config. Allowed only if premium.
|
||||
|
|
@ -24,45 +24,70 @@ from typing import TYPE_CHECKING, Literal
|
|||
if TYPE_CHECKING:
|
||||
from app.db import SearchSpace
|
||||
|
||||
ModelKind = Literal["chat", "image", "vision"]
|
||||
ModelKind = Literal["llm", "image", "vision"]
|
||||
|
||||
_KIND_LABEL: dict[ModelKind, str] = {
|
||||
"chat": "chat model",
|
||||
"llm": "agent LLM",
|
||||
"image": "image generation model",
|
||||
"vision": "vision model",
|
||||
}
|
||||
|
||||
|
||||
def _is_premium_global(model_id: int) -> bool:
|
||||
"""Return True if a negative (global) model id is a premium tier model."""
|
||||
def _is_premium_global(kind: ModelKind, config_id: int) -> bool:
|
||||
"""Return True if a negative (global) config id is a premium tier model."""
|
||||
from app.config import config as app_config
|
||||
|
||||
model = next((m for m in app_config.GLOBAL_MODELS if m.get("id") == model_id), None)
|
||||
if not model:
|
||||
cfg: dict | None = None
|
||||
if kind == "llm":
|
||||
from app.agents.chat.runtime.llm_config import (
|
||||
load_global_llm_config_by_id,
|
||||
)
|
||||
|
||||
cfg = load_global_llm_config_by_id(config_id)
|
||||
elif kind == "image":
|
||||
cfg = next(
|
||||
(
|
||||
c
|
||||
for c in app_config.GLOBAL_IMAGE_GEN_CONFIGS
|
||||
if c.get("id") == config_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
else: # vision
|
||||
cfg = next(
|
||||
(
|
||||
c
|
||||
for c in app_config.GLOBAL_VISION_LLM_CONFIGS
|
||||
if c.get("id") == config_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not cfg:
|
||||
return False
|
||||
return str(model.get("billing_tier", "free")).lower() == "premium"
|
||||
return str(cfg.get("billing_tier", "free")).lower() == "premium"
|
||||
|
||||
|
||||
def _classify(kind: ModelKind, model_id: int | None) -> tuple[bool, str]:
|
||||
"""Classify a resolved model id as allowed or blocked.
|
||||
def _classify(kind: ModelKind, config_id: int | None) -> tuple[bool, str]:
|
||||
"""Classify a resolved config id as allowed or blocked.
|
||||
|
||||
Returns ``(allowed, reason)``; ``reason`` is empty when allowed.
|
||||
"""
|
||||
label = _KIND_LABEL[kind]
|
||||
|
||||
if model_id is None or model_id == 0:
|
||||
if config_id is None or config_id == 0:
|
||||
return (
|
||||
False,
|
||||
f"The {label} is set to Auto mode. Automations require an explicit "
|
||||
"premium model or your own (BYOK) model so every run is billable.",
|
||||
)
|
||||
|
||||
if model_id > 0:
|
||||
# Positive id -> user/search-space BYOK model. Always allowed.
|
||||
if config_id > 0:
|
||||
# Positive id → user-owned BYOK config. Always allowed.
|
||||
return True, ""
|
||||
|
||||
# Negative id -> global model. Allowed only if premium.
|
||||
if _is_premium_global(model_id):
|
||||
# Negative id → global config. Allowed only if premium.
|
||||
if _is_premium_global(kind, config_id):
|
||||
return True, ""
|
||||
|
||||
return (
|
||||
|
|
@ -74,27 +99,27 @@ def _classify(kind: ModelKind, model_id: int | None) -> tuple[bool, str]:
|
|||
|
||||
def get_model_eligibility(
|
||||
*,
|
||||
chat_model_id: int | None,
|
||||
image_gen_model_id: int | None,
|
||||
vision_model_id: int | None,
|
||||
agent_llm_id: int | None,
|
||||
image_generation_config_id: int | None,
|
||||
vision_llm_config_id: int | None,
|
||||
) -> dict:
|
||||
"""Return ``{"allowed": bool, "violations": [...]}`` for explicit model ids.
|
||||
"""Return ``{"allowed": bool, "violations": [...]}`` for explicit config ids.
|
||||
|
||||
The ID-based core shared by both the search-space path (creation/eligibility)
|
||||
and the captured-snapshot path (runtime backstop). Each violation is
|
||||
``{"kind", "model_id", "reason"}``.
|
||||
``{"kind", "config_id", "reason"}``.
|
||||
"""
|
||||
checks: list[tuple[ModelKind, int | None]] = [
|
||||
("chat", chat_model_id),
|
||||
("image", image_gen_model_id),
|
||||
("vision", vision_model_id),
|
||||
("llm", agent_llm_id),
|
||||
("image", image_generation_config_id),
|
||||
("vision", vision_llm_config_id),
|
||||
]
|
||||
|
||||
violations: list[dict] = []
|
||||
for kind, model_id in checks:
|
||||
allowed, reason = _classify(kind, model_id)
|
||||
for kind, config_id in checks:
|
||||
allowed, reason = _classify(kind, config_id)
|
||||
if not allowed:
|
||||
violations.append({"kind": kind, "model_id": model_id, "reason": reason})
|
||||
violations.append({"kind": kind, "config_id": config_id, "reason": reason})
|
||||
|
||||
return {"allowed": not violations, "violations": violations}
|
||||
|
||||
|
|
@ -106,9 +131,9 @@ def get_automation_model_eligibility(search_space: SearchSpace) -> dict:
|
|||
wrapper over :func:`get_model_eligibility`.
|
||||
"""
|
||||
return get_model_eligibility(
|
||||
chat_model_id=search_space.chat_model_id,
|
||||
image_gen_model_id=search_space.image_gen_model_id,
|
||||
vision_model_id=search_space.vision_model_id,
|
||||
agent_llm_id=search_space.agent_llm_id,
|
||||
image_generation_config_id=search_space.image_generation_config_id,
|
||||
vision_llm_config_id=search_space.vision_llm_config_id,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -125,9 +150,9 @@ class AutomationModelPolicyError(Exception):
|
|||
|
||||
def assert_models_billable(
|
||||
*,
|
||||
chat_model_id: int | None,
|
||||
image_gen_model_id: int | None,
|
||||
vision_model_id: int | None,
|
||||
agent_llm_id: int | None,
|
||||
image_generation_config_id: int | None,
|
||||
vision_llm_config_id: int | None,
|
||||
) -> None:
|
||||
"""Raise :class:`AutomationModelPolicyError` if any explicit id is not billable.
|
||||
|
||||
|
|
@ -135,9 +160,9 @@ def assert_models_billable(
|
|||
captured model snapshot.
|
||||
"""
|
||||
result = get_model_eligibility(
|
||||
chat_model_id=chat_model_id,
|
||||
image_gen_model_id=image_gen_model_id,
|
||||
vision_model_id=vision_model_id,
|
||||
agent_llm_id=agent_llm_id,
|
||||
image_generation_config_id=image_generation_config_id,
|
||||
vision_llm_config_id=vision_llm_config_id,
|
||||
)
|
||||
if not result["allowed"]:
|
||||
raise AutomationModelPolicyError(result["violations"])
|
||||
|
|
|
|||
|
|
@ -115,12 +115,14 @@ def init_worker(**kwargs):
|
|||
initialize_llm_router,
|
||||
initialize_openrouter_integration,
|
||||
initialize_pricing_registration,
|
||||
initialize_vision_llm_router,
|
||||
)
|
||||
|
||||
initialize_openrouter_integration()
|
||||
initialize_pricing_registration()
|
||||
initialize_llm_router()
|
||||
initialize_image_gen_router()
|
||||
initialize_vision_llm_router()
|
||||
|
||||
|
||||
# Celery configuration, sourced from the central Config singleton
|
||||
|
|
@ -179,8 +181,7 @@ celery_app = Celery(
|
|||
backend=CELERY_RESULT_BACKEND,
|
||||
include=[
|
||||
"app.tasks.celery_tasks.document_tasks",
|
||||
"app.podcasts.tasks.draft",
|
||||
"app.podcasts.tasks.render",
|
||||
"app.tasks.celery_tasks.podcast_tasks",
|
||||
"app.tasks.celery_tasks.video_presentation_tasks",
|
||||
"app.tasks.celery_tasks.connector_tasks",
|
||||
"app.tasks.celery_tasks.obsidian_tasks",
|
||||
|
|
@ -188,10 +189,7 @@ celery_app = Celery(
|
|||
"app.tasks.celery_tasks.document_reindex_tasks",
|
||||
"app.tasks.celery_tasks.stale_notification_cleanup_task",
|
||||
"app.tasks.celery_tasks.stripe_reconciliation_task",
|
||||
"app.tasks.celery_tasks.auto_reload_task",
|
||||
"app.tasks.celery_tasks.gateway_tasks",
|
||||
"app.etl_pipeline.cache.eviction.task",
|
||||
"app.indexing_pipeline.cache.eviction.task",
|
||||
"app.automations.tasks.execute_run",
|
||||
"app.automations.triggers.builtin.schedule.selector",
|
||||
"app.automations.triggers.builtin.event.selector",
|
||||
|
|
@ -283,9 +281,16 @@ celery_app.conf.beat_schedule = {
|
|||
"expires": 60, # Task expires after 60 seconds if not picked up
|
||||
},
|
||||
},
|
||||
# Reconcile Stripe credit purchases that were paid but remained pending
|
||||
"reconcile-pending-stripe-credit-purchases": {
|
||||
"task": "reconcile_pending_stripe_credit_purchases",
|
||||
# Reconcile Stripe purchases that were paid but remained pending
|
||||
"reconcile-pending-stripe-page-purchases": {
|
||||
"task": "reconcile_pending_stripe_page_purchases",
|
||||
"schedule": crontab(**stripe_reconciliation_schedule_params),
|
||||
"options": {
|
||||
"expires": 60,
|
||||
},
|
||||
},
|
||||
"reconcile-pending-stripe-token-purchases": {
|
||||
"task": "reconcile_pending_stripe_token_purchases",
|
||||
"schedule": crontab(**stripe_reconciliation_schedule_params),
|
||||
"options": {
|
||||
"expires": 60,
|
||||
|
|
@ -306,18 +311,6 @@ celery_app.conf.beat_schedule = {
|
|||
"schedule": crontab(hour="3", minute="17"),
|
||||
"options": {"expires": 600},
|
||||
},
|
||||
# Prune the ETL parse cache (TTL + size budget) once daily, off-peak.
|
||||
"evict-etl-cache": {
|
||||
"task": "evict_etl_cache",
|
||||
"schedule": crontab(hour="4", minute="0"),
|
||||
"options": {"expires": 600},
|
||||
},
|
||||
# Prune the embedding cache (chunk+embedding sets) once daily, off-peak.
|
||||
"evict-embedding-cache": {
|
||||
"task": "evict_embedding_cache",
|
||||
"schedule": crontab(hour="4", minute="30"),
|
||||
"options": {"expires": 600},
|
||||
},
|
||||
# Fire due automation schedule triggers (Beat entry owned by the schedule
|
||||
# trigger; see app.automations.triggers.builtin.schedule.source).
|
||||
**SCHEDULE_BEAT_SCHEDULE,
|
||||
|
|
|
|||
|
|
@ -78,7 +78,8 @@ def load_global_llm_configs():
|
|||
# stamps) never leak into the cached YAML structure.
|
||||
configs = copy.deepcopy(data.get("global_llm_configs", []))
|
||||
|
||||
# Lazy import keeps the `app.config` -> `app.services` edge one-way.
|
||||
# Lazy import keeps the `app.config` -> `app.services` edge one-way
|
||||
# and matches the `provider_api_base` pattern used elsewhere.
|
||||
from app.services.provider_capabilities import derive_supports_image_input
|
||||
|
||||
seen_slugs: dict[str, int] = {}
|
||||
|
|
@ -103,7 +104,7 @@ def load_global_llm_configs():
|
|||
else None
|
||||
)
|
||||
cfg["supports_image_input"] = derive_supports_image_input(
|
||||
provider=cfg.get("provider") or cfg.get("litellm_provider"),
|
||||
provider=cfg.get("provider"),
|
||||
model_name=cfg.get("model_name"),
|
||||
base_model=base_model,
|
||||
custom_provider=cfg.get("custom_provider"),
|
||||
|
|
@ -119,10 +120,10 @@ def load_global_llm_configs():
|
|||
else:
|
||||
seen_slugs[slug] = cfg.get("id", 0)
|
||||
|
||||
# Stamp Auto ranking metadata. YAML configs are always
|
||||
# Stamp Auto (Fastest) ranking metadata. YAML configs are always
|
||||
# Tier A — operator-curated, locked first when premium-eligible.
|
||||
# The OpenRouter refresh tick later re-stamps health for any cfg
|
||||
# whose provider == "openrouter" via _enrich_health.
|
||||
# whose provider == "OPENROUTER" via _enrich_health.
|
||||
try:
|
||||
from app.services.quality_score import static_score_yaml
|
||||
|
||||
|
|
@ -132,7 +133,7 @@ def load_global_llm_configs():
|
|||
cfg["quality_score_static"] = static_q
|
||||
cfg["quality_score"] = static_q
|
||||
cfg["quality_score_health"] = None
|
||||
# YAML cfgs whose provider is openrouter are also subject
|
||||
# YAML cfgs whose provider is OPENROUTER are also subject
|
||||
# to health gating against their own /endpoints data — a
|
||||
# hand-picked dead OR model is still dead. _enrich_health
|
||||
# re-stamps health_gated for them on the next refresh tick.
|
||||
|
|
@ -210,6 +211,42 @@ def load_global_image_gen_configs():
|
|||
return []
|
||||
|
||||
|
||||
def load_global_vision_llm_configs():
|
||||
data = _global_config_data()
|
||||
if not data:
|
||||
return []
|
||||
|
||||
try:
|
||||
configs = copy.deepcopy(data.get("global_vision_llm_configs", []) or [])
|
||||
for cfg in configs:
|
||||
if isinstance(cfg, dict):
|
||||
cfg.setdefault("billing_tier", "free")
|
||||
return configs
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load global vision LLM configs: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def load_vision_llm_router_settings():
|
||||
default_settings = {
|
||||
"routing_strategy": "usage-based-routing",
|
||||
"num_retries": 3,
|
||||
"allowed_fails": 3,
|
||||
"cooldown_time": 60,
|
||||
}
|
||||
|
||||
data = _global_config_data()
|
||||
if not data:
|
||||
return default_settings
|
||||
|
||||
try:
|
||||
settings = data.get("vision_llm_router_settings", {})
|
||||
return {**default_settings, **settings}
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load vision LLM router settings: {e}")
|
||||
return default_settings
|
||||
|
||||
|
||||
def load_image_gen_router_settings():
|
||||
"""
|
||||
Load router settings for image generation Auto mode from YAML file.
|
||||
|
|
@ -326,8 +363,8 @@ def initialize_openrouter_integration():
|
|||
else:
|
||||
print("Info: OpenRouter integration enabled but no models fetched")
|
||||
|
||||
# Image generation emissions reuse the catalogue already cached by
|
||||
# ``service.initialize``
|
||||
# Image generation + vision LLM emissions are opt-in (issue L).
|
||||
# Both reuse the catalogue already cached by ``service.initialize``
|
||||
# so we don't make additional network calls here.
|
||||
if settings.get("image_generation_enabled"):
|
||||
try:
|
||||
|
|
@ -341,26 +378,21 @@ def initialize_openrouter_integration():
|
|||
except Exception as e:
|
||||
print(f"Warning: Failed to inject OpenRouter image-gen configs: {e}")
|
||||
|
||||
refresh_global_model_catalog()
|
||||
if settings.get("vision_enabled"):
|
||||
try:
|
||||
vision_configs = service.get_vision_llm_configs()
|
||||
if vision_configs:
|
||||
config.GLOBAL_VISION_LLM_CONFIGS.extend(vision_configs)
|
||||
print(
|
||||
f"Info: OpenRouter integration added {len(vision_configs)} "
|
||||
f"vision LLM models"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to inject OpenRouter vision-LLM configs: {e}")
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to initialize OpenRouter integration: {e}")
|
||||
|
||||
|
||||
def materialize_global_configs():
|
||||
from app.services.global_model_catalog import materialize_global_model_catalog
|
||||
|
||||
return materialize_global_model_catalog(
|
||||
chat_configs=getattr(config, "GLOBAL_LLM_CONFIGS", []),
|
||||
image_configs=getattr(config, "GLOBAL_IMAGE_GEN_CONFIGS", []),
|
||||
)
|
||||
|
||||
|
||||
def refresh_global_model_catalog():
|
||||
connections, models = materialize_global_configs()
|
||||
config.GLOBAL_CONNECTIONS = connections
|
||||
config.GLOBAL_MODELS = models
|
||||
|
||||
|
||||
def initialize_pricing_registration():
|
||||
"""
|
||||
Teach LiteLLM the per-token cost of every deployment in
|
||||
|
|
@ -398,10 +430,7 @@ def initialize_llm_router():
|
|||
router_settings = config.ROUTER_SETTINGS
|
||||
|
||||
if not all_configs:
|
||||
print(
|
||||
"Info: No global LLM configs found; global Auto pool is unavailable. "
|
||||
"Auto can still use enabled BYOK models."
|
||||
)
|
||||
print("Info: No global LLM configs found, Auto mode will not be available")
|
||||
return
|
||||
|
||||
try:
|
||||
|
|
@ -446,6 +475,32 @@ def initialize_image_gen_router():
|
|||
print(f"Warning: Failed to initialize Image Generation Router: {e}")
|
||||
|
||||
|
||||
def initialize_vision_llm_router():
|
||||
vision_configs = load_global_vision_llm_configs()
|
||||
# Reuse the router settings already parsed at Config construction. The
|
||||
# *configs* list is intentionally re-read from YAML (it must exclude the
|
||||
# OpenRouter-injected dynamic models held in config.GLOBAL_VISION_LLM_CONFIGS).
|
||||
router_settings = config.VISION_LLM_ROUTER_SETTINGS
|
||||
|
||||
if not vision_configs:
|
||||
print(
|
||||
"Info: No global vision LLM configs found, "
|
||||
"Vision LLM Auto mode will not be available"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
from app.services.vision_llm_router_service import VisionLLMRouterService
|
||||
|
||||
VisionLLMRouterService.initialize(vision_configs, router_settings)
|
||||
print(
|
||||
f"Info: Vision LLM Router initialized with {len(vision_configs)} models "
|
||||
f"(strategy: {router_settings.get('routing_strategy', 'usage-based-routing')})"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to initialize Vision LLM Router: {e}")
|
||||
|
||||
|
||||
class Config:
|
||||
# Check if ffmpeg is installed
|
||||
if not is_ffmpeg_installed():
|
||||
|
|
@ -486,28 +541,6 @@ class Config:
|
|||
# Database
|
||||
DATABASE_URL = os.getenv("DATABASE_URL")
|
||||
|
||||
# When TRUE (default) the app ensures extensions/tables/indexes exist on
|
||||
# startup. Set FALSE in environments where schema is owned exclusively by
|
||||
# Alembic migrations to skip all boot-time DDL.
|
||||
DB_BOOTSTRAP_ON_STARTUP = (
|
||||
os.getenv("DB_BOOTSTRAP_ON_STARTUP", "TRUE").upper() == "TRUE"
|
||||
)
|
||||
# Per-session lock_timeout (ms) applied to boot-time DDL so a contended
|
||||
# CREATE INDEX / CREATE TABLE fails fast instead of hanging the FastAPI
|
||||
# lifespan forever behind another transaction's lock.
|
||||
DB_DDL_LOCK_TIMEOUT_MS = int(os.getenv("DB_DDL_LOCK_TIMEOUT_MS", "5000"))
|
||||
# Global idle_in_transaction_session_timeout (ms) applied to every pooled
|
||||
# connection so an abandoned "idle in transaction" session can't wedge the
|
||||
# database indefinitely. 0 disables. Only applied to asyncpg connections.
|
||||
DB_IDLE_IN_TX_TIMEOUT_MS = int(os.getenv("DB_IDLE_IN_TX_TIMEOUT_MS", "900000"))
|
||||
# Same protection for the separate Celery worker engine, where long-running
|
||||
# ingestion/podcast/video tasks live. Kept higher than the web default so a
|
||||
# legitimate per-document embed window is never reaped: if a task hasn't
|
||||
# touched the DB in 60 min it's treated as orphaned and dropped. 0 disables.
|
||||
DB_CELERY_IDLE_IN_TX_TIMEOUT_MS = int(
|
||||
os.getenv("DB_CELERY_IDLE_IN_TX_TIMEOUT_MS", "3600000")
|
||||
)
|
||||
|
||||
# Celery / Redis
|
||||
# Redis (single endpoint for Celery broker, result backend, and app cache).
|
||||
# Legacy CELERY_BROKER_URL / CELERY_RESULT_BACKEND / REDIS_APP_URL still
|
||||
|
|
@ -557,15 +590,14 @@ class Config:
|
|||
# Platform web search (SearXNG)
|
||||
SEARXNG_DEFAULT_HOST = os.getenv("SEARXNG_DEFAULT_HOST")
|
||||
|
||||
SURFSENSE_PUBLIC_URL = os.getenv("SURFSENSE_PUBLIC_URL")
|
||||
NEXT_FRONTEND_URL = os.getenv("NEXT_FRONTEND_URL") or SURFSENSE_PUBLIC_URL
|
||||
NEXT_FRONTEND_URL = os.getenv("NEXT_FRONTEND_URL")
|
||||
# Backend URL to override the http to https in the OAuth redirect URI
|
||||
BACKEND_URL = os.getenv("BACKEND_URL") or SURFSENSE_PUBLIC_URL
|
||||
BACKEND_URL = os.getenv("BACKEND_URL")
|
||||
|
||||
# Messaging gateway
|
||||
# Messaging gateway (Telegram v1)
|
||||
# Global master switch: when FALSE, no gateway supervisors/workers start and all
|
||||
# gated gateway HTTP routes return 404, regardless of the per-channel flags below.
|
||||
GATEWAY_ENABLED = os.getenv("GATEWAY_ENABLED", "FALSE").upper() == "TRUE"
|
||||
# gateway HTTP routes return 404, regardless of the per-channel flags below.
|
||||
GATEWAY_ENABLED = os.getenv("GATEWAY_ENABLED", "TRUE").upper() == "TRUE"
|
||||
TELEGRAM_SHARED_BOT_TOKEN = os.getenv("TELEGRAM_SHARED_BOT_TOKEN")
|
||||
TELEGRAM_SHARED_BOT_USERNAME = os.getenv("TELEGRAM_SHARED_BOT_USERNAME")
|
||||
TELEGRAM_WEBHOOK_SECRET = os.getenv("TELEGRAM_WEBHOOK_SECRET")
|
||||
|
|
@ -608,9 +640,14 @@ class Config:
|
|||
)
|
||||
GATEWAY_DISCORD_REDIRECT_URI = os.getenv("GATEWAY_DISCORD_REDIRECT_URI")
|
||||
|
||||
# Stripe checkout (shared secrets for the unified credit wallet)
|
||||
# Stripe checkout for pay-as-you-go page packs
|
||||
STRIPE_SECRET_KEY = os.getenv("STRIPE_SECRET_KEY")
|
||||
STRIPE_WEBHOOK_SECRET = os.getenv("STRIPE_WEBHOOK_SECRET")
|
||||
STRIPE_PRICE_ID = os.getenv("STRIPE_PRICE_ID")
|
||||
STRIPE_PAGES_PER_UNIT = int(os.getenv("STRIPE_PAGES_PER_UNIT", "1000"))
|
||||
STRIPE_PAGE_BUYING_ENABLED = (
|
||||
os.getenv("STRIPE_PAGE_BUYING_ENABLED", "TRUE").upper() == "TRUE"
|
||||
)
|
||||
STRIPE_RECONCILIATION_LOOKBACK_MINUTES = int(
|
||||
os.getenv("STRIPE_RECONCILIATION_LOOKBACK_MINUTES", "10")
|
||||
)
|
||||
|
|
@ -618,56 +655,27 @@ class Config:
|
|||
os.getenv("STRIPE_RECONCILIATION_BATCH_SIZE", "100")
|
||||
)
|
||||
|
||||
# Unified credit wallet (micro-USD) settings.
|
||||
# Premium credit (micro-USD) quota settings.
|
||||
#
|
||||
# Storage unit is integer micro-USD (1_000_000 = $1.00). A single
|
||||
# ``credit_micros_balance`` funds both ETL page processing and premium
|
||||
# model calls. New users start with ``DEFAULT_CREDIT_MICROS_BALANCE``
|
||||
# ($5 by default).
|
||||
#
|
||||
# Legacy env names (``PREMIUM_CREDIT_MICROS_LIMIT`` / ``PREMIUM_TOKEN_LIMIT``,
|
||||
# ``STRIPE_PREMIUM_TOKEN_PRICE_ID``, ``STRIPE_CREDIT_MICROS_PER_UNIT`` /
|
||||
# ``STRIPE_TOKENS_PER_UNIT``, ``STRIPE_TOKEN_BUYING_ENABLED``) are still
|
||||
# honoured as fall-backs for one release; deprecation warnings fire below.
|
||||
DEFAULT_CREDIT_MICROS_BALANCE = int(
|
||||
os.getenv("DEFAULT_CREDIT_MICROS_BALANCE")
|
||||
or os.getenv("PREMIUM_CREDIT_MICROS_LIMIT")
|
||||
# Storage unit is integer micro-USD (1_000_000 = $1.00). The legacy
|
||||
# ``PREMIUM_TOKEN_LIMIT`` and ``STRIPE_TOKENS_PER_UNIT`` env vars are
|
||||
# still honoured for one release as fall-back values — the prior
|
||||
# $1-per-1M-tokens Stripe price means every existing value maps 1:1
|
||||
# to micros, so operators upgrading without changing their .env still
|
||||
# get correct behaviour. A startup deprecation warning fires below if
|
||||
# they're set.
|
||||
PREMIUM_CREDIT_MICROS_LIMIT = int(
|
||||
os.getenv("PREMIUM_CREDIT_MICROS_LIMIT")
|
||||
or os.getenv("PREMIUM_TOKEN_LIMIT", "5000000")
|
||||
)
|
||||
STRIPE_CREDIT_PRICE_ID = os.getenv("STRIPE_CREDIT_PRICE_ID") or os.getenv(
|
||||
"STRIPE_PREMIUM_TOKEN_PRICE_ID"
|
||||
)
|
||||
STRIPE_PREMIUM_TOKEN_PRICE_ID = os.getenv("STRIPE_PREMIUM_TOKEN_PRICE_ID")
|
||||
STRIPE_CREDIT_MICROS_PER_UNIT = int(
|
||||
os.getenv("STRIPE_CREDIT_MICROS_PER_UNIT")
|
||||
or os.getenv("STRIPE_TOKENS_PER_UNIT", "1000000")
|
||||
)
|
||||
STRIPE_CREDIT_BUYING_ENABLED = (
|
||||
os.getenv("STRIPE_CREDIT_BUYING_ENABLED")
|
||||
or os.getenv("STRIPE_TOKEN_BUYING_ENABLED", "FALSE")
|
||||
).upper() == "TRUE"
|
||||
|
||||
# ETL page processing debits the credit wallet only when enabled. Defaults
|
||||
# to FALSE so self-hosted / OSS installs keep effectively-free ETL; hosted
|
||||
# deployments set this TRUE. 1 page == ``MICROS_PER_PAGE`` micro-USD.
|
||||
ETL_CREDIT_BILLING_ENABLED = (
|
||||
os.getenv("ETL_CREDIT_BILLING_ENABLED", "FALSE").upper() == "TRUE"
|
||||
STRIPE_TOKEN_BUYING_ENABLED = (
|
||||
os.getenv("STRIPE_TOKEN_BUYING_ENABLED", "FALSE").upper() == "TRUE"
|
||||
)
|
||||
MICROS_PER_PAGE = int(os.getenv("MICROS_PER_PAGE", "1000"))
|
||||
|
||||
# Low-balance WARNING threshold (micro-USD). Surfaced by the quota service
|
||||
# so the UI can nudge the user to top up / enable auto-reload. $0.50.
|
||||
CREDIT_LOW_BALANCE_WARNING_MICROS = int(
|
||||
os.getenv("CREDIT_LOW_BALANCE_WARNING_MICROS", "500000")
|
||||
)
|
||||
|
||||
# Auto-reload (off-session Stripe top-up) feature flag and guards.
|
||||
AUTO_RELOAD_ENABLED = os.getenv("AUTO_RELOAD_ENABLED", "FALSE").upper() == "TRUE"
|
||||
# Minimum configurable reload amount (micro-USD). $1.00 to match pack pricing.
|
||||
AUTO_RELOAD_MIN_AMOUNT_MICROS = int(
|
||||
os.getenv("AUTO_RELOAD_MIN_AMOUNT_MICROS", "1000000")
|
||||
)
|
||||
# Cooldown so a burst of debits can't fire multiple charges (minutes).
|
||||
AUTO_RELOAD_COOLDOWN_MINUTES = int(os.getenv("AUTO_RELOAD_COOLDOWN_MINUTES", "10"))
|
||||
|
||||
# Safety ceiling on the per-call premium reservation. ``stream_new_chat``
|
||||
# estimates an upper-bound cost from ``litellm.get_model_info`` x the
|
||||
|
|
@ -677,13 +685,14 @@ class Config:
|
|||
# reserve_tokens ≈ $0.36) with headroom.
|
||||
QUOTA_MAX_RESERVE_MICROS = int(os.getenv("QUOTA_MAX_RESERVE_MICROS", "1000000"))
|
||||
|
||||
if (
|
||||
os.getenv("PREMIUM_TOKEN_LIMIT") or os.getenv("PREMIUM_CREDIT_MICROS_LIMIT")
|
||||
) and not os.getenv("DEFAULT_CREDIT_MICROS_BALANCE"):
|
||||
if os.getenv("PREMIUM_TOKEN_LIMIT") and not os.getenv(
|
||||
"PREMIUM_CREDIT_MICROS_LIMIT"
|
||||
):
|
||||
print(
|
||||
"Warning: PREMIUM_TOKEN_LIMIT / PREMIUM_CREDIT_MICROS_LIMIT are "
|
||||
"deprecated; rename to DEFAULT_CREDIT_MICROS_BALANCE. The old keys "
|
||||
"will be removed in a future release."
|
||||
"Warning: PREMIUM_TOKEN_LIMIT is deprecated; rename to "
|
||||
"PREMIUM_CREDIT_MICROS_LIMIT (1:1 numerical mapping under the "
|
||||
"current Stripe price). The old key will be removed in a "
|
||||
"future release."
|
||||
)
|
||||
if os.getenv("STRIPE_TOKENS_PER_UNIT") and not os.getenv(
|
||||
"STRIPE_CREDIT_MICROS_PER_UNIT"
|
||||
|
|
@ -693,22 +702,6 @@ class Config:
|
|||
"STRIPE_CREDIT_MICROS_PER_UNIT (1:1 numerical mapping). "
|
||||
"The old key will be removed in a future release."
|
||||
)
|
||||
if os.getenv("STRIPE_PREMIUM_TOKEN_PRICE_ID") and not os.getenv(
|
||||
"STRIPE_CREDIT_PRICE_ID"
|
||||
):
|
||||
print(
|
||||
"Warning: STRIPE_PREMIUM_TOKEN_PRICE_ID is deprecated; rename to "
|
||||
"STRIPE_CREDIT_PRICE_ID. The old key will be removed in a future "
|
||||
"release."
|
||||
)
|
||||
if os.getenv("STRIPE_TOKEN_BUYING_ENABLED") and not os.getenv(
|
||||
"STRIPE_CREDIT_BUYING_ENABLED"
|
||||
):
|
||||
print(
|
||||
"Warning: STRIPE_TOKEN_BUYING_ENABLED is deprecated; rename to "
|
||||
"STRIPE_CREDIT_BUYING_ENABLED. The old key will be removed in a "
|
||||
"future release."
|
||||
)
|
||||
|
||||
# Anonymous / no-login mode settings
|
||||
NOLOGIN_MODE_ENABLED = os.getenv("NOLOGIN_MODE_ENABLED", "FALSE").upper() == "TRUE"
|
||||
|
|
@ -730,7 +723,7 @@ class Config:
|
|||
os.getenv("QUOTA_DEFAULT_IMAGE_RESERVE_MICROS", "50000")
|
||||
)
|
||||
|
||||
# Per-podcast reservation (in micro-USD). One chat model call generating
|
||||
# Per-podcast reservation (in micro-USD). One agent LLM call generating
|
||||
# a transcript, typically 5k-20k completion tokens. $0.20 covers a long
|
||||
# premium-model run. Tune via env.
|
||||
QUOTA_DEFAULT_PODCAST_RESERVE_MICROS = int(
|
||||
|
|
@ -836,13 +829,6 @@ class Config:
|
|||
# LLM instances are now managed per-user through the LLMConfig system
|
||||
# Legacy environment variables removed in favor of user-specific configurations
|
||||
|
||||
# True when an operator-provided global_llm_config.yaml is present.
|
||||
# Used to gate the per-search-space LLM onboarding flow: when a global
|
||||
# config file exists, search spaces inherit it and onboarding is skipped.
|
||||
GLOBAL_LLM_CONFIG_FILE_EXISTS = (
|
||||
BASE_DIR / "app" / "config" / "global_llm_config.yaml"
|
||||
).exists()
|
||||
|
||||
# Global LLM Configurations (optional)
|
||||
# Load from global_llm_config.yaml if available
|
||||
# These can be used as default options for users
|
||||
|
|
@ -857,17 +843,11 @@ class Config:
|
|||
# Router settings for Image Generation Auto mode
|
||||
IMAGE_GEN_ROUTER_SETTINGS = load_image_gen_router_settings()
|
||||
|
||||
# Virtual GLOBAL connection/model catalog. This is server-only metadata
|
||||
# derived from global_llm_config.yaml; GLOBAL keys are not stored in DB.
|
||||
from app.services.global_model_catalog import (
|
||||
materialize_global_model_catalog as _materialize_global_model_catalog,
|
||||
)
|
||||
# Global Vision LLM Configurations (optional)
|
||||
GLOBAL_VISION_LLM_CONFIGS = load_global_vision_llm_configs()
|
||||
|
||||
GLOBAL_CONNECTIONS, GLOBAL_MODELS = _materialize_global_model_catalog(
|
||||
chat_configs=GLOBAL_LLM_CONFIGS,
|
||||
image_configs=GLOBAL_IMAGE_GEN_CONFIGS,
|
||||
)
|
||||
del _materialize_global_model_catalog
|
||||
# Router settings for Vision LLM Auto mode
|
||||
VISION_LLM_ROUTER_SETTINGS = load_vision_llm_router_settings()
|
||||
|
||||
# OpenRouter Integration settings (optional)
|
||||
OPENROUTER_INTEGRATION_SETTINGS = load_openrouter_integration_settings()
|
||||
|
|
@ -923,6 +903,9 @@ class Config:
|
|||
# ETL Service
|
||||
ETL_SERVICE = os.getenv("ETL_SERVICE")
|
||||
|
||||
# Pages limit for ETL services (default to very high number for OSS unlimited usage)
|
||||
PAGES_LIMIT = int(os.getenv("PAGES_LIMIT", "999999999"))
|
||||
|
||||
if ETL_SERVICE == "UNSTRUCTURED":
|
||||
# Unstructured API Key
|
||||
UNSTRUCTURED_API_KEY = os.getenv("UNSTRUCTURED_API_KEY")
|
||||
|
|
@ -933,47 +916,6 @@ class Config:
|
|||
AZURE_DI_ENDPOINT = os.getenv("AZURE_DI_ENDPOINT")
|
||||
AZURE_DI_KEY = os.getenv("AZURE_DI_KEY")
|
||||
|
||||
# ETL parse cache: reuse parser output for identical bytes across workspaces.
|
||||
ETL_CACHE_ENABLED = (
|
||||
os.getenv("ETL_CACHE_ENABLED", "false").strip().lower() == "true"
|
||||
)
|
||||
# Bump to invalidate every cached entry after a parser/behaviour change.
|
||||
ETL_CACHE_PARSER_VERSION = int(os.getenv("ETL_CACHE_PARSER_VERSION", "1"))
|
||||
ETL_CACHE_TTL_DAYS = int(os.getenv("ETL_CACHE_TTL_DAYS", "90"))
|
||||
ETL_CACHE_MAX_TOTAL_MB = int(os.getenv("ETL_CACHE_MAX_TOTAL_MB", "5120"))
|
||||
ETL_CACHE_EVICTION_BATCH = int(os.getenv("ETL_CACHE_EVICTION_BATCH", "500"))
|
||||
# Optional dedicated blob storage; unset reuses the main file_storage backend.
|
||||
ETL_CACHE_STORAGE_BACKEND = os.getenv("ETL_CACHE_STORAGE_BACKEND")
|
||||
ETL_CACHE_STORAGE_CONTAINER = os.getenv("ETL_CACHE_STORAGE_CONTAINER")
|
||||
ETL_CACHE_STORAGE_LOCAL_PATH = os.getenv("ETL_CACHE_STORAGE_LOCAL_PATH")
|
||||
|
||||
# Embedding cache: reuse chunk+embedding output for identical markdown across
|
||||
# workspaces. Blobs share the ETL_CACHE_STORAGE_* backend.
|
||||
EMBEDDING_CACHE_ENABLED = (
|
||||
os.getenv("EMBEDDING_CACHE_ENABLED", "false").strip().lower() == "true"
|
||||
)
|
||||
# Bump to invalidate every cached embedding set after a chunker change.
|
||||
EMBEDDING_CACHE_CHUNKER_VERSION = int(
|
||||
os.getenv("EMBEDDING_CACHE_CHUNKER_VERSION", "1")
|
||||
)
|
||||
EMBEDDING_CACHE_TTL_DAYS = int(os.getenv("EMBEDDING_CACHE_TTL_DAYS", "90"))
|
||||
EMBEDDING_CACHE_MAX_TOTAL_MB = int(
|
||||
os.getenv("EMBEDDING_CACHE_MAX_TOTAL_MB", "5120")
|
||||
)
|
||||
EMBEDDING_CACHE_EVICTION_BATCH = int(
|
||||
os.getenv("EMBEDDING_CACHE_EVICTION_BATCH", "500")
|
||||
)
|
||||
|
||||
# Incremental re-indexing: on document edits, keep chunk rows whose text is
|
||||
# unchanged (reusing their embeddings) and embed only new/changed chunks.
|
||||
# Kill switch -- disabling falls back to delete-all + full re-embed.
|
||||
CHUNK_RECONCILE_ENABLED = (
|
||||
os.getenv("CHUNK_RECONCILE_ENABLED", "true").strip().lower() == "true"
|
||||
)
|
||||
INDEXING_CHUNK_INSERT_BATCH_SIZE = int(
|
||||
os.getenv("INDEXING_CHUNK_INSERT_BATCH_SIZE", "200")
|
||||
)
|
||||
|
||||
# Proxy provider selection. Maps to a ProxyProvider implementation registered
|
||||
# in app/utils/proxy/registry.py. Add new vendors there and switch via this var.
|
||||
PROXY_PROVIDER = os.getenv("PROXY_PROVIDER", "anonymous_proxies")
|
||||
|
|
|
|||
|
|
@ -1,236 +1,362 @@
|
|||
# Global LLM Configuration
|
||||
#
|
||||
# SETUP INSTRUCTIONS:
|
||||
# 1. Copy this file to global_llm_config.yaml.
|
||||
# 2. Replace placeholder credentials, endpoints, deployment names, and pricing
|
||||
# with values from your own provider accounts.
|
||||
# 1. For production: Copy this file to global_llm_config.yaml and add your real API keys
|
||||
# 2. For testing: The system will use this example file automatically if global_llm_config.yaml doesn't exist
|
||||
#
|
||||
# This file is intentionally safe to commit. Do not put real API keys in this
|
||||
# example file.
|
||||
# NOTE: The example API keys below are placeholders and won't work.
|
||||
# Replace them with your actual API keys to enable global configurations.
|
||||
#
|
||||
# These YAML entries are materialized at startup as server-owned GLOBAL
|
||||
# connections and models:
|
||||
# These configurations will be available to all users as a convenient option
|
||||
# Users can choose to use these global configs or add their own
|
||||
#
|
||||
# global_llm_configs -> GLOBAL chat models
|
||||
# global_image_generation_configs -> GLOBAL image generation models
|
||||
# AUTO MODE (Recommended):
|
||||
# - Auto mode (ID: 0) uses LiteLLM Router to automatically load balance across all global configs
|
||||
# - This helps avoid rate limits by distributing requests across multiple providers
|
||||
# - New users are automatically assigned Auto mode by default
|
||||
# - Configure router_settings below to customize the load balancing behavior
|
||||
#
|
||||
# Do not add global_connections or global_models sections here. They are
|
||||
# runtime-derived metadata exposed through the model-connections APIs.
|
||||
#
|
||||
# Static config shape:
|
||||
# - Connection fields: provider, api_key, api_base, api_version
|
||||
# - Model fields: model_name, billing_tier, rpm/tpm, capabilities, litellm_params
|
||||
# - Public no-login SEO metadata: seo_title, seo_description
|
||||
# - Prompt defaults: system_instructions, use_default_system_instructions,
|
||||
# citations_enabled
|
||||
#
|
||||
# Provider notes:
|
||||
# - Use the canonical provider field.
|
||||
# - For Azure, use the bare deployment name in model_name, for example
|
||||
# model_name: "gpt-5.1". The resolver prefixes the LiteLLM model string from
|
||||
# provider: "azure".
|
||||
#
|
||||
# GLOBAL ID namespace:
|
||||
# - ID 0 is reserved for Auto mode.
|
||||
# - Negative IDs are server-owned GLOBAL models.
|
||||
# - Positive IDs are user/BYOK database models.
|
||||
# - Keep static IDs unique across chat and image generation.
|
||||
# - Suggested static ranges: chat -1..-999, image -2001..-2999.
|
||||
# - Vision is not a separate config/table. Chat models that accept images use
|
||||
# supports_image_input: true.
|
||||
# Structure matches NewLLMConfig:
|
||||
# - Model configuration (provider, model_name, api_key, etc.)
|
||||
# - Prompt configuration (system_instructions, citations_enabled)
|
||||
#
|
||||
# COST-BASED PREMIUM CREDITS:
|
||||
# Each premium model bills the user's USD-credit balance based on provider cost
|
||||
# reported by LiteLLM. For custom Azure deployments or any model LiteLLM does
|
||||
# not know, declare per-token costs inline:
|
||||
# Each premium config bills the user's USD-credit balance based on the
|
||||
# actual provider cost reported by LiteLLM. For models LiteLLM already
|
||||
# knows (most OpenAI/Anthropic/etc. names) you don't need to do anything.
|
||||
# For custom Azure deployment names (e.g. an in-house "gpt-5.4" deployment)
|
||||
# or any model LiteLLM doesn't have in its built-in pricing table, declare
|
||||
# per-token costs inline so they bill correctly:
|
||||
#
|
||||
# litellm_params:
|
||||
# base_model: "my-custom-deployment"
|
||||
# # USD per token; 0.00000125 == $1.25 per million input tokens.
|
||||
# input_cost_per_token: 0.00000125
|
||||
# output_cost_per_token: 0.00001
|
||||
# base_model: "my-custom-azure-deploy"
|
||||
# # USD per token; e.g. 0.000003 == $3.00 per million input tokens
|
||||
# input_cost_per_token: 0.000003
|
||||
# output_cost_per_token: 0.000015
|
||||
#
|
||||
# OpenRouter dynamic chat models pull pricing automatically from OpenRouter's
|
||||
# API. Models without resolvable pricing debit $0 and log a warning.
|
||||
# OpenRouter dynamic models pull pricing automatically from OpenRouter's
|
||||
# API — no inline declaration needed. Models without resolvable pricing
|
||||
# debit $0 from the user's balance and log a WARNING.
|
||||
|
||||
# =============================================================================
|
||||
# Chat Auto Mode Router Settings
|
||||
# =============================================================================
|
||||
# These settings control how the LiteLLM Router distributes Auto-mode requests
|
||||
# across curated router-eligible GLOBAL chat deployments.
|
||||
# Router Settings for Auto Mode
|
||||
# These settings control how the LiteLLM Router distributes requests across models
|
||||
router_settings:
|
||||
# Routing strategy options:
|
||||
# - "usage-based-routing": Routes to deployment with lowest current usage.
|
||||
# - "simple-shuffle": Random distribution with optional RPM/TPM weighting.
|
||||
# - "least-busy": Routes to least busy deployment.
|
||||
# - "latency-based-routing": Routes based on response latency.
|
||||
# - "usage-based-routing": Routes to deployment with lowest current usage (recommended for rate limits)
|
||||
# - "simple-shuffle": Random distribution with optional RPM/TPM weighting
|
||||
# - "least-busy": Routes to least busy deployment
|
||||
# - "latency-based-routing": Routes based on response latency
|
||||
routing_strategy: "usage-based-routing"
|
||||
|
||||
# Number of retries before failing
|
||||
num_retries: 3
|
||||
|
||||
# Number of failures allowed before cooling down a deployment
|
||||
allowed_fails: 3
|
||||
|
||||
# Cooldown time in seconds after allowed_fails is exceeded
|
||||
cooldown_time: 60
|
||||
# Optional fallback map:
|
||||
# fallbacks:
|
||||
# - {"azure/gpt-5.1": ["azure/gpt-5.4-mini"]}
|
||||
|
||||
# =============================================================================
|
||||
# Static GLOBAL Chat Models
|
||||
# =============================================================================
|
||||
# Fallback models (optional) - when primary fails, try these
|
||||
# Format: [{"primary_model": ["fallback1", "fallback2"]}]
|
||||
# fallbacks: []
|
||||
|
||||
global_llm_configs:
|
||||
# Premium Azure chat model with image input support and explicit custom
|
||||
# pricing. This is the current shape to use for hosted GPT 5.x deployments.
|
||||
# Example: OpenAI GPT-4 Turbo with citations enabled
|
||||
- id: -1
|
||||
name: "Azure GPT 5.1"
|
||||
billing_tier: "premium"
|
||||
anonymous_enabled: false
|
||||
seo_enabled: false
|
||||
seo_slug: "azure-gpt-5-1"
|
||||
quota_reserve_tokens: 4000
|
||||
provider: "azure"
|
||||
model_name: "gpt-5.1"
|
||||
supports_image_input: true
|
||||
supports_tools: true
|
||||
max_input_tokens: 400000
|
||||
api_key: "your-azure-api-key-here"
|
||||
api_base: "https://your-resource.openai.azure.com"
|
||||
# api_version is optional. Include it if your Azure deployment requires a
|
||||
# specific API version.
|
||||
# api_version: "2025-04-01-preview"
|
||||
rpm: 47500
|
||||
tpm: 14750000
|
||||
litellm_params:
|
||||
max_tokens: 16384
|
||||
base_model: "gpt-5.1"
|
||||
input_cost_per_token: 0.00000125
|
||||
output_cost_per_token: 0.00001
|
||||
system_instructions: ""
|
||||
use_default_system_instructions: true
|
||||
citations_enabled: true
|
||||
|
||||
# Larger premium chat model. If your provider prices long-context traffic
|
||||
# differently, choose a conservative flat price or document the limitation
|
||||
# next to the inline pricing.
|
||||
- id: -2
|
||||
name: "Azure GPT 5.4"
|
||||
billing_tier: "premium"
|
||||
anonymous_enabled: false
|
||||
seo_enabled: false
|
||||
seo_slug: "azure-gpt-5-4"
|
||||
quota_reserve_tokens: 4000
|
||||
provider: "azure"
|
||||
model_name: "gpt-5.4"
|
||||
supports_image_input: true
|
||||
supports_tools: true
|
||||
max_input_tokens: 400000
|
||||
api_key: "your-azure-api-key-here"
|
||||
api_base: "https://your-resource.openai.azure.com"
|
||||
rpm: 150000
|
||||
tpm: 15000000
|
||||
litellm_params:
|
||||
max_tokens: 16384
|
||||
base_model: "gpt-5.4"
|
||||
input_cost_per_token: 0.0000025
|
||||
output_cost_per_token: 0.000015
|
||||
system_instructions: ""
|
||||
use_default_system_instructions: true
|
||||
citations_enabled: true
|
||||
|
||||
# Free/no-login hosted model. Free models are visible to users when
|
||||
# anonymous_enabled/seo_enabled are true but do not debit premium credits.
|
||||
- id: -3
|
||||
name: "Azure GPT 5.4 Mini"
|
||||
name: "Global GPT-4 Turbo"
|
||||
description: "OpenAI's GPT-4 Turbo with default prompts and citations"
|
||||
billing_tier: "free"
|
||||
anonymous_enabled: true
|
||||
seo_enabled: true
|
||||
seo_slug: "gpt-5-4-mini-no-login"
|
||||
seo_title: "Free GPT 5.4 Mini Chat"
|
||||
seo_description: "Chat with a hosted GPT 5.4 Mini model without signing in."
|
||||
seo_slug: "gpt-4-turbo"
|
||||
quota_reserve_tokens: 4000
|
||||
provider: "azure"
|
||||
model_name: "gpt-5.4-mini"
|
||||
supports_image_input: false
|
||||
supports_tools: true
|
||||
max_input_tokens: 128000
|
||||
api_key: "your-azure-api-key-here"
|
||||
api_base: "https://your-resource.openai.azure.com"
|
||||
rpm: 15000
|
||||
tpm: 15000000
|
||||
provider: "OPENAI"
|
||||
model_name: "gpt-4-turbo-preview"
|
||||
api_key: "sk-your-openai-api-key-here"
|
||||
api_base: ""
|
||||
# Rate limits for load balancing (requests/tokens per minute)
|
||||
rpm: 500 # Requests per minute
|
||||
tpm: 100000 # Tokens per minute
|
||||
litellm_params:
|
||||
max_tokens: 16384
|
||||
base_model: "gpt-5.4-mini"
|
||||
temperature: 0.7
|
||||
max_tokens: 4000
|
||||
# Prompt Configuration
|
||||
system_instructions: "" # Empty = use default SURFSENSE_SYSTEM_INSTRUCTIONS
|
||||
use_default_system_instructions: true
|
||||
citations_enabled: true
|
||||
|
||||
# Example: Anthropic Claude 3 Opus
|
||||
- id: -2
|
||||
name: "Global Claude 3 Opus"
|
||||
description: "Anthropic's most capable model with citations"
|
||||
billing_tier: "free"
|
||||
anonymous_enabled: true
|
||||
seo_enabled: true
|
||||
seo_slug: "claude-3-opus"
|
||||
quota_reserve_tokens: 4000
|
||||
provider: "ANTHROPIC"
|
||||
model_name: "claude-3-opus-20240229"
|
||||
api_key: "sk-ant-your-anthropic-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 1000
|
||||
tpm: 100000
|
||||
litellm_params:
|
||||
temperature: 0.7
|
||||
max_tokens: 4000
|
||||
system_instructions: ""
|
||||
use_default_system_instructions: true
|
||||
citations_enabled: true
|
||||
|
||||
# Planner LLM. This is operator-only and is not shown in the user-facing
|
||||
# model selector. Only one global_llm_configs entry should set is_planner.
|
||||
# Example: Fast model - GPT-3.5 Turbo (citations disabled for speed)
|
||||
- id: -3
|
||||
name: "Global GPT-3.5 Turbo (Fast)"
|
||||
description: "Fast responses without citations for quick queries"
|
||||
billing_tier: "free"
|
||||
anonymous_enabled: true
|
||||
seo_enabled: true
|
||||
seo_slug: "gpt-3.5-turbo-fast"
|
||||
quota_reserve_tokens: 2000
|
||||
provider: "OPENAI"
|
||||
model_name: "gpt-3.5-turbo"
|
||||
api_key: "sk-your-openai-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 3500 # GPT-3.5 has higher rate limits
|
||||
tpm: 200000
|
||||
litellm_params:
|
||||
temperature: 0.5
|
||||
max_tokens: 2000
|
||||
system_instructions: ""
|
||||
use_default_system_instructions: true
|
||||
citations_enabled: false # Disabled for faster responses
|
||||
|
||||
# Example: Chinese LLM - DeepSeek with custom instructions
|
||||
- id: -4
|
||||
name: "Global DeepSeek Chat (Chinese)"
|
||||
description: "DeepSeek optimized for Chinese language responses"
|
||||
billing_tier: "free"
|
||||
anonymous_enabled: true
|
||||
seo_enabled: true
|
||||
seo_slug: "deepseek-chat-chinese"
|
||||
quota_reserve_tokens: 4000
|
||||
provider: "DEEPSEEK"
|
||||
model_name: "deepseek-chat"
|
||||
api_key: "your-deepseek-api-key-here"
|
||||
api_base: "https://api.deepseek.com/v1"
|
||||
rpm: 60
|
||||
tpm: 100000
|
||||
litellm_params:
|
||||
temperature: 0.7
|
||||
max_tokens: 4000
|
||||
# Custom system instructions for Chinese responses
|
||||
system_instructions: |
|
||||
<system_instruction>
|
||||
You are SurfSense, a reasoning and acting AI agent designed to answer user questions using the user's personal knowledge base.
|
||||
|
||||
Today's date (UTC): {resolved_today}
|
||||
|
||||
IMPORTANT: Please respond in Chinese (简体中文) unless the user specifically requests another language.
|
||||
</system_instruction>
|
||||
use_default_system_instructions: false
|
||||
citations_enabled: true
|
||||
|
||||
# Example: Azure OpenAI GPT-4o
|
||||
# IMPORTANT: For Azure deployments, always include 'base_model' in litellm_params
|
||||
# to enable accurate token counting, cost tracking, and max token limits
|
||||
- id: -5
|
||||
name: "Global Azure GPT-4o"
|
||||
description: "Azure OpenAI GPT-4o deployment"
|
||||
billing_tier: "free"
|
||||
anonymous_enabled: true
|
||||
seo_enabled: true
|
||||
seo_slug: "azure-gpt-4o"
|
||||
quota_reserve_tokens: 4000
|
||||
provider: "AZURE"
|
||||
# model_name format for Azure: azure/<your-deployment-name>
|
||||
model_name: "azure/gpt-4o-deployment"
|
||||
api_key: "your-azure-api-key-here"
|
||||
api_base: "https://your-resource.openai.azure.com"
|
||||
api_version: "2024-02-15-preview" # Azure API version
|
||||
rpm: 1000
|
||||
tpm: 150000
|
||||
litellm_params:
|
||||
temperature: 0.7
|
||||
max_tokens: 4000
|
||||
# REQUIRED for Azure: Specify the underlying OpenAI model
|
||||
# This fixes "Could not identify azure model" warnings
|
||||
# Common base_model values: gpt-4, gpt-4-turbo, gpt-4o, gpt-4o-mini, gpt-3.5-turbo
|
||||
base_model: "gpt-4o"
|
||||
system_instructions: ""
|
||||
use_default_system_instructions: true
|
||||
citations_enabled: true
|
||||
|
||||
# Example: Azure OpenAI GPT-4 Turbo
|
||||
- id: -6
|
||||
name: "Global Azure GPT-4 Turbo"
|
||||
description: "Azure OpenAI GPT-4 Turbo deployment"
|
||||
billing_tier: "free"
|
||||
anonymous_enabled: true
|
||||
seo_enabled: true
|
||||
seo_slug: "azure-gpt-4-turbo"
|
||||
quota_reserve_tokens: 4000
|
||||
provider: "AZURE"
|
||||
model_name: "azure/gpt-4-turbo-deployment"
|
||||
api_key: "your-azure-api-key-here"
|
||||
api_base: "https://your-resource.openai.azure.com"
|
||||
api_version: "2024-02-15-preview"
|
||||
rpm: 500
|
||||
tpm: 100000
|
||||
litellm_params:
|
||||
temperature: 0.7
|
||||
max_tokens: 4000
|
||||
base_model: "gpt-4-turbo" # Maps to gpt-4-turbo-preview
|
||||
system_instructions: ""
|
||||
use_default_system_instructions: true
|
||||
citations_enabled: true
|
||||
|
||||
# Example: Groq - Fast inference
|
||||
- id: -7
|
||||
name: "Global Groq Llama 3"
|
||||
description: "Ultra-fast Llama 3 70B via Groq"
|
||||
billing_tier: "free"
|
||||
anonymous_enabled: true
|
||||
seo_enabled: true
|
||||
seo_slug: "groq-llama-3"
|
||||
quota_reserve_tokens: 8000
|
||||
provider: "GROQ"
|
||||
model_name: "llama3-70b-8192"
|
||||
api_key: "your-groq-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 30 # Groq has lower rate limits on free tier
|
||||
tpm: 14400
|
||||
litellm_params:
|
||||
temperature: 0.7
|
||||
max_tokens: 8000
|
||||
system_instructions: ""
|
||||
use_default_system_instructions: true
|
||||
citations_enabled: true
|
||||
|
||||
# Example: MiniMax M3 - High-performance with 512K context window
|
||||
- id: -8
|
||||
name: "Global MiniMax M3"
|
||||
description: "MiniMax M3 with 512K context window and competitive pricing"
|
||||
billing_tier: "free"
|
||||
anonymous_enabled: true
|
||||
seo_enabled: true
|
||||
seo_slug: "minimax-m3"
|
||||
quota_reserve_tokens: 4000
|
||||
provider: "MINIMAX"
|
||||
model_name: "MiniMax-M3"
|
||||
api_key: "your-minimax-api-key-here"
|
||||
api_base: "https://api.minimax.io/v1"
|
||||
rpm: 60
|
||||
tpm: 100000
|
||||
litellm_params:
|
||||
temperature: 1.0 # MiniMax requires temperature in (0.0, 1.0], cannot be 0
|
||||
max_tokens: 4000
|
||||
system_instructions: ""
|
||||
use_default_system_instructions: true
|
||||
citations_enabled: true
|
||||
|
||||
# Example: Planner LLM - small, fast model used for internal utility tasks
|
||||
#
|
||||
# The PLANNER role handles short, structured internal calls (KB query
|
||||
# rewriting, date extraction, recency classification, etc.) that don't
|
||||
# need frontier-tier capability. Pointing the planner at a cheap+fast
|
||||
# model (gpt-4o-mini, Claude Haiku, Azure gpt-5.x-nano, Groq Llama, ...)
|
||||
# typically saves 500ms-1.5s per turn vs. routing those same internal
|
||||
# calls through the user's chat model.
|
||||
#
|
||||
# Activation:
|
||||
# - Mark EXACTLY ONE global config with ``is_planner: true``.
|
||||
# - If multiple are marked, the first one wins and a WARNING is logged.
|
||||
# - If none is marked, every internal call falls back to the user's
|
||||
# chat LLM (same behavior as before this flag existed).
|
||||
#
|
||||
# This config is operator-only — it is NOT exposed in the user-facing
|
||||
# model selector, never billed against premium quota, and the
|
||||
# billing_tier / anonymous_enabled fields below are ignored.
|
||||
- id: -9
|
||||
name: "Azure GPT 5.x Nano Planner"
|
||||
name: "Global Planner (GPT-4o mini)"
|
||||
description: "Internal-only planner LLM for query rewriting and classification"
|
||||
is_planner: true
|
||||
billing_tier: "free"
|
||||
anonymous_enabled: false
|
||||
seo_enabled: false
|
||||
quota_reserve_tokens: 1000
|
||||
provider: "azure"
|
||||
model_name: "gpt-5.4-nano"
|
||||
supports_image_input: false
|
||||
supports_tools: false
|
||||
router_pool_eligible: false
|
||||
api_key: "your-azure-api-key-here"
|
||||
api_base: "https://your-resource.openai.azure.com"
|
||||
rpm: 20000
|
||||
tpm: 4000000
|
||||
provider: "OPENAI"
|
||||
model_name: "gpt-4o-mini"
|
||||
api_key: "sk-your-openai-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 3500
|
||||
tpm: 200000
|
||||
litellm_params:
|
||||
temperature: 0
|
||||
max_tokens: 1000
|
||||
base_model: "gpt-5.4-nano"
|
||||
system_instructions: ""
|
||||
use_default_system_instructions: true
|
||||
citations_enabled: false
|
||||
|
||||
# =============================================================================
|
||||
# OpenRouter Dynamic Model Integration
|
||||
# OpenRouter Integration
|
||||
# =============================================================================
|
||||
# When enabled, SurfSense fetches the OpenRouter catalog at startup and injects
|
||||
# supported models as GLOBAL chat and optionally image-generation models.
|
||||
# Tier is derived per model from OpenRouter data:
|
||||
# - model id ends with ":free" -> billing_tier=free
|
||||
# - prompt and completion pricing are zero -> billing_tier=free
|
||||
# - otherwise -> billing_tier=premium
|
||||
#
|
||||
# Do not use deprecated openrouter_integration.billing_tier or
|
||||
# openrouter_integration.anonymous_enabled. Use the tier-specific anonymous
|
||||
# switches below.
|
||||
# When enabled, dynamically fetches ALL available models from the OpenRouter API
|
||||
# and injects them as global configs. This gives premium users access to any model
|
||||
# on OpenRouter (Claude, Gemini, Llama, Mistral, etc.) via their premium token quota,
|
||||
# while free-tier OpenRouter models show up with a green Free badge and do NOT
|
||||
# consume premium quota.
|
||||
# Models are fetched at startup and refreshed periodically in the background.
|
||||
# All calls go through LiteLLM with the openrouter/ prefix.
|
||||
openrouter_integration:
|
||||
enabled: false
|
||||
api_key: "sk-or-your-openrouter-api-key"
|
||||
|
||||
# Tier is derived PER MODEL from OpenRouter's own API signals:
|
||||
# - id ends with ":free" -> billing_tier=free
|
||||
# - pricing.prompt AND pricing.completion == "0" -> billing_tier=free
|
||||
# - otherwise -> billing_tier=premium
|
||||
# No global billing_tier knob is honored; any legacy value emits a startup warning.
|
||||
|
||||
# Anonymous access is split by tier so operators can expose only free
|
||||
# models to no-login users without leaking paid inference.
|
||||
anonymous_enabled_paid: false
|
||||
anonymous_enabled_free: false
|
||||
|
||||
seo_enabled: false
|
||||
# quota_reserve_tokens: tokens reserved per call for quota enforcement
|
||||
quota_reserve_tokens: 4000
|
||||
|
||||
# Base negative ID namespace for dynamic chat models. IDs are derived
|
||||
# deterministically so they survive catalog churn. Do not overlap static IDs.
|
||||
# id_offset: base negative ID for dynamically generated configs.
|
||||
# Model IDs are derived deterministically via BLAKE2b so they survive
|
||||
# catalogue churn. Must not overlap with your static global_llm_configs IDs.
|
||||
id_offset: -10000
|
||||
|
||||
# Separate base negative ID namespace for dynamic image-generation models.
|
||||
image_id_offset: -20000
|
||||
|
||||
# How often to refresh the OpenRouter catalog. 0 means startup only.
|
||||
# refresh_interval_hours: how often to re-fetch models from OpenRouter (0 = startup only)
|
||||
refresh_interval_hours: 24
|
||||
|
||||
# Paid OpenRouter models may join curated router pools when eligible.
|
||||
# Rate limits for PAID OpenRouter models. These are used by LiteLLM Router
|
||||
# for per-deployment accounting when OR premium models participate in the
|
||||
# shared sub-agent "auto" pool. They do NOT cap OpenRouter itself — your
|
||||
# real account limits live at https://openrouter.ai/settings/limits.
|
||||
rpm: 200
|
||||
tpm: 1000000
|
||||
|
||||
# Free OpenRouter models are available for user-facing selection/pinning but
|
||||
# should be treated as a shared-account bucket, not normal router capacity.
|
||||
# Rate limits for FREE OpenRouter models. Informational only: free OR
|
||||
# models are intentionally kept OUT of the LiteLLM Router pool, because
|
||||
# OpenRouter enforces free-tier limits globally per account (~20 RPM +
|
||||
# 50-1000 daily requests across every ":free" model combined) —
|
||||
# per-deployment router accounting can't represent a shared bucket
|
||||
# correctly. Free OR models stay fully available in the model selector
|
||||
# and for user-facing Auto thread pinning.
|
||||
free_rpm: 20
|
||||
free_tpm: 100000
|
||||
|
||||
# Image generation is opt-in to avoid injecting a large image catalog during
|
||||
# upgrades. Vision-capable chat models are represented with
|
||||
# supports_image_input: true.
|
||||
# Image generation + vision LLM emission are OPT-IN. OpenRouter's catalogue
|
||||
# contains hundreds of image- and vision-capable models; turning these on
|
||||
# injects them into the global Image-Generation / Vision-LLM model
|
||||
# selectors alongside any static configs. Tier (free/premium) is derived
|
||||
# per model the same way it is for chat (`:free` suffix or zero pricing).
|
||||
# When a user picks a premium image/vision model the call debits the
|
||||
# shared $5 USD-cost-based premium credit pool — so leaving these off
|
||||
# avoids surprise quota burn on existing deployments. Default: false.
|
||||
image_generation_enabled: false
|
||||
vision_enabled: false
|
||||
|
||||
|
|
@ -241,80 +367,191 @@ openrouter_integration:
|
|||
citations_enabled: true
|
||||
|
||||
# =============================================================================
|
||||
# Image Generation Auto Mode Router Settings
|
||||
# Image Generation Configuration
|
||||
# =============================================================================
|
||||
# These configurations power the image generation feature using litellm.aimage_generation().
|
||||
# Supported providers: OpenAI, Azure, Google AI Studio, Vertex AI, AWS Bedrock,
|
||||
# Recraft, OpenRouter, Xinference, Nscale
|
||||
#
|
||||
# Auto mode (ID 0) uses LiteLLM Router for load balancing across all image gen configs.
|
||||
|
||||
# Router Settings for Image Generation Auto Mode
|
||||
image_generation_router_settings:
|
||||
routing_strategy: "usage-based-routing"
|
||||
num_retries: 3
|
||||
allowed_fails: 3
|
||||
cooldown_time: 60
|
||||
|
||||
# =============================================================================
|
||||
# Static GLOBAL Image Generation Models
|
||||
# =============================================================================
|
||||
global_image_generation_configs:
|
||||
- id: -2001
|
||||
name: "Azure GPT Image 1.5"
|
||||
billing_tier: "premium"
|
||||
provider: "azure"
|
||||
model_name: "gpt-image-1.5"
|
||||
# Example: OpenAI DALL-E 3
|
||||
- id: -1
|
||||
name: "Global DALL-E 3"
|
||||
description: "OpenAI's DALL-E 3 for high-quality image generation"
|
||||
provider: "OPENAI"
|
||||
model_name: "dall-e-3"
|
||||
api_key: "sk-your-openai-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 50 # Requests per minute (image gen is rate-limited by RPM, not tokens)
|
||||
litellm_params: {}
|
||||
|
||||
# Example: OpenAI GPT Image 1
|
||||
- id: -2
|
||||
name: "Global GPT Image 1"
|
||||
description: "OpenAI's GPT Image 1 model"
|
||||
provider: "OPENAI"
|
||||
model_name: "gpt-image-1"
|
||||
api_key: "sk-your-openai-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 50
|
||||
litellm_params: {}
|
||||
|
||||
# Example: Azure OpenAI DALL-E 3
|
||||
- id: -3
|
||||
name: "Global Azure DALL-E 3"
|
||||
description: "Azure-hosted DALL-E 3 deployment"
|
||||
provider: "AZURE_OPENAI"
|
||||
model_name: "azure/dall-e-3-deployment"
|
||||
api_key: "your-azure-api-key-here"
|
||||
api_base: "https://your-resource.openai.azure.com"
|
||||
# api_version: "2025-04-01-preview"
|
||||
rpm: 60
|
||||
api_version: "2024-02-15-preview"
|
||||
rpm: 50
|
||||
litellm_params:
|
||||
base_model: "gpt-image-1.5"
|
||||
base_model: "dall-e-3"
|
||||
|
||||
- id: -2002
|
||||
name: "Azure GPT Image 1 Mini"
|
||||
billing_tier: "free"
|
||||
provider: "azure"
|
||||
model_name: "gpt-image-1-mini"
|
||||
api_key: "your-azure-api-key-here"
|
||||
api_base: "https://your-resource.openai.azure.com"
|
||||
# api_version: "2025-04-01-preview"
|
||||
rpm: 120
|
||||
litellm_params:
|
||||
base_model: "gpt-image-1-mini"
|
||||
# Example: OpenRouter Gemini Image Generation
|
||||
# - id: -4
|
||||
# name: "Global Gemini Image Gen"
|
||||
# description: "Google Gemini image generation via OpenRouter"
|
||||
# provider: "OPENROUTER"
|
||||
# model_name: "google/gemini-2.5-flash-image"
|
||||
# api_key: "your-openrouter-api-key-here"
|
||||
# api_base: ""
|
||||
# rpm: 30
|
||||
# litellm_params: {}
|
||||
|
||||
# =============================================================================
|
||||
# Field Notes
|
||||
# Vision LLM Configuration
|
||||
# =============================================================================
|
||||
# Common chat/image fields:
|
||||
# - provider: Canonical provider adapter name. Example: azure, openai,
|
||||
# anthropic, openrouter, groq, bedrock.
|
||||
# - model_name: Provider model or deployment id. For Azure, use the bare
|
||||
# deployment name. The resolver prefixes LiteLLM model strings from provider.
|
||||
# - api_base: Provider endpoint/root URL. For OpenAI-compatible providers, the
|
||||
# resolver adds /v1 when needed.
|
||||
# - api_version: Optional provider-specific API version, stored on the
|
||||
# materialized connection extra metadata.
|
||||
# - litellm_params: Passed to LiteLLM when invoking the model. Also used for
|
||||
# base_model and inline pricing registration.
|
||||
# These configurations power the vision autocomplete feature (screenshot analysis).
|
||||
# Only vision-capable models should be used here (e.g. GPT-4o, Gemini Pro, Claude 3).
|
||||
# Supported providers: OpenAI, Anthropic, Google, Azure OpenAI, Vertex AI, Bedrock,
|
||||
# xAI, OpenRouter, Ollama, Groq, Together AI, Fireworks AI, DeepSeek, Mistral, Custom
|
||||
#
|
||||
# Chat model fields:
|
||||
# - supports_image_input: true when the chat model can consume image inputs.
|
||||
# - supports_tools: true when the model can use tools/function calling.
|
||||
# - max_input_tokens: Optional UI/catalog metadata for context size.
|
||||
# - router_pool_eligible: false keeps a model out of shared router pools while
|
||||
# still allowing direct selection/pinning.
|
||||
# - is_planner: true marks the internal-only planner model. Only one config
|
||||
# should set this flag.
|
||||
# Auto mode (ID 0) uses LiteLLM Router for load balancing across all vision configs.
|
||||
|
||||
# Router Settings for Vision LLM Auto Mode
|
||||
vision_llm_router_settings:
|
||||
routing_strategy: "usage-based-routing"
|
||||
num_retries: 3
|
||||
allowed_fails: 3
|
||||
cooldown_time: 60
|
||||
|
||||
global_vision_llm_configs:
|
||||
# Example: OpenAI GPT-4o (recommended for vision)
|
||||
- id: -1
|
||||
name: "Global GPT-4o Vision"
|
||||
description: "OpenAI's GPT-4o with strong vision capabilities"
|
||||
provider: "OPENAI"
|
||||
model_name: "gpt-4o"
|
||||
api_key: "sk-your-openai-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 500
|
||||
tpm: 100000
|
||||
litellm_params:
|
||||
temperature: 0.3
|
||||
max_tokens: 1000
|
||||
|
||||
# Example: Google Gemini 2.0 Flash
|
||||
- id: -2
|
||||
name: "Global Gemini 2.0 Flash"
|
||||
description: "Google's fast vision model with large context"
|
||||
provider: "GOOGLE"
|
||||
model_name: "gemini-2.0-flash"
|
||||
api_key: "your-google-ai-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 1000
|
||||
tpm: 200000
|
||||
litellm_params:
|
||||
temperature: 0.3
|
||||
max_tokens: 1000
|
||||
|
||||
# Example: Anthropic Claude 3.5 Sonnet
|
||||
- id: -3
|
||||
name: "Global Claude 3.5 Sonnet Vision"
|
||||
description: "Anthropic's Claude 3.5 Sonnet with vision support"
|
||||
provider: "ANTHROPIC"
|
||||
model_name: "claude-3-5-sonnet-20241022"
|
||||
api_key: "sk-ant-your-anthropic-api-key-here"
|
||||
api_base: ""
|
||||
rpm: 1000
|
||||
tpm: 100000
|
||||
litellm_params:
|
||||
temperature: 0.3
|
||||
max_tokens: 1000
|
||||
|
||||
# Example: Azure OpenAI GPT-4o
|
||||
# - id: -4
|
||||
# name: "Global Azure GPT-4o Vision"
|
||||
# description: "Azure-hosted GPT-4o for vision analysis"
|
||||
# provider: "AZURE_OPENAI"
|
||||
# model_name: "azure/gpt-4o-deployment"
|
||||
# api_key: "your-azure-api-key-here"
|
||||
# api_base: "https://your-resource.openai.azure.com"
|
||||
# api_version: "2024-02-15-preview"
|
||||
# rpm: 500
|
||||
# tpm: 100000
|
||||
# litellm_params:
|
||||
# temperature: 0.3
|
||||
# max_tokens: 1000
|
||||
# base_model: "gpt-4o"
|
||||
|
||||
# Notes:
|
||||
# - ID 0 is reserved for "Auto" mode - uses LiteLLM Router for load balancing
|
||||
# - Use negative IDs to distinguish global configs from user configs (NewLLMConfig in DB)
|
||||
# - IDs should be unique and sequential (e.g., -1, -2, -3, etc.)
|
||||
# - The 'api_key' field will not be exposed to users via API
|
||||
# - system_instructions: Custom prompt or empty string to use defaults
|
||||
# - use_default_system_instructions: true = use SURFSENSE_SYSTEM_INSTRUCTIONS when system_instructions is empty
|
||||
# - citations_enabled: true = include citation instructions, false = include anti-citation instructions
|
||||
# - All standard LiteLLM providers are supported
|
||||
# - rpm/tpm: Optional rate limits for load balancing (requests/tokens per minute)
|
||||
# These help the router distribute load evenly and avoid rate limit errors
|
||||
#
|
||||
# Catalog and access fields:
|
||||
# - billing_tier: "free" or "premium".
|
||||
# - anonymous_enabled: Whether the model appears in the public no-login catalog.
|
||||
# - seo_enabled: Whether a /free/<seo_slug> landing page is generated.
|
||||
# - seo_slug: Stable URL slug for SEO pages. Keep unique and do not change once
|
||||
# public.
|
||||
# - seo_title / seo_description: Optional SEO metadata overrides.
|
||||
# - quota_reserve_tokens: Tokens reserved before each chat LLM call.
|
||||
# - rpm / tpm: Optional rate limits for router accounting and load balancing.
|
||||
#
|
||||
# Image generation notes:
|
||||
# - Image-generation configs use the same GLOBAL ID namespace as chat models.
|
||||
# - Only RPM is relevant for most image-generation APIs.
|
||||
# - The runtime uses litellm.aimage_generation().
|
||||
# - Image billing currently uses billing_tier and model catalog metadata. Keep
|
||||
# quota reserve tuning in code/catalog unless the materializer copies a YAML
|
||||
# key for image quota reservation.
|
||||
# IMAGE GENERATION NOTES:
|
||||
# - Image generation configs use the same ID scheme as LLM configs (negative for global)
|
||||
# - Supported models: dall-e-2, dall-e-3, gpt-image-1 (OpenAI), azure/* (Azure),
|
||||
# bedrock/* (AWS), vertex_ai/* (Google), recraft/* (Recraft), openrouter/* (OpenRouter)
|
||||
# - The router uses litellm.aimage_generation() for async image generation
|
||||
# - Only RPM (requests per minute) is relevant for image generation rate limiting.
|
||||
# TPM (tokens per minute) does not apply since image APIs are billed/rate-limited per request, not per token.
|
||||
#
|
||||
# VISION LLM NOTES:
|
||||
# - Vision configs use the same ID scheme (negative for global, positive for user DB)
|
||||
# - Only use vision-capable models (GPT-4o, Gemini, Claude 3, etc.)
|
||||
# - Lower temperature (0.3) is recommended for accurate screenshot analysis
|
||||
# - Lower max_tokens (1000) is sufficient since autocomplete produces short suggestions
|
||||
#
|
||||
# PLANNER LLM NOTES:
|
||||
# - is_planner: true marks a config as the internal-only planner LLM (small,
|
||||
# fast model used for KB query rewriting, date extraction, recency
|
||||
# classification, etc.). Only one config may carry this flag — if
|
||||
# multiple do, the first one wins and a startup WARNING is logged.
|
||||
# - When no config is marked is_planner, every internal utility call falls
|
||||
# back to the user's chat LLM (the historical behavior).
|
||||
# - Planner configs are NOT shown in the user-facing model selector and
|
||||
# are NOT billed against the user's premium quota. Their billing_tier,
|
||||
# anonymous_enabled, seo_* fields are ignored.
|
||||
# - Recommended models: gpt-4o-mini, claude-3-5-haiku, gemini-1.5-flash,
|
||||
# azure gpt-5.x-nano, groq llama3-8b — anything <200ms p50 on a 1-2k
|
||||
# prompt. Frontier models here defeat the purpose of the flag.
|
||||
#
|
||||
# TOKEN QUOTA & ANONYMOUS ACCESS NOTES:
|
||||
# - billing_tier: "free" or "premium". Controls whether registered users need premium token quota.
|
||||
# - anonymous_enabled: true/false. Whether the model appears in the public no-login catalog.
|
||||
# - seo_enabled: true/false. Whether a /free/<seo_slug> landing page is generated.
|
||||
# - seo_slug: Stable URL slug for SEO pages. Must be unique. Do NOT change once public.
|
||||
# - seo_title: Optional HTML title tag override for the model's /free/<slug> page.
|
||||
# - seo_description: Optional meta description override for the model's /free/<slug> page.
|
||||
# - quota_reserve_tokens: Tokens reserved before each LLM call for quota enforcement.
|
||||
# Independent of litellm_params.max_tokens. Used by the token quota service.
|
||||
|
|
|
|||
|
|
@ -90,12 +90,11 @@ async def download_and_extract_content(
|
|||
if error:
|
||||
return None, metadata, error
|
||||
|
||||
from app.etl_pipeline.cache import extract_with_cache
|
||||
from app.etl_pipeline.etl_document import EtlRequest
|
||||
from app.etl_pipeline.etl_pipeline_service import EtlPipelineService
|
||||
|
||||
result = await extract_with_cache(
|
||||
EtlRequest(file_path=temp_file_path, filename=file_name),
|
||||
vision_llm=vision_llm,
|
||||
result = await EtlPipelineService(vision_llm=vision_llm).extract(
|
||||
EtlRequest(file_path=temp_file_path, filename=file_name)
|
||||
)
|
||||
markdown = result.markdown_content
|
||||
return markdown, metadata, None
|
||||
|
|
|
|||
|
|
@ -122,13 +122,12 @@ async def download_and_extract_content(
|
|||
async def _parse_file_to_markdown(
|
||||
file_path: str, filename: str, *, vision_llm=None
|
||||
) -> str:
|
||||
"""Parse a local file to markdown via the cache-aware ETL pipeline."""
|
||||
from app.etl_pipeline.cache import extract_with_cache
|
||||
"""Parse a local file to markdown using the unified ETL pipeline."""
|
||||
from app.etl_pipeline.etl_document import EtlRequest
|
||||
from app.etl_pipeline.etl_pipeline_service import EtlPipelineService
|
||||
|
||||
result = await extract_with_cache(
|
||||
EtlRequest(file_path=file_path, filename=filename),
|
||||
vision_llm=vision_llm,
|
||||
result = await EtlPipelineService(vision_llm=vision_llm).extract(
|
||||
EtlRequest(file_path=file_path, filename=filename)
|
||||
)
|
||||
return result.markdown_content
|
||||
|
||||
|
|
|
|||
|
|
@ -84,12 +84,11 @@ async def download_and_extract_content(
|
|||
async def _parse_file_to_markdown(
|
||||
file_path: str, filename: str, *, vision_llm=None
|
||||
) -> str:
|
||||
"""Parse a local file to markdown via the cache-aware ETL pipeline."""
|
||||
from app.etl_pipeline.cache import extract_with_cache
|
||||
"""Parse a local file to markdown using the unified ETL pipeline."""
|
||||
from app.etl_pipeline.etl_document import EtlRequest
|
||||
from app.etl_pipeline.etl_pipeline_service import EtlPipelineService
|
||||
|
||||
result = await extract_with_cache(
|
||||
EtlRequest(file_path=file_path, filename=filename),
|
||||
vision_llm=vision_llm,
|
||||
result = await EtlPipelineService(vision_llm=vision_llm).extract(
|
||||
EtlRequest(file_path=file_path, filename=filename)
|
||||
)
|
||||
return result.markdown_content
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
|
|
@ -35,8 +34,6 @@ from app.config import config
|
|||
if config.AUTH_TYPE == "GOOGLE":
|
||||
from fastapi_users.db import SQLAlchemyBaseOAuthAccountTableUUID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DATABASE_URL = config.DATABASE_URL
|
||||
|
||||
|
||||
|
|
@ -117,6 +114,13 @@ class SearchSourceConnectorType(StrEnum):
|
|||
COMPOSIO_GOOGLE_CALENDAR_CONNECTOR = "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR"
|
||||
|
||||
|
||||
class PodcastStatus(StrEnum):
|
||||
PENDING = "pending"
|
||||
GENERATING = "generating"
|
||||
READY = "ready"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class VideoPresentationStatus(StrEnum):
|
||||
PENDING = "pending"
|
||||
GENERATING = "generating"
|
||||
|
|
@ -201,15 +205,79 @@ class DocumentStatus:
|
|||
return None
|
||||
|
||||
|
||||
class ConnectionScope(StrEnum):
|
||||
GLOBAL = "GLOBAL"
|
||||
SEARCH_SPACE = "SEARCH_SPACE"
|
||||
USER = "USER"
|
||||
class LiteLLMProvider(StrEnum):
|
||||
"""
|
||||
Enum for LLM providers supported by LiteLLM.
|
||||
"""
|
||||
|
||||
OPENAI = "OPENAI"
|
||||
ANTHROPIC = "ANTHROPIC"
|
||||
GOOGLE = "GOOGLE"
|
||||
AZURE_OPENAI = "AZURE_OPENAI"
|
||||
BEDROCK = "BEDROCK"
|
||||
VERTEX_AI = "VERTEX_AI"
|
||||
GROQ = "GROQ"
|
||||
COHERE = "COHERE"
|
||||
MISTRAL = "MISTRAL"
|
||||
DEEPSEEK = "DEEPSEEK"
|
||||
XAI = "XAI"
|
||||
OPENROUTER = "OPENROUTER"
|
||||
TOGETHER_AI = "TOGETHER_AI"
|
||||
FIREWORKS_AI = "FIREWORKS_AI"
|
||||
REPLICATE = "REPLICATE"
|
||||
PERPLEXITY = "PERPLEXITY"
|
||||
OLLAMA = "OLLAMA"
|
||||
ALIBABA_QWEN = "ALIBABA_QWEN"
|
||||
MOONSHOT = "MOONSHOT"
|
||||
ZHIPU = "ZHIPU"
|
||||
ANYSCALE = "ANYSCALE"
|
||||
DEEPINFRA = "DEEPINFRA"
|
||||
CEREBRAS = "CEREBRAS"
|
||||
SAMBANOVA = "SAMBANOVA"
|
||||
AI21 = "AI21"
|
||||
CLOUDFLARE = "CLOUDFLARE"
|
||||
DATABRICKS = "DATABRICKS"
|
||||
COMETAPI = "COMETAPI"
|
||||
HUGGINGFACE = "HUGGINGFACE"
|
||||
GITHUB_MODELS = "GITHUB_MODELS"
|
||||
MINIMAX = "MINIMAX"
|
||||
CUSTOM = "CUSTOM"
|
||||
|
||||
|
||||
class ModelSource(StrEnum):
|
||||
DISCOVERED = "DISCOVERED"
|
||||
MANUAL = "MANUAL"
|
||||
class ImageGenProvider(StrEnum):
|
||||
"""
|
||||
Enum for image generation providers supported by LiteLLM.
|
||||
This is a subset of LLM providers — only those that support image generation.
|
||||
See: https://docs.litellm.ai/docs/image_generation#supported-providers
|
||||
"""
|
||||
|
||||
OPENAI = "OPENAI"
|
||||
AZURE_OPENAI = "AZURE_OPENAI"
|
||||
GOOGLE = "GOOGLE" # Google AI Studio
|
||||
VERTEX_AI = "VERTEX_AI"
|
||||
BEDROCK = "BEDROCK" # AWS Bedrock
|
||||
RECRAFT = "RECRAFT"
|
||||
OPENROUTER = "OPENROUTER"
|
||||
XINFERENCE = "XINFERENCE"
|
||||
NSCALE = "NSCALE"
|
||||
|
||||
|
||||
class VisionProvider(StrEnum):
|
||||
OPENAI = "OPENAI"
|
||||
ANTHROPIC = "ANTHROPIC"
|
||||
GOOGLE = "GOOGLE"
|
||||
AZURE_OPENAI = "AZURE_OPENAI"
|
||||
VERTEX_AI = "VERTEX_AI"
|
||||
BEDROCK = "BEDROCK"
|
||||
XAI = "XAI"
|
||||
OPENROUTER = "OPENROUTER"
|
||||
OLLAMA = "OLLAMA"
|
||||
GROQ = "GROQ"
|
||||
TOGETHER_AI = "TOGETHER_AI"
|
||||
FIREWORKS_AI = "FIREWORKS_AI"
|
||||
DEEPSEEK = "DEEPSEEK"
|
||||
MISTRAL = "MISTRAL"
|
||||
CUSTOM = "CUSTOM"
|
||||
|
||||
|
||||
class LogLevel(StrEnum):
|
||||
|
|
@ -252,7 +320,7 @@ class PagePurchaseStatus(StrEnum):
|
|||
FAILED = "failed"
|
||||
|
||||
|
||||
class CreditPurchaseStatus(StrEnum):
|
||||
class PremiumTokenPurchaseStatus(StrEnum):
|
||||
PENDING = "pending"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
|
@ -264,27 +332,26 @@ INCENTIVE_TASKS_CONFIG = {
|
|||
IncentiveTaskType.GITHUB_STAR: {
|
||||
"title": "Star our GitHub repository",
|
||||
"description": "Show your support by starring SurfSense on GitHub",
|
||||
# Credit reward in USD micro-units (1_000_000 == $1.00). $0.03.
|
||||
"credit_micros_reward": 30000,
|
||||
"pages_reward": 30,
|
||||
"action_url": "https://github.com/MODSetter/SurfSense",
|
||||
},
|
||||
IncentiveTaskType.REDDIT_FOLLOW: {
|
||||
"title": "Join our Subreddit",
|
||||
"description": "Join the SurfSense community on Reddit",
|
||||
"credit_micros_reward": 30000,
|
||||
"pages_reward": 30,
|
||||
"action_url": "https://www.reddit.com/r/SurfSense/",
|
||||
},
|
||||
IncentiveTaskType.DISCORD_JOIN: {
|
||||
"title": "Join our Discord",
|
||||
"description": "Join the SurfSense community on Discord",
|
||||
"credit_micros_reward": 40000,
|
||||
"pages_reward": 40,
|
||||
"action_url": "https://discord.gg/ejRNvftDp9",
|
||||
},
|
||||
# Future tasks can be configured here:
|
||||
# IncentiveTaskType.GITHUB_ISSUE: {
|
||||
# "title": "Create an issue",
|
||||
# "description": "Help improve SurfSense by reporting bugs or suggesting features",
|
||||
# "credit_micros_reward": 50000,
|
||||
# "pages_reward": 50,
|
||||
# "action_url": "https://github.com/MODSetter/SurfSense/issues/new/choose",
|
||||
# },
|
||||
}
|
||||
|
|
@ -638,11 +705,11 @@ class NewChatThread(BaseModel, TimestampMixin):
|
|||
default=False,
|
||||
server_default="false",
|
||||
)
|
||||
# Auto model pin for this thread: concrete resolved global LLM
|
||||
# Auto (Fastest) model pin for this thread: concrete resolved global LLM
|
||||
# config id. NULL means no pin; Auto will resolve on the next turn.
|
||||
# Single-writer invariant: only app.services.auto_model_pin_service sets
|
||||
# or clears this column (plus bulk clears when a search space's
|
||||
# chat_model_id changes). Unindexed: all reads are by primary key.
|
||||
# agent_llm_id changes). Unindexed: all reads are by primary key.
|
||||
pinned_llm_config_id = Column(Integer, nullable=True)
|
||||
|
||||
# Surface metadata for first-party SurfSense and external chat threads.
|
||||
|
|
@ -1423,10 +1490,7 @@ class Document(BaseModel, TimestampMixin):
|
|||
created_by = relationship("User", back_populates="documents")
|
||||
connector = relationship("SearchSourceConnector", back_populates="documents")
|
||||
chunks = relationship(
|
||||
"Chunk",
|
||||
back_populates="document",
|
||||
cascade="all, delete-orphan",
|
||||
order_by="Chunk.position",
|
||||
"Chunk", back_populates="document", cascade="all, delete-orphan"
|
||||
)
|
||||
# Original upload + future derived artifacts (redacted, filled-form).
|
||||
# Model lives in app.file_storage.persistence to keep that feature cohesive.
|
||||
|
|
@ -1462,11 +1526,6 @@ class Chunk(BaseModel, TimestampMixin):
|
|||
|
||||
content = Column(Text, nullable=False)
|
||||
embedding = Column(Vector(config.embedding_model_instance.dimension))
|
||||
# Explicit document order; ids don't follow it since incremental
|
||||
# re-indexing keeps unchanged rows across edits. Deliberately not indexed:
|
||||
# ordering reads are document-scoped (covered by ix_chunks_document_id) and
|
||||
# building a position index on the large chunks table is not worth it.
|
||||
position = Column(Integer, nullable=False, server_default="0")
|
||||
|
||||
document_id = Column(
|
||||
Integer,
|
||||
|
|
@ -1477,6 +1536,41 @@ class Chunk(BaseModel, TimestampMixin):
|
|||
document = relationship("Document", back_populates="chunks")
|
||||
|
||||
|
||||
class Podcast(BaseModel, TimestampMixin):
|
||||
"""Podcast model for storing generated podcasts."""
|
||||
|
||||
__tablename__ = "podcasts"
|
||||
|
||||
title = Column(String(500), nullable=False)
|
||||
podcast_transcript = Column(JSONB, nullable=True)
|
||||
file_location = Column(Text, nullable=True)
|
||||
status = Column(
|
||||
SQLAlchemyEnum(
|
||||
PodcastStatus,
|
||||
name="podcast_status",
|
||||
create_type=False,
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
nullable=False,
|
||||
default=PodcastStatus.READY,
|
||||
server_default="ready",
|
||||
index=True,
|
||||
)
|
||||
|
||||
search_space_id = Column(
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
search_space = relationship("SearchSpace", back_populates="podcasts")
|
||||
|
||||
thread_id = Column(
|
||||
Integer,
|
||||
ForeignKey("new_chat_threads.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
thread = relationship("NewChatThread")
|
||||
|
||||
|
||||
class VideoPresentation(BaseModel, TimestampMixin):
|
||||
"""Video presentation model for storing AI-generated video presentations.
|
||||
|
||||
|
|
@ -1548,80 +1642,73 @@ class Report(BaseModel, TimestampMixin):
|
|||
thread = relationship("NewChatThread")
|
||||
|
||||
|
||||
class Connection(BaseModel, TimestampMixin):
|
||||
__tablename__ = "connections"
|
||||
class ImageGenerationConfig(BaseModel, TimestampMixin):
|
||||
"""
|
||||
Dedicated configuration table for image generation models.
|
||||
|
||||
provider = Column(String(100), nullable=False, index=True)
|
||||
base_url = Column(String(500), nullable=True)
|
||||
api_key = Column(String, nullable=True)
|
||||
extra = Column(JSONB, nullable=False, default=dict, server_default="{}")
|
||||
scope = Column(SQLAlchemyEnum(ConnectionScope), nullable=False, index=True)
|
||||
enabled = Column(Boolean, nullable=False, default=True, server_default="true")
|
||||
Separate from NewLLMConfig because image generation models don't need
|
||||
system_instructions, citations_enabled, or use_default_system_instructions.
|
||||
They only need provider credentials and model parameters.
|
||||
"""
|
||||
|
||||
__tablename__ = "image_generation_configs"
|
||||
|
||||
name = Column(String(100), nullable=False, index=True)
|
||||
description = Column(String(500), nullable=True)
|
||||
|
||||
# Provider & model (uses ImageGenProvider, NOT LiteLLMProvider)
|
||||
provider = Column(SQLAlchemyEnum(ImageGenProvider), nullable=False)
|
||||
custom_provider = Column(String(100), nullable=True)
|
||||
model_name = Column(String(100), nullable=False)
|
||||
|
||||
# Credentials
|
||||
api_key = Column(String, nullable=False)
|
||||
api_base = Column(String(500), nullable=True)
|
||||
api_version = Column(String(50), nullable=True) # Azure-specific
|
||||
|
||||
# Additional litellm parameters
|
||||
litellm_params = Column(JSON, nullable=True, default={})
|
||||
|
||||
# Relationships
|
||||
search_space_id = Column(
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
search_space = relationship(
|
||||
"SearchSpace", back_populates="image_generation_configs"
|
||||
)
|
||||
|
||||
# User who created this config
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
user = relationship("User", back_populates="image_generation_configs")
|
||||
|
||||
|
||||
class VisionLLMConfig(BaseModel, TimestampMixin):
|
||||
__tablename__ = "vision_llm_configs"
|
||||
|
||||
name = Column(String(100), nullable=False, index=True)
|
||||
description = Column(String(500), nullable=True)
|
||||
|
||||
provider = Column(SQLAlchemyEnum(VisionProvider), nullable=False)
|
||||
custom_provider = Column(String(100), nullable=True)
|
||||
model_name = Column(String(100), nullable=False)
|
||||
|
||||
api_key = Column(String, nullable=False)
|
||||
api_base = Column(String(500), nullable=True)
|
||||
api_version = Column(String(50), nullable=True)
|
||||
|
||||
litellm_params = Column(JSON, nullable=True, default={})
|
||||
|
||||
search_space_id = Column(
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=True
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
search_space = relationship("SearchSpace", back_populates="vision_llm_configs")
|
||||
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
|
||||
search_space = relationship("SearchSpace", back_populates="connections")
|
||||
user = relationship("User", back_populates="connections")
|
||||
models = relationship(
|
||||
"Model",
|
||||
back_populates="connection",
|
||||
order_by="Model.id",
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
CheckConstraint(
|
||||
"(scope = 'GLOBAL' AND search_space_id IS NULL AND user_id IS NULL) OR "
|
||||
"(scope = 'SEARCH_SPACE' AND search_space_id IS NOT NULL AND user_id IS NOT NULL) OR "
|
||||
"(scope = 'USER' AND user_id IS NOT NULL)",
|
||||
name="ck_connections_scope_owner",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class Model(BaseModel, TimestampMixin):
|
||||
__tablename__ = "models"
|
||||
|
||||
connection_id = Column(
|
||||
Integer,
|
||||
ForeignKey("connections.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
model_id = Column(String(255), nullable=False)
|
||||
display_name = Column(String(255), nullable=True)
|
||||
source = Column(
|
||||
SQLAlchemyEnum(ModelSource),
|
||||
nullable=False,
|
||||
default=ModelSource.DISCOVERED,
|
||||
server_default=ModelSource.DISCOVERED.value,
|
||||
)
|
||||
supports_chat = Column(Boolean, nullable=True)
|
||||
max_input_tokens = Column(Integer, nullable=True)
|
||||
supports_image_input = Column(Boolean, nullable=True)
|
||||
supports_tools = Column(Boolean, nullable=True)
|
||||
supports_image_generation = Column(Boolean, nullable=True)
|
||||
capabilities_override = Column(
|
||||
JSONB, nullable=False, default=dict, server_default="{}"
|
||||
)
|
||||
enabled = Column(Boolean, nullable=False, default=True, server_default="true")
|
||||
billing_tier = Column(String(50), nullable=True, index=True)
|
||||
catalog = Column(JSONB, nullable=False, default=dict, server_default="{}")
|
||||
|
||||
connection = relationship("Connection", back_populates="models")
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"connection_id", "model_id", name="uq_models_connection_model_id"
|
||||
),
|
||||
Index("ix_models_model_id", "model_id"),
|
||||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
user = relationship("User", back_populates="vision_llm_configs")
|
||||
|
||||
|
||||
class ImageGeneration(BaseModel, TimestampMixin):
|
||||
|
|
@ -1655,9 +1742,10 @@ class ImageGeneration(BaseModel, TimestampMixin):
|
|||
style = Column(String(50), nullable=True) # Model-specific style parameter
|
||||
response_format = Column(String(50), nullable=True) # "url" or "b64_json"
|
||||
|
||||
# Image generation model provenance.
|
||||
# 0 = Auto mode, negative IDs = GLOBAL models, positive IDs = Model records.
|
||||
image_gen_model_id = Column(Integer, nullable=True)
|
||||
# Image generation config reference
|
||||
# 0 = Auto mode (router), negative IDs = global configs from YAML,
|
||||
# positive IDs = ImageGenerationConfig records in DB
|
||||
image_generation_config_id = Column(Integer, nullable=True)
|
||||
|
||||
# Response data (full litellm response as JSONB) — present on success
|
||||
response_data = Column(JSONB, nullable=True)
|
||||
|
|
@ -1699,19 +1787,19 @@ class SearchSpace(BaseModel, TimestampMixin):
|
|||
|
||||
shared_memory_md = Column(Text, nullable=True, server_default="")
|
||||
|
||||
# Connection/model role bindings.
|
||||
# Note: ID values preserve the existing convention:
|
||||
# - 0: Auto mode
|
||||
# - Negative IDs: Global virtual models from global_llm_config.yaml
|
||||
# - Positive IDs: User/search-space models from the models table
|
||||
chat_model_id = Column(
|
||||
Integer, nullable=True, default=0, server_default="0"
|
||||
# Search space-level LLM preferences (shared by all members)
|
||||
# Note: ID values:
|
||||
# - 0: Auto mode (uses LiteLLM Router for load balancing) - default for new search spaces
|
||||
# - Negative IDs: Global configs from YAML
|
||||
# - Positive IDs: Custom configs from DB (NewLLMConfig table)
|
||||
agent_llm_id = Column(
|
||||
Integer, nullable=True, default=0
|
||||
) # For agent/chat operations, defaults to Auto mode
|
||||
image_gen_model_id = Column(
|
||||
Integer, nullable=True, default=0, server_default="0"
|
||||
) # For image generation, defaults to Auto mode when eligible
|
||||
vision_model_id = Column(
|
||||
Integer, nullable=True, default=0, server_default="0"
|
||||
image_generation_config_id = Column(
|
||||
Integer, nullable=True, default=0
|
||||
) # For image generation, defaults to Auto mode
|
||||
vision_llm_config_id = Column(
|
||||
Integer, nullable=True, default=0
|
||||
) # For vision/screenshot analysis, defaults to Auto mode
|
||||
|
||||
ai_file_sort_enabled = Column(
|
||||
|
|
@ -1783,12 +1871,23 @@ class SearchSpace(BaseModel, TimestampMixin):
|
|||
order_by="SearchSourceConnector.id",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
connections = relationship(
|
||||
"Connection",
|
||||
new_llm_configs = relationship(
|
||||
"NewLLMConfig",
|
||||
back_populates="search_space",
|
||||
order_by="Connection.id",
|
||||
order_by="NewLLMConfig.id",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
image_generation_configs = relationship(
|
||||
"ImageGenerationConfig",
|
||||
back_populates="search_space",
|
||||
order_by="ImageGenerationConfig.id",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
vision_llm_configs = relationship(
|
||||
"VisionLLMConfig",
|
||||
back_populates="search_space",
|
||||
order_by="VisionLLMConfig.id",
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
automations = relationship(
|
||||
|
|
@ -1891,6 +1990,64 @@ class SearchSourceConnector(BaseModel, TimestampMixin):
|
|||
documents = relationship("Document", back_populates="connector")
|
||||
|
||||
|
||||
class NewLLMConfig(BaseModel, TimestampMixin):
|
||||
"""
|
||||
New LLM configuration table that combines model settings with prompt configuration.
|
||||
|
||||
This table provides:
|
||||
- LLM model configuration (provider, model_name, api_key, etc.)
|
||||
- Configurable system instructions (defaults to SURFSENSE_SYSTEM_INSTRUCTIONS)
|
||||
- Citation toggle (enable/disable citation instructions)
|
||||
|
||||
Note: Tools instructions are built by get_tools_instructions(thread_visibility) (personal vs shared memory).
|
||||
"""
|
||||
|
||||
__tablename__ = "new_llm_configs"
|
||||
|
||||
name = Column(String(100), nullable=False, index=True)
|
||||
description = Column(String(500), nullable=True)
|
||||
|
||||
# === LLM Model Configuration (from original LLMConfig, excluding 'language') ===
|
||||
# Provider from the enum
|
||||
provider = Column(SQLAlchemyEnum(LiteLLMProvider), nullable=False)
|
||||
# Custom provider name when provider is CUSTOM
|
||||
custom_provider = Column(String(100), nullable=True)
|
||||
# Just the model name without provider prefix
|
||||
model_name = Column(String(100), nullable=False)
|
||||
# API Key should be encrypted before storing
|
||||
api_key = Column(String, nullable=False)
|
||||
api_base = Column(String(500), nullable=True)
|
||||
# For any other parameters that litellm supports
|
||||
litellm_params = Column(JSON, nullable=True, default={})
|
||||
|
||||
# === Prompt Configuration ===
|
||||
# Configurable system instructions (defaults to SURFSENSE_SYSTEM_INSTRUCTIONS)
|
||||
# Users can customize this from the UI
|
||||
system_instructions = Column(
|
||||
Text,
|
||||
nullable=False,
|
||||
default="", # Empty string means use default SURFSENSE_SYSTEM_INSTRUCTIONS
|
||||
)
|
||||
# Whether to use the default system instructions when system_instructions is empty
|
||||
use_default_system_instructions = Column(Boolean, nullable=False, default=True)
|
||||
|
||||
# Citation toggle - when enabled, SURFSENSE_CITATION_INSTRUCTIONS is injected
|
||||
# When disabled, an anti-citation prompt is injected instead
|
||||
citations_enabled = Column(Boolean, nullable=False, default=True)
|
||||
|
||||
# === Relationships ===
|
||||
search_space_id = Column(
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
search_space = relationship("SearchSpace", back_populates="new_llm_configs")
|
||||
|
||||
# User who created this config
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
user = relationship("User", back_populates="new_llm_configs")
|
||||
|
||||
|
||||
class Log(BaseModel, TimestampMixin):
|
||||
__tablename__ = "logs"
|
||||
|
||||
|
|
@ -1912,7 +2069,7 @@ class UserIncentiveTask(BaseModel, TimestampMixin):
|
|||
"""
|
||||
Tracks completed incentive tasks for users.
|
||||
Each user can only complete each task type once.
|
||||
When a task is completed, the user's credit_micros_balance is increased.
|
||||
When a task is completed, the user's pages_limit is increased.
|
||||
"""
|
||||
|
||||
__tablename__ = "user_incentive_tasks"
|
||||
|
|
@ -1931,8 +2088,7 @@ class UserIncentiveTask(BaseModel, TimestampMixin):
|
|||
index=True,
|
||||
)
|
||||
task_type = Column(SQLAlchemyEnum(IncentiveTaskType), nullable=False, index=True)
|
||||
# Credit reward granted in USD micro-units (1_000_000 == $1.00).
|
||||
credit_micros_awarded = Column(BigInteger, nullable=False)
|
||||
pages_awarded = Column(Integer, nullable=False)
|
||||
completed_at = Column(
|
||||
TIMESTAMP(timezone=True),
|
||||
nullable=False,
|
||||
|
|
@ -1975,18 +2131,18 @@ class PagePurchase(Base, TimestampMixin):
|
|||
user = relationship("User", back_populates="page_purchases")
|
||||
|
||||
|
||||
class CreditPurchase(Base, TimestampMixin):
|
||||
"""Tracks Stripe checkout sessions used to grant credit (USD micro-units).
|
||||
class PremiumTokenPurchase(Base, TimestampMixin):
|
||||
"""Tracks Stripe checkout sessions used to grant additional premium credit (USD micro-units).
|
||||
|
||||
Renamed from ``premium_token_purchases`` in migration 156 as part of the
|
||||
unified-credits wallet. ``credit_micros_granted`` stores the USD-micro
|
||||
amount added to ``user.credit_micros_balance`` on fulfillment.
|
||||
|
||||
``source`` distinguishes a user-initiated checkout from an automatic
|
||||
off-session top-up (auto-reload), added in the auto-reload migration.
|
||||
Note: the table name is preserved (``premium_token_purchases``) for
|
||||
operational continuity even though the unit is now USD micro-credits
|
||||
instead of raw tokens. The ``credit_micros_granted`` column replaced
|
||||
the legacy ``tokens_granted`` in migration 140; the stored values
|
||||
were not transformed because the prior $1 = 1M tokens Stripe price
|
||||
makes the unit conversion 1:1 numerically.
|
||||
"""
|
||||
|
||||
__tablename__ = "credit_purchases"
|
||||
__tablename__ = "premium_token_purchases"
|
||||
__allow_unmapped__ = True
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
|
|
@ -2004,18 +2160,15 @@ class CreditPurchase(Base, TimestampMixin):
|
|||
credit_micros_granted = Column(BigInteger, nullable=False)
|
||||
amount_total = Column(Integer, nullable=True)
|
||||
currency = Column(String(10), nullable=True)
|
||||
source = Column(
|
||||
String(20), nullable=False, default="checkout", server_default="checkout"
|
||||
)
|
||||
status = Column(
|
||||
SQLAlchemyEnum(CreditPurchaseStatus),
|
||||
SQLAlchemyEnum(PremiumTokenPurchaseStatus),
|
||||
nullable=False,
|
||||
default=CreditPurchaseStatus.PENDING,
|
||||
default=PremiumTokenPurchaseStatus.PENDING,
|
||||
index=True,
|
||||
)
|
||||
completed_at = Column(TIMESTAMP(timezone=True), nullable=True)
|
||||
|
||||
user = relationship("User", back_populates="credit_purchases")
|
||||
user = relationship("User", back_populates="premium_token_purchases")
|
||||
|
||||
|
||||
class SearchSpaceRole(BaseModel, TimestampMixin):
|
||||
|
|
@ -2257,8 +2410,22 @@ if config.AUTH_TYPE == "GOOGLE":
|
|||
passive_deletes=True,
|
||||
)
|
||||
|
||||
connections = relationship(
|
||||
"Connection",
|
||||
# LLM configs created by this user
|
||||
new_llm_configs = relationship(
|
||||
"NewLLMConfig",
|
||||
back_populates="user",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
# Image generation configs created by this user
|
||||
image_generation_configs = relationship(
|
||||
"ImageGenerationConfig",
|
||||
back_populates="user",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
vision_llm_configs = relationship(
|
||||
"VisionLLMConfig",
|
||||
back_populates="user",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
|
@ -2281,40 +2448,33 @@ if config.AUTH_TYPE == "GOOGLE":
|
|||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
credit_purchases = relationship(
|
||||
"CreditPurchase",
|
||||
premium_token_purchases = relationship(
|
||||
"PremiumTokenPurchase",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
# Unified credit wallet (USD micro-units, 1_000_000 == $1.00).
|
||||
# Decreases on use (ETL pages + premium model calls), increases on
|
||||
# purchase / incentive grant / auto-reload. May dip slightly negative
|
||||
# when an actual cost exceeds its pre-charge estimate; UI clamps at $0.
|
||||
credit_micros_balance = Column(
|
||||
# Page usage tracking for ETL services
|
||||
pages_limit = Column(
|
||||
Integer,
|
||||
nullable=False,
|
||||
default=config.PAGES_LIMIT,
|
||||
server_default=str(config.PAGES_LIMIT),
|
||||
)
|
||||
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
|
||||
|
||||
premium_credit_micros_limit = Column(
|
||||
BigInteger,
|
||||
nullable=False,
|
||||
default=config.DEFAULT_CREDIT_MICROS_BALANCE,
|
||||
server_default=str(config.DEFAULT_CREDIT_MICROS_BALANCE),
|
||||
default=config.PREMIUM_CREDIT_MICROS_LIMIT,
|
||||
server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT),
|
||||
)
|
||||
# In-flight reservation holds (released/settled at finalize).
|
||||
credit_micros_reserved = Column(
|
||||
premium_credit_micros_used = Column(
|
||||
BigInteger, nullable=False, default=0, server_default="0"
|
||||
)
|
||||
|
||||
# Auto-reload (off-session Stripe top-up), behind AUTO_RELOAD_ENABLED.
|
||||
# ``stripe_customer_id`` + ``auto_reload_payment_method_id`` are the
|
||||
# saved-card plumbing; thresholds are micro-USD. ``auto_reload_failed_at``
|
||||
# is set (and ``auto_reload_enabled`` flipped off) when an off-session
|
||||
# charge is declined so the UI can prompt the user to fix their card.
|
||||
stripe_customer_id = Column(String, nullable=True)
|
||||
auto_reload_enabled = Column(
|
||||
Boolean, nullable=False, default=False, server_default="false"
|
||||
premium_credit_micros_reserved = Column(
|
||||
BigInteger, nullable=False, default=0, server_default="0"
|
||||
)
|
||||
auto_reload_threshold_micros = Column(BigInteger, nullable=True)
|
||||
auto_reload_amount_micros = Column(BigInteger, nullable=True)
|
||||
auto_reload_payment_method_id = Column(String, nullable=True)
|
||||
auto_reload_failed_at = Column(TIMESTAMP(timezone=True), nullable=True)
|
||||
|
||||
# User profile from OAuth
|
||||
display_name = Column(String, nullable=True)
|
||||
|
|
@ -2389,8 +2549,22 @@ else:
|
|||
passive_deletes=True,
|
||||
)
|
||||
|
||||
connections = relationship(
|
||||
"Connection",
|
||||
# LLM configs created by this user
|
||||
new_llm_configs = relationship(
|
||||
"NewLLMConfig",
|
||||
back_populates="user",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
# Image generation configs created by this user
|
||||
image_generation_configs = relationship(
|
||||
"ImageGenerationConfig",
|
||||
back_populates="user",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
vision_llm_configs = relationship(
|
||||
"VisionLLMConfig",
|
||||
back_populates="user",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
|
@ -2413,40 +2587,33 @@ else:
|
|||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
credit_purchases = relationship(
|
||||
"CreditPurchase",
|
||||
premium_token_purchases = relationship(
|
||||
"PremiumTokenPurchase",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
# Unified credit wallet (USD micro-units, 1_000_000 == $1.00).
|
||||
# Decreases on use (ETL pages + premium model calls), increases on
|
||||
# purchase / incentive grant / auto-reload. May dip slightly negative
|
||||
# when an actual cost exceeds its pre-charge estimate; UI clamps at $0.
|
||||
credit_micros_balance = Column(
|
||||
# Page usage tracking for ETL services
|
||||
pages_limit = Column(
|
||||
Integer,
|
||||
nullable=False,
|
||||
default=config.PAGES_LIMIT,
|
||||
server_default=str(config.PAGES_LIMIT),
|
||||
)
|
||||
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
|
||||
|
||||
premium_credit_micros_limit = Column(
|
||||
BigInteger,
|
||||
nullable=False,
|
||||
default=config.DEFAULT_CREDIT_MICROS_BALANCE,
|
||||
server_default=str(config.DEFAULT_CREDIT_MICROS_BALANCE),
|
||||
default=config.PREMIUM_CREDIT_MICROS_LIMIT,
|
||||
server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT),
|
||||
)
|
||||
# In-flight reservation holds (released/settled at finalize).
|
||||
credit_micros_reserved = Column(
|
||||
premium_credit_micros_used = Column(
|
||||
BigInteger, nullable=False, default=0, server_default="0"
|
||||
)
|
||||
|
||||
# Auto-reload (off-session Stripe top-up), behind AUTO_RELOAD_ENABLED.
|
||||
# ``stripe_customer_id`` + ``auto_reload_payment_method_id`` are the
|
||||
# saved-card plumbing; thresholds are micro-USD. ``auto_reload_failed_at``
|
||||
# is set (and ``auto_reload_enabled`` flipped off) when an off-session
|
||||
# charge is declined so the UI can prompt the user to fix their card.
|
||||
stripe_customer_id = Column(String, nullable=True)
|
||||
auto_reload_enabled = Column(
|
||||
Boolean, nullable=False, default=False, server_default="false"
|
||||
premium_credit_micros_reserved = Column(
|
||||
BigInteger, nullable=False, default=0, server_default="0"
|
||||
)
|
||||
auto_reload_threshold_micros = Column(BigInteger, nullable=True)
|
||||
auto_reload_amount_micros = Column(BigInteger, nullable=True)
|
||||
auto_reload_payment_method_id = Column(String, nullable=True)
|
||||
auto_reload_failed_at = Column(TIMESTAMP(timezone=True), nullable=True)
|
||||
|
||||
# User profile (can be set manually for non-OAuth users)
|
||||
display_name = Column(String, nullable=True)
|
||||
|
|
@ -2720,38 +2887,8 @@ from app.automations.persistence import ( # noqa: E402, F401
|
|||
AutomationRun,
|
||||
AutomationTrigger,
|
||||
)
|
||||
from app.etl_pipeline.cache.persistence.models import CachedParse # noqa: E402, F401
|
||||
from app.file_storage.persistence import DocumentFile # noqa: E402, F401
|
||||
from app.indexing_pipeline.cache.persistence.models import ( # noqa: E402, F401
|
||||
CachedEmbeddingSet,
|
||||
)
|
||||
from app.notifications.persistence import Notification # noqa: E402, F401
|
||||
from app.podcasts.persistence import ( # noqa: E402, F401
|
||||
Podcast,
|
||||
PodcastStatus,
|
||||
)
|
||||
|
||||
|
||||
def _build_connect_args() -> dict:
|
||||
"""Build driver connect_args, including a protective idle-in-transaction
|
||||
timeout for asyncpg connections.
|
||||
|
||||
A single abandoned ``idle in transaction`` session can hold table/row locks
|
||||
indefinitely and wedge writes plus boot-time DDL (the classic "FastAPI
|
||||
stuck at application startup" failure). Setting
|
||||
``idle_in_transaction_session_timeout`` server-side makes Postgres reap such
|
||||
sessions automatically. It never affects sessions that are actively running
|
||||
statements — only ones that opened a transaction and went idle.
|
||||
"""
|
||||
connect_args: dict = {}
|
||||
idle_ms = config.DB_IDLE_IN_TX_TIMEOUT_MS
|
||||
# ``server_settings`` is asyncpg-specific; only apply it for that driver.
|
||||
if idle_ms and idle_ms > 0 and DATABASE_URL and "asyncpg" in DATABASE_URL:
|
||||
connect_args["server_settings"] = {
|
||||
"idle_in_transaction_session_timeout": str(idle_ms)
|
||||
}
|
||||
return connect_args
|
||||
|
||||
|
||||
engine = create_async_engine(
|
||||
DATABASE_URL,
|
||||
|
|
@ -2760,7 +2897,6 @@ engine = create_async_engine(
|
|||
pool_recycle=1800,
|
||||
pool_pre_ping=True,
|
||||
pool_timeout=30,
|
||||
connect_args=_build_connect_args(),
|
||||
)
|
||||
async_session_maker = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
|
|
@ -2785,117 +2921,54 @@ async def shielded_async_session():
|
|||
await session.close()
|
||||
|
||||
|
||||
# (index_name, table, CREATE statement). Built with CONCURRENTLY so an index
|
||||
# build only takes a non-blocking ShareUpdateExclusiveLock — ingestion
|
||||
# INSERT/UPDATE on documents/chunks keep flowing while the index builds, and a
|
||||
# slow build can never freeze the FastAPI lifespan or block writers.
|
||||
_INDEX_DEFINITIONS: list[tuple[str, str, str]] = [
|
||||
(
|
||||
"document_vector_index",
|
||||
"documents",
|
||||
"CREATE INDEX CONCURRENTLY IF NOT EXISTS document_vector_index ON documents USING hnsw (embedding public.vector_cosine_ops)",
|
||||
),
|
||||
(
|
||||
"document_search_index",
|
||||
"documents",
|
||||
"CREATE INDEX CONCURRENTLY IF NOT EXISTS document_search_index ON documents USING gin (to_tsvector('english', content))",
|
||||
),
|
||||
(
|
||||
"chucks_vector_index",
|
||||
"chunks",
|
||||
"CREATE INDEX CONCURRENTLY IF NOT EXISTS chucks_vector_index ON chunks USING hnsw (embedding public.vector_cosine_ops)",
|
||||
),
|
||||
(
|
||||
"chucks_search_index",
|
||||
"chunks",
|
||||
"CREATE INDEX CONCURRENTLY IF NOT EXISTS chucks_search_index ON chunks USING gin (to_tsvector('english', content))",
|
||||
),
|
||||
# pg_trgm index for efficient ILIKE '%term%' searches on titles — critical
|
||||
# for the document mention picker (@mentions) to scale.
|
||||
(
|
||||
"idx_documents_title_trgm",
|
||||
"documents",
|
||||
"CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_documents_title_trgm ON documents USING gin (title gin_trgm_ops)",
|
||||
),
|
||||
(
|
||||
"idx_documents_search_space_id",
|
||||
"documents",
|
||||
"CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_documents_search_space_id ON documents (search_space_id)",
|
||||
),
|
||||
# Covering index for "recent documents" query — enables index-only scan.
|
||||
(
|
||||
"idx_documents_search_space_updated",
|
||||
"documents",
|
||||
"CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_documents_search_space_updated ON documents (search_space_id, updated_at DESC NULLS LAST) INCLUDE (id, title, document_type)",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
async def _drop_invalid_index(conn, name: str) -> None:
|
||||
"""Drop a leftover *invalid* index so it can be rebuilt.
|
||||
|
||||
A ``CREATE INDEX CONCURRENTLY`` that is interrupted (timeout, crash,
|
||||
cancellation) leaves behind an ``indisvalid = false`` index. Because the
|
||||
name now exists, a later ``CREATE INDEX CONCURRENTLY IF NOT EXISTS`` would
|
||||
skip it and the broken index would persist forever. Detect and drop it
|
||||
first.
|
||||
"""
|
||||
result = await conn.execute(
|
||||
text("SELECT indisvalid FROM pg_index WHERE indexrelid = to_regclass(:n)"),
|
||||
{"n": name},
|
||||
)
|
||||
row = result.first()
|
||||
if row is not None and row[0] is False:
|
||||
logger.warning(
|
||||
"[startup] dropping invalid leftover index %s before rebuild", name
|
||||
async def setup_indexes():
|
||||
async with engine.begin() as conn:
|
||||
# Create indexes
|
||||
# Document embedding indexes
|
||||
await conn.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS document_vector_index ON documents USING hnsw (embedding public.vector_cosine_ops)"
|
||||
)
|
||||
)
|
||||
await conn.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS document_search_index ON documents USING gin (to_tsvector('english', content))"
|
||||
)
|
||||
)
|
||||
# Document Chuck Indexes
|
||||
await conn.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS chucks_vector_index ON chunks USING hnsw (embedding public.vector_cosine_ops)"
|
||||
)
|
||||
)
|
||||
await conn.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS chucks_search_index ON chunks USING gin (to_tsvector('english', content))"
|
||||
)
|
||||
)
|
||||
# pg_trgm indexes for efficient ILIKE '%term%' searches on titles
|
||||
# Critical for document mention picker (@mentions) to scale
|
||||
await conn.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_documents_title_trgm ON documents USING gin (title gin_trgm_ops)"
|
||||
)
|
||||
)
|
||||
# B-tree index on search_space_id for fast filtering
|
||||
await conn.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_documents_search_space_id ON documents (search_space_id)"
|
||||
)
|
||||
)
|
||||
# Covering index for "recent documents" query - enables index-only scan
|
||||
await conn.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_documents_search_space_updated ON documents (search_space_id, updated_at DESC NULLS LAST) INCLUDE (id, title, document_type)"
|
||||
)
|
||||
)
|
||||
await conn.execute(text(f'DROP INDEX CONCURRENTLY IF EXISTS "{name}"'))
|
||||
|
||||
|
||||
async def setup_indexes() -> None:
|
||||
"""Ensure search/vector indexes exist without ever blocking startup.
|
||||
|
||||
Each index is created with ``CONCURRENTLY`` (so it never takes a blocking
|
||||
SHARE lock on documents/chunks) under a short per-session ``lock_timeout``
|
||||
(so a contended boot fails fast instead of hanging the lifespan forever).
|
||||
Failures are logged and swallowed per-index — a missing index just gets
|
||||
retried on the next boot rather than crash-looping the API.
|
||||
"""
|
||||
lock_timeout_ms = int(config.DB_DDL_LOCK_TIMEOUT_MS)
|
||||
# AUTOCOMMIT is mandatory: CREATE INDEX CONCURRENTLY cannot run inside a
|
||||
# transaction block.
|
||||
async with engine.connect() as base_conn:
|
||||
conn = await base_conn.execution_options(isolation_level="AUTOCOMMIT")
|
||||
await conn.execute(text(f"SET lock_timeout = {lock_timeout_ms}"))
|
||||
for name, table, ddl in _INDEX_DEFINITIONS:
|
||||
try:
|
||||
await _drop_invalid_index(conn, name)
|
||||
await conn.execute(text(ddl))
|
||||
except Exception as exc:
|
||||
# Non-fatal by design: a missing index is retried next boot.
|
||||
logger.warning(
|
||||
"[startup] index %s on %s not ready (%s: %s); "
|
||||
"will retry on next boot",
|
||||
name,
|
||||
table,
|
||||
exc.__class__.__name__,
|
||||
exc,
|
||||
)
|
||||
|
||||
|
||||
async def create_db_and_tables():
|
||||
if not config.DB_BOOTSTRAP_ON_STARTUP:
|
||||
logger.info(
|
||||
"[startup] DB bootstrap skipped (DB_BOOTSTRAP_ON_STARTUP=FALSE); "
|
||||
"schema/indexes are expected to be managed by migrations"
|
||||
)
|
||||
return
|
||||
|
||||
lock_timeout_ms = int(config.DB_DDL_LOCK_TIMEOUT_MS)
|
||||
async with engine.begin() as conn:
|
||||
# Fail fast instead of hanging forever if another session holds a
|
||||
# conflicting lock on a table we need to touch.
|
||||
await conn.execute(text(f"SET LOCAL lock_timeout = {lock_timeout_ms}"))
|
||||
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
|
||||
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm"))
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
|
|
|||
|
|
@ -1,11 +0,0 @@
|
|||
"""Content-addressed reuse of expensive ETL parser output across workspaces."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.etl_pipeline.cache.cached_extraction import extract_with_cache
|
||||
from app.etl_pipeline.cache.service import EtlCacheService
|
||||
|
||||
__all__ = [
|
||||
"EtlCacheService",
|
||||
"extract_with_cache",
|
||||
]
|
||||
|
|
@ -1,86 +0,0 @@
|
|||
"""Entry point: serve ETL parses from cache, parsing only on a miss."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
from app.config import config
|
||||
from app.etl_pipeline.cache.eligibility import is_parse_cacheable
|
||||
from app.etl_pipeline.cache.schemas import ParseKey
|
||||
from app.etl_pipeline.cache.service import EtlCacheService
|
||||
from app.etl_pipeline.cache.settings import load_etl_cache_settings
|
||||
from app.etl_pipeline.etl_document import EtlRequest, EtlResult
|
||||
from app.etl_pipeline.etl_pipeline_service import EtlPipelineService
|
||||
from app.observability import metrics
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_HASH_CHUNK = 1024 * 1024
|
||||
|
||||
|
||||
async def extract_with_cache(request: EtlRequest, *, vision_llm=None) -> EtlResult:
|
||||
"""Drop-in for ``EtlPipelineService.extract`` that reuses prior parser output."""
|
||||
settings = load_etl_cache_settings()
|
||||
|
||||
cacheable = is_parse_cacheable(
|
||||
filename=request.filename,
|
||||
etl_service=config.ETL_SERVICE,
|
||||
cache_enabled=settings.enabled,
|
||||
has_vision_llm=vision_llm is not None,
|
||||
)
|
||||
if not cacheable:
|
||||
return await EtlPipelineService(vision_llm=vision_llm).extract(request)
|
||||
|
||||
key = ParseKey.for_document(
|
||||
await asyncio.to_thread(_hash_file, request.file_path),
|
||||
etl_service=config.ETL_SERVICE,
|
||||
mode=request.processing_mode.value,
|
||||
version=settings.parser_version,
|
||||
)
|
||||
|
||||
cached_result = await _recall(key)
|
||||
if cached_result is not None:
|
||||
metrics.record_etl_cache_lookup(
|
||||
etl_service=key.etl_service, mode=key.mode, outcome="hit"
|
||||
)
|
||||
logger.debug("ETL cache hit for %s", key.source_sha256)
|
||||
return cached_result
|
||||
|
||||
metrics.record_etl_cache_lookup(
|
||||
etl_service=key.etl_service, mode=key.mode, outcome="miss"
|
||||
)
|
||||
result = await EtlPipelineService(vision_llm=vision_llm).extract(request)
|
||||
await _remember(key, result)
|
||||
return result
|
||||
|
||||
|
||||
async def _recall(key: ParseKey) -> EtlResult | None:
|
||||
# Caching is best-effort: any failure falls through to a normal parse.
|
||||
try:
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
|
||||
async with get_celery_session_maker()() as session:
|
||||
return await EtlCacheService(session).recall(key)
|
||||
except Exception:
|
||||
logger.warning("ETL cache recall failed; parsing fresh", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
async def _remember(key: ParseKey, result: EtlResult) -> None:
|
||||
try:
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
|
||||
async with get_celery_session_maker()() as session:
|
||||
await EtlCacheService(session).remember(key, result)
|
||||
except Exception:
|
||||
logger.warning("ETL cache write failed; result not cached", exc_info=True)
|
||||
|
||||
|
||||
def _hash_file(path: str) -> str:
|
||||
digest = hashlib.sha256()
|
||||
with open(path, "rb") as handle:
|
||||
for chunk in iter(lambda: handle.read(_HASH_CHUNK), b""):
|
||||
digest.update(chunk)
|
||||
return digest.hexdigest()
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
"""Gating rule: may this upload be served from / written to the parse cache?"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.etl_pipeline.file_classifier import FileCategory, classify_file
|
||||
|
||||
|
||||
def is_parse_cacheable(
|
||||
*,
|
||||
filename: str,
|
||||
etl_service: str | None,
|
||||
cache_enabled: bool,
|
||||
has_vision_llm: bool,
|
||||
) -> bool:
|
||||
"""Only deterministic document parses are shareable across workspaces.
|
||||
|
||||
Vision-LLM runs append model-generated content not captured by the cache key,
|
||||
and a missing ETL service means there is no document parser to key against --
|
||||
both bypass the cache. Non-document categories (plaintext, audio, images,
|
||||
direct-convert) are cheap or parser-agnostic and are handled outside it.
|
||||
"""
|
||||
if not cache_enabled:
|
||||
return False
|
||||
if has_vision_llm:
|
||||
return False
|
||||
if not etl_service:
|
||||
return False
|
||||
return classify_file(filename) == FileCategory.DOCUMENT
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
"""Background pruning of the parse cache by age and size budget."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .task import evict_etl_cache_task
|
||||
|
||||
__all__ = [
|
||||
"evict_etl_cache_task",
|
||||
]
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
"""Pure selection rules for which cached entries to drop."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
from app.etl_pipeline.cache.schemas import EvictionCandidate
|
||||
|
||||
|
||||
def select_over_budget(
|
||||
coldest_first: Iterable[EvictionCandidate],
|
||||
*,
|
||||
current_total_bytes: int,
|
||||
max_total_bytes: int,
|
||||
) -> list[EvictionCandidate]:
|
||||
"""Pick coldest entries until the footprint drops under the budget."""
|
||||
bytes_to_free = current_total_bytes - max_total_bytes
|
||||
if bytes_to_free <= 0:
|
||||
return []
|
||||
|
||||
chosen: list[EvictionCandidate] = []
|
||||
bytes_freed = 0
|
||||
for candidate in coldest_first:
|
||||
if bytes_freed >= bytes_to_free:
|
||||
break
|
||||
chosen.append(candidate)
|
||||
bytes_freed += candidate.size_bytes
|
||||
return chosen
|
||||
|
|
@ -1,68 +0,0 @@
|
|||
"""Celery task that prunes the parse cache by TTL, then by size budget."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.etl_pipeline.cache.eviction.policy import select_over_budget
|
||||
from app.etl_pipeline.cache.persistence import CachedParseRepository
|
||||
from app.etl_pipeline.cache.schemas import EvictionCandidate
|
||||
from app.etl_pipeline.cache.settings import load_etl_cache_settings
|
||||
from app.etl_pipeline.cache.storage import MarkdownCacheStore
|
||||
from app.observability import metrics
|
||||
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@celery_app.task(name="evict_etl_cache")
|
||||
def evict_etl_cache_task():
|
||||
return run_async_celery_task(_evict)
|
||||
|
||||
|
||||
async def _evict() -> None:
|
||||
"""Expire stale entries, then shed the coldest overflow only if still over budget."""
|
||||
settings = load_etl_cache_settings()
|
||||
if not settings.enabled:
|
||||
return
|
||||
|
||||
store = MarkdownCacheStore()
|
||||
async with get_celery_session_maker()() as session:
|
||||
index = CachedParseRepository(session)
|
||||
|
||||
cutoff = datetime.now(UTC) - timedelta(days=settings.ttl_days)
|
||||
expired = await index.select_expired(
|
||||
cutoff=cutoff, limit=settings.eviction_batch
|
||||
)
|
||||
await _drop(index, store, expired, phase="ttl")
|
||||
|
||||
total = await index.total_size_bytes()
|
||||
if total > settings.max_total_bytes:
|
||||
coldest = await index.select_coldest(limit=settings.eviction_batch)
|
||||
over_budget = select_over_budget(
|
||||
coldest,
|
||||
current_total_bytes=total,
|
||||
max_total_bytes=settings.max_total_bytes,
|
||||
)
|
||||
await _drop(index, store, over_budget, phase="size")
|
||||
|
||||
|
||||
async def _drop(
|
||||
index: CachedParseRepository,
|
||||
store: MarkdownCacheStore,
|
||||
candidates: list[EvictionCandidate],
|
||||
*,
|
||||
phase: str,
|
||||
) -> None:
|
||||
if not candidates:
|
||||
return
|
||||
for candidate in candidates:
|
||||
# Drop the index row even if the blob delete fails (orphan blob is harmless).
|
||||
with contextlib.suppress(Exception):
|
||||
await store.delete(candidate.storage_key)
|
||||
await index.delete_by_ids([candidate.id for candidate in candidates])
|
||||
metrics.record_etl_cache_eviction(len(candidates), phase=phase)
|
||||
logger.info("Evicted %d cached parses (%s)", len(candidates), phase)
|
||||
|
|
@ -1,11 +0,0 @@
|
|||
"""Database access for cached parse rows."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .models import CachedParse
|
||||
from .repository import CachedParseRepository
|
||||
|
||||
__all__ = [
|
||||
"CachedParse",
|
||||
"CachedParseRepository",
|
||||
]
|
||||
|
|
@ -1,49 +0,0 @@
|
|||
"""``etl_cache_parses``: one reusable parser result per (bytes + recipe)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import (
|
||||
BigInteger,
|
||||
Column,
|
||||
DateTime,
|
||||
Index,
|
||||
Integer,
|
||||
String,
|
||||
UniqueConstraint,
|
||||
)
|
||||
|
||||
from app.db import BaseModel, TimestampMixin
|
||||
|
||||
|
||||
class CachedParse(BaseModel, TimestampMixin):
|
||||
__tablename__ = "etl_cache_parses"
|
||||
|
||||
# Key: raw bytes + the recipe that produced the markdown.
|
||||
source_sha256 = Column(String(64), nullable=False)
|
||||
etl_service = Column(String(32), nullable=False)
|
||||
mode = Column(String(16), nullable=False)
|
||||
parser_version = Column(Integer, nullable=False)
|
||||
|
||||
# Where the markdown blob lives (kept out of the row to stay small).
|
||||
storage_backend = Column(String(32), nullable=False)
|
||||
storage_key = Column(String, nullable=False)
|
||||
size_bytes = Column(BigInteger, nullable=False)
|
||||
|
||||
# Payload needed to rebuild the EtlResult on a hit.
|
||||
content_type = Column(String(32), nullable=False)
|
||||
actual_pages = Column(Integer, nullable=False, default=0, server_default="0")
|
||||
|
||||
# Drives eviction (popularity + recency).
|
||||
times_reused = Column(BigInteger, nullable=False, default=0, server_default="0")
|
||||
last_used_at = Column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"source_sha256",
|
||||
"etl_service",
|
||||
"mode",
|
||||
"parser_version",
|
||||
name="uq_etl_cache_parses_key",
|
||||
),
|
||||
Index("ix_etl_cache_parses_last_used_at", "last_used_at"),
|
||||
)
|
||||
|
|
@ -1,121 +0,0 @@
|
|||
"""CRUD and eviction selectors for ``etl_cache_parses`` (no business rules)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import delete, func, select, update
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.etl_pipeline.cache.schemas import EvictionCandidate, ParseKey
|
||||
|
||||
from .models import CachedParse
|
||||
|
||||
_EVICTION_COLUMNS = (
|
||||
CachedParse.id,
|
||||
CachedParse.storage_key,
|
||||
CachedParse.size_bytes,
|
||||
CachedParse.last_used_at,
|
||||
CachedParse.times_reused,
|
||||
)
|
||||
|
||||
|
||||
def _as_eviction_candidate(row) -> EvictionCandidate:
|
||||
return EvictionCandidate(
|
||||
id=row.id,
|
||||
storage_key=row.storage_key,
|
||||
size_bytes=row.size_bytes,
|
||||
last_used_at=row.last_used_at,
|
||||
times_reused=row.times_reused,
|
||||
)
|
||||
|
||||
|
||||
class CachedParseRepository:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def get(self, key: ParseKey) -> CachedParse | None:
|
||||
result = await self._session.execute(
|
||||
select(CachedParse).where(
|
||||
CachedParse.source_sha256 == key.source_sha256,
|
||||
CachedParse.etl_service == key.etl_service,
|
||||
CachedParse.mode == key.mode,
|
||||
CachedParse.parser_version == key.version,
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def insert(
|
||||
self,
|
||||
*,
|
||||
key: ParseKey,
|
||||
content_type: str,
|
||||
actual_pages: int,
|
||||
storage_backend: str,
|
||||
storage_key: str,
|
||||
size_bytes: int,
|
||||
) -> None:
|
||||
# Concurrent writers parse identical bytes, so a lost race is harmless.
|
||||
now = datetime.now(UTC)
|
||||
await self._session.execute(
|
||||
pg_insert(CachedParse)
|
||||
.values(
|
||||
source_sha256=key.source_sha256,
|
||||
etl_service=key.etl_service,
|
||||
mode=key.mode,
|
||||
parser_version=key.version,
|
||||
content_type=content_type,
|
||||
actual_pages=actual_pages,
|
||||
storage_backend=storage_backend,
|
||||
storage_key=storage_key,
|
||||
size_bytes=size_bytes,
|
||||
times_reused=0,
|
||||
last_used_at=now,
|
||||
created_at=now,
|
||||
)
|
||||
.on_conflict_do_nothing(constraint="uq_etl_cache_parses_key")
|
||||
)
|
||||
await self._session.commit()
|
||||
|
||||
async def mark_used(self, row_id: int) -> None:
|
||||
await self._session.execute(
|
||||
update(CachedParse)
|
||||
.where(CachedParse.id == row_id)
|
||||
.values(
|
||||
times_reused=CachedParse.times_reused + 1,
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
)
|
||||
await self._session.commit()
|
||||
|
||||
async def total_size_bytes(self) -> int:
|
||||
result = await self._session.execute(
|
||||
select(func.coalesce(func.sum(CachedParse.size_bytes), 0))
|
||||
)
|
||||
return int(result.scalar() or 0)
|
||||
|
||||
async def select_expired(
|
||||
self, *, cutoff: datetime, limit: int
|
||||
) -> list[EvictionCandidate]:
|
||||
result = await self._session.execute(
|
||||
select(*_EVICTION_COLUMNS)
|
||||
.where(CachedParse.last_used_at < cutoff)
|
||||
.order_by(CachedParse.last_used_at.asc())
|
||||
.limit(limit)
|
||||
)
|
||||
return [_as_eviction_candidate(row) for row in result]
|
||||
|
||||
async def select_coldest(self, *, limit: int) -> list[EvictionCandidate]:
|
||||
result = await self._session.execute(
|
||||
select(*_EVICTION_COLUMNS)
|
||||
.order_by(CachedParse.times_reused.asc(), CachedParse.last_used_at.asc())
|
||||
.limit(limit)
|
||||
)
|
||||
return [_as_eviction_candidate(row) for row in result]
|
||||
|
||||
async def delete_by_ids(self, ids: list[int]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
await self._session.execute(delete(CachedParse).where(CachedParse.id.in_(ids)))
|
||||
await self._session.commit()
|
||||
|
|
@ -1,11 +0,0 @@
|
|||
"""Pure value objects for the parse cache."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .eviction_candidate import EvictionCandidate
|
||||
from .parse_key import ParseKey
|
||||
|
||||
__all__ = [
|
||||
"EvictionCandidate",
|
||||
"ParseKey",
|
||||
]
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
"""Row projection handed to the eviction policy."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EvictionCandidate:
|
||||
id: int
|
||||
storage_key: str
|
||||
size_bytes: int
|
||||
last_used_at: datetime
|
||||
times_reused: int
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
"""Identity of a cacheable parse: equal keys yield identical markdown."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ParseKey:
|
||||
source_sha256: str
|
||||
etl_service: str
|
||||
mode: str
|
||||
version: int
|
||||
|
||||
@classmethod
|
||||
def for_document(
|
||||
cls, source_sha256: str, *, etl_service: str, mode: str, version: int
|
||||
) -> ParseKey:
|
||||
return cls(
|
||||
source_sha256=source_sha256,
|
||||
etl_service=etl_service,
|
||||
mode=mode,
|
||||
version=version,
|
||||
)
|
||||
|
||||
@property
|
||||
def object_suffix(self) -> str:
|
||||
return f"{self.etl_service}.{self.mode}.v{self.version}.md"
|
||||
|
|
@ -1,53 +0,0 @@
|
|||
"""Recall and remember parser output, coordinating the index and blob store."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.etl_pipeline.cache.persistence import CachedParseRepository
|
||||
from app.etl_pipeline.cache.schemas import ParseKey
|
||||
from app.etl_pipeline.cache.storage import MarkdownCacheStore
|
||||
from app.etl_pipeline.etl_document import EtlResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EtlCacheService:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._index = CachedParseRepository(session)
|
||||
self._store = MarkdownCacheStore()
|
||||
|
||||
async def recall(self, key: ParseKey) -> EtlResult | None:
|
||||
"""Return the cached result, or None on a miss."""
|
||||
row = await self._index.get(key)
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
markdown = await self._store.load(row.storage_key)
|
||||
except Exception:
|
||||
# Index points at a blob that is gone; treat as a miss and re-parse.
|
||||
logger.warning("Cache blob missing: %s", row.storage_key, exc_info=True)
|
||||
return None
|
||||
|
||||
await self._index.mark_used(row.id)
|
||||
return EtlResult(
|
||||
markdown_content=markdown,
|
||||
etl_service=row.etl_service,
|
||||
actual_pages=row.actual_pages,
|
||||
content_type=row.content_type,
|
||||
)
|
||||
|
||||
async def remember(self, key: ParseKey, result: EtlResult) -> None:
|
||||
"""Store a freshly parsed result for future reuse."""
|
||||
storage_key = await self._store.save(key, result.markdown_content)
|
||||
await self._index.insert(
|
||||
key=key,
|
||||
content_type=result.content_type,
|
||||
actual_pages=result.actual_pages,
|
||||
storage_backend=self._store.backend_name,
|
||||
storage_key=storage_key,
|
||||
size_bytes=len(result.markdown_content.encode("utf-8")),
|
||||
)
|
||||
|
|
@ -1,33 +0,0 @@
|
|||
"""Cache configuration resolved from the central ``Config``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EtlCacheSettings:
|
||||
enabled: bool
|
||||
parser_version: int
|
||||
ttl_days: int
|
||||
max_total_bytes: int
|
||||
eviction_batch: int
|
||||
# None for any storage_* field means: reuse the main file_storage backend.
|
||||
storage_backend: str | None
|
||||
storage_container: str | None
|
||||
storage_local_root: str | None
|
||||
|
||||
|
||||
def load_etl_cache_settings() -> EtlCacheSettings:
|
||||
from app.config import config
|
||||
|
||||
return EtlCacheSettings(
|
||||
enabled=config.ETL_CACHE_ENABLED,
|
||||
parser_version=config.ETL_CACHE_PARSER_VERSION,
|
||||
ttl_days=config.ETL_CACHE_TTL_DAYS,
|
||||
max_total_bytes=config.ETL_CACHE_MAX_TOTAL_MB * 1024 * 1024,
|
||||
eviction_batch=config.ETL_CACHE_EVICTION_BATCH,
|
||||
storage_backend=config.ETL_CACHE_STORAGE_BACKEND or None,
|
||||
storage_container=config.ETL_CACHE_STORAGE_CONTAINER or None,
|
||||
storage_local_root=config.ETL_CACHE_STORAGE_LOCAL_PATH or None,
|
||||
)
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
"""Blob storage for cached parse markdown."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .markdown_store import MarkdownCacheStore
|
||||
|
||||
__all__ = [
|
||||
"MarkdownCacheStore",
|
||||
]
|
||||
|
|
@ -1,48 +0,0 @@
|
|||
"""Resolve the storage backend for cache blobs: shared main store or a dedicated one."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
|
||||
from app.file_storage.backends.base import StorageBackend
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def resolve_cache_backend() -> StorageBackend:
|
||||
from app.etl_pipeline.cache.settings import load_etl_cache_settings
|
||||
|
||||
settings = load_etl_cache_settings()
|
||||
|
||||
if not settings.storage_backend:
|
||||
from app.file_storage.factory import get_storage_backend
|
||||
|
||||
return get_storage_backend()
|
||||
|
||||
backend = settings.storage_backend.strip().lower()
|
||||
|
||||
if backend == "azure":
|
||||
from app.config import config
|
||||
|
||||
if not settings.storage_container:
|
||||
raise ValueError("ETL_CACHE_STORAGE_CONTAINER is required for azure cache.")
|
||||
if not config.AZURE_STORAGE_CONNECTION_STRING:
|
||||
raise ValueError(
|
||||
"AZURE_STORAGE_CONNECTION_STRING is required for azure cache."
|
||||
)
|
||||
from app.file_storage.backends.azure import AzureBlobBackend
|
||||
|
||||
return AzureBlobBackend(
|
||||
connection_string=config.AZURE_STORAGE_CONNECTION_STRING,
|
||||
container=settings.storage_container,
|
||||
)
|
||||
|
||||
if backend == "local":
|
||||
if not settings.storage_local_root:
|
||||
raise ValueError(
|
||||
"ETL_CACHE_STORAGE_LOCAL_PATH is required for local cache."
|
||||
)
|
||||
from app.file_storage.backends.local import LocalFileBackend
|
||||
|
||||
return LocalFileBackend(settings.storage_local_root)
|
||||
|
||||
raise ValueError(f"Unknown ETL_CACHE_STORAGE_BACKEND: {settings.storage_backend!r}")
|
||||
|
|
@ -1,35 +0,0 @@
|
|||
"""Read and write cached markdown blobs through the resolved backend."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.etl_pipeline.cache.schemas import ParseKey
|
||||
from app.etl_pipeline.cache.storage.backend import resolve_cache_backend
|
||||
from app.etl_pipeline.cache.storage.object_keys import build_parse_object_key
|
||||
|
||||
_MARKDOWN_CONTENT_TYPE = "text/markdown; charset=utf-8"
|
||||
|
||||
|
||||
class MarkdownCacheStore:
|
||||
def __init__(self) -> None:
|
||||
self._backend = resolve_cache_backend()
|
||||
|
||||
@property
|
||||
def backend_name(self) -> str:
|
||||
return self._backend.backend_name
|
||||
|
||||
async def save(self, key: ParseKey, markdown: str) -> str:
|
||||
"""Persist the markdown and return its storage key for the index row."""
|
||||
storage_key = build_parse_object_key(key)
|
||||
await self._backend.put(
|
||||
storage_key,
|
||||
markdown.encode("utf-8"),
|
||||
content_type=_MARKDOWN_CONTENT_TYPE,
|
||||
)
|
||||
return storage_key
|
||||
|
||||
async def load(self, storage_key: str) -> str:
|
||||
chunks = [chunk async for chunk in self._backend.open_stream(storage_key)]
|
||||
return b"".join(chunks).decode("utf-8")
|
||||
|
||||
async def delete(self, storage_key: str) -> None:
|
||||
await self._backend.delete(storage_key)
|
||||
|
|
@ -1,12 +0,0 @@
|
|||
"""Object keys for cached markdown, namespaced under a dedicated prefix."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.etl_pipeline.cache.schemas import ParseKey
|
||||
|
||||
CACHE_PREFIX = "etl_cache"
|
||||
|
||||
|
||||
def build_parse_object_key(key: ParseKey) -> str:
|
||||
# Content-addressed: identical bytes + recipe always map to the same key.
|
||||
return f"{CACHE_PREFIX}/{key.source_sha256}/{key.object_suffix}"
|
||||
|
|
@ -8,7 +8,7 @@ from app.config import config
|
|||
|
||||
|
||||
def require_gateway_enabled() -> None:
|
||||
"""FastAPI dependency that gates gateway operational routes on the global flag.
|
||||
"""FastAPI dependency that gates all gateway HTTP routes on the global flag.
|
||||
|
||||
Returns 404 (rather than 503) when ``GATEWAY_ENABLED`` is FALSE so that
|
||||
disabling the gateway makes its webhook/OAuth/pairing surface indistinguishable
|
||||
|
|
|
|||
|
|
@ -1,11 +0,0 @@
|
|||
"""Content-addressed reuse of chunk+embedding output across workspaces."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.indexing_pipeline.cache.cached_indexing import build_chunk_embeddings
|
||||
from app.indexing_pipeline.cache.service import EmbeddingCacheService
|
||||
|
||||
__all__ = [
|
||||
"EmbeddingCacheService",
|
||||
"build_chunk_embeddings",
|
||||
]
|
||||
|
|
@ -1,129 +0,0 @@
|
|||
"""Entry point: serve chunk embeddings from cache, embedding only on a miss.
|
||||
|
||||
Embeddings are a pure function of the markdown, the embedding model, and the
|
||||
chunker -- so identical markdown is chunked and embedded once and reused across
|
||||
workspaces, even when it came from different sources.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
from app.config import config
|
||||
from app.indexing_pipeline.cache.eligibility import is_embedding_cacheable
|
||||
from app.indexing_pipeline.cache.schemas import CachedChunk, EmbeddingKey, EmbeddingSet
|
||||
from app.indexing_pipeline.cache.service import EmbeddingCacheService
|
||||
from app.indexing_pipeline.cache.settings import load_embedding_cache_settings
|
||||
from app.indexing_pipeline.document_chunker import chunk_text, chunk_text_hybrid
|
||||
from app.indexing_pipeline.document_embedder import embed_texts
|
||||
from app.observability import metrics
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ChunkPair = tuple[str, np.ndarray]
|
||||
|
||||
|
||||
async def build_chunk_embeddings(
|
||||
markdown: str, *, use_code_chunker: bool
|
||||
) -> tuple[np.ndarray, list[ChunkPair]]:
|
||||
"""Return the document-level vector and ordered ``(chunk_text, vector)`` pairs.
|
||||
|
||||
Drop-in for the inline chunk+embed step; reuses prior output when the same
|
||||
markdown has already been embedded with the current model and chunker.
|
||||
"""
|
||||
settings = load_embedding_cache_settings()
|
||||
chunker_kind = "code" if use_code_chunker else "hybrid"
|
||||
embedding_dim = getattr(config.embedding_model_instance, "dimension", None)
|
||||
|
||||
cacheable = is_embedding_cacheable(
|
||||
cache_enabled=settings.enabled,
|
||||
embedding_model=config.EMBEDDING_MODEL,
|
||||
embedding_dim=embedding_dim,
|
||||
)
|
||||
if not cacheable:
|
||||
return await _compute(markdown, use_code_chunker=use_code_chunker)
|
||||
|
||||
key = EmbeddingKey(
|
||||
markdown_sha256=_hash_text(markdown),
|
||||
embedding_model=config.EMBEDDING_MODEL,
|
||||
embedding_dim=int(embedding_dim),
|
||||
chunker_kind=chunker_kind,
|
||||
chunker_version=settings.chunker_version,
|
||||
)
|
||||
|
||||
cached = await _recall(key)
|
||||
if cached is not None:
|
||||
metrics.record_embedding_cache_lookup(
|
||||
embedding_model=key.embedding_model,
|
||||
chunker_kind=chunker_kind,
|
||||
outcome="hit",
|
||||
)
|
||||
logger.debug("Embedding cache hit for %s", key.markdown_sha256)
|
||||
return cached.summary_embedding, [(c.text, c.embedding) for c in cached.chunks]
|
||||
|
||||
metrics.record_embedding_cache_lookup(
|
||||
embedding_model=key.embedding_model, chunker_kind=chunker_kind, outcome="miss"
|
||||
)
|
||||
summary_embedding, chunk_pairs = await _compute(
|
||||
markdown, use_code_chunker=use_code_chunker
|
||||
)
|
||||
await _remember(key, summary_embedding, chunk_pairs)
|
||||
return summary_embedding, chunk_pairs
|
||||
|
||||
|
||||
async def chunk_markdown(markdown: str, *, use_code_chunker: bool) -> list[str]:
|
||||
"""Chunk markdown into ordered texts with the pipeline's chunker selection."""
|
||||
if use_code_chunker:
|
||||
return await asyncio.to_thread(chunk_text, markdown, use_code_chunker=True)
|
||||
# Table-aware hybrid chunker keeps Markdown tables intact (issue #1334).
|
||||
return await asyncio.to_thread(chunk_text_hybrid, markdown)
|
||||
|
||||
|
||||
async def embed_batch(texts: list[str]) -> list[np.ndarray]:
|
||||
"""Embed texts in one batch off the event loop."""
|
||||
return await asyncio.to_thread(embed_texts, texts)
|
||||
|
||||
|
||||
async def _compute(
|
||||
markdown: str, *, use_code_chunker: bool
|
||||
) -> tuple[np.ndarray, list[ChunkPair]]:
|
||||
chunk_texts = await chunk_markdown(markdown, use_code_chunker=use_code_chunker)
|
||||
embeddings = await embed_batch([markdown, *chunk_texts])
|
||||
summary_embedding, *chunk_embeddings = embeddings
|
||||
return summary_embedding, list(zip(chunk_texts, chunk_embeddings, strict=False))
|
||||
|
||||
|
||||
async def _recall(key: EmbeddingKey) -> EmbeddingSet | None:
|
||||
# Caching is best-effort: any failure falls through to a normal embed.
|
||||
try:
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
|
||||
async with get_celery_session_maker()() as session:
|
||||
return await EmbeddingCacheService(session).recall(key)
|
||||
except Exception:
|
||||
logger.warning("Embedding cache recall failed; embedding fresh", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
async def _remember(
|
||||
key: EmbeddingKey, summary_embedding: np.ndarray, chunk_pairs: list[ChunkPair]
|
||||
) -> None:
|
||||
try:
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
|
||||
embedding_set = EmbeddingSet(
|
||||
summary_embedding=summary_embedding,
|
||||
chunks=[CachedChunk(text=text, embedding=vec) for text, vec in chunk_pairs],
|
||||
)
|
||||
async with get_celery_session_maker()() as session:
|
||||
await EmbeddingCacheService(session).remember(key, embedding_set)
|
||||
except Exception:
|
||||
logger.warning("Embedding cache write failed; result not cached", exc_info=True)
|
||||
|
||||
|
||||
def _hash_text(text: str) -> str:
|
||||
return hashlib.sha256(text.encode("utf-8")).hexdigest()
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
"""Gating rule: may this document be served from / written to the embedding cache?"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def is_embedding_cacheable(
|
||||
*,
|
||||
cache_enabled: bool,
|
||||
embedding_model: str | None,
|
||||
embedding_dim: int | None,
|
||||
) -> bool:
|
||||
"""Cache only when a concrete embedding model and dimension are configured.
|
||||
|
||||
Without a model there is nothing to key against, and without a dimension the
|
||||
blob's integrity guard cannot run -- both bypass the cache.
|
||||
"""
|
||||
if not cache_enabled:
|
||||
return False
|
||||
if not embedding_model:
|
||||
return False
|
||||
return bool(embedding_dim)
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
"""Background pruning of the embedding cache by age and size budget."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .task import evict_embedding_cache_task
|
||||
|
||||
__all__ = [
|
||||
"evict_embedding_cache_task",
|
||||
]
|
||||
|
|
@ -1,68 +0,0 @@
|
|||
"""Celery task that prunes the embedding cache by TTL, then by size budget."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.etl_pipeline.cache.eviction.policy import select_over_budget
|
||||
from app.etl_pipeline.cache.schemas import EvictionCandidate
|
||||
from app.indexing_pipeline.cache.persistence import CachedEmbeddingSetRepository
|
||||
from app.indexing_pipeline.cache.settings import load_embedding_cache_settings
|
||||
from app.indexing_pipeline.cache.storage import EmbeddingCacheStore
|
||||
from app.observability import metrics
|
||||
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@celery_app.task(name="evict_embedding_cache")
|
||||
def evict_embedding_cache_task():
|
||||
return run_async_celery_task(_evict)
|
||||
|
||||
|
||||
async def _evict() -> None:
|
||||
"""Expire stale entries, then shed the coldest overflow only if still over budget."""
|
||||
settings = load_embedding_cache_settings()
|
||||
if not settings.enabled:
|
||||
return
|
||||
|
||||
store = EmbeddingCacheStore()
|
||||
async with get_celery_session_maker()() as session:
|
||||
index = CachedEmbeddingSetRepository(session)
|
||||
|
||||
cutoff = datetime.now(UTC) - timedelta(days=settings.ttl_days)
|
||||
expired = await index.select_expired(
|
||||
cutoff=cutoff, limit=settings.eviction_batch
|
||||
)
|
||||
await _drop(index, store, expired, phase="ttl")
|
||||
|
||||
total = await index.total_size_bytes()
|
||||
if total > settings.max_total_bytes:
|
||||
coldest = await index.select_coldest(limit=settings.eviction_batch)
|
||||
over_budget = select_over_budget(
|
||||
coldest,
|
||||
current_total_bytes=total,
|
||||
max_total_bytes=settings.max_total_bytes,
|
||||
)
|
||||
await _drop(index, store, over_budget, phase="size")
|
||||
|
||||
|
||||
async def _drop(
|
||||
index: CachedEmbeddingSetRepository,
|
||||
store: EmbeddingCacheStore,
|
||||
candidates: list[EvictionCandidate],
|
||||
*,
|
||||
phase: str,
|
||||
) -> None:
|
||||
if not candidates:
|
||||
return
|
||||
for candidate in candidates:
|
||||
# Drop the index row even if the blob delete fails (orphan blob is harmless).
|
||||
with contextlib.suppress(Exception):
|
||||
await store.delete(candidate.storage_key)
|
||||
await index.delete_by_ids([candidate.id for candidate in candidates])
|
||||
metrics.record_embedding_cache_eviction(len(candidates), phase=phase)
|
||||
logger.info("Evicted %d cached embedding sets (%s)", len(candidates), phase)
|
||||
|
|
@ -1,11 +0,0 @@
|
|||
"""Database access for cached embedding sets."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .models import CachedEmbeddingSet
|
||||
from .repository import CachedEmbeddingSetRepository
|
||||
|
||||
__all__ = [
|
||||
"CachedEmbeddingSet",
|
||||
"CachedEmbeddingSetRepository",
|
||||
]
|
||||
|
|
@ -1,47 +0,0 @@
|
|||
"""``embedding_cache_sets``: one reusable chunk+embedding set per markdown."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import (
|
||||
BigInteger,
|
||||
Column,
|
||||
DateTime,
|
||||
Index,
|
||||
Integer,
|
||||
String,
|
||||
UniqueConstraint,
|
||||
)
|
||||
|
||||
from app.db import BaseModel, TimestampMixin
|
||||
|
||||
|
||||
class CachedEmbeddingSet(BaseModel, TimestampMixin):
|
||||
__tablename__ = "embedding_cache_sets"
|
||||
|
||||
# Key: markdown text + the recipe that turned it into vectors.
|
||||
markdown_sha256 = Column(String(64), nullable=False)
|
||||
embedding_model = Column(String(255), nullable=False)
|
||||
embedding_dim = Column(Integer, nullable=False)
|
||||
chunker_kind = Column(String(8), nullable=False)
|
||||
chunker_version = Column(Integer, nullable=False)
|
||||
|
||||
# Where the embedding blob lives (kept out of the row to stay small).
|
||||
storage_backend = Column(String(32), nullable=False)
|
||||
storage_key = Column(String, nullable=False)
|
||||
size_bytes = Column(BigInteger, nullable=False)
|
||||
chunk_count = Column(Integer, nullable=False, default=0, server_default="0")
|
||||
|
||||
# Drives eviction (popularity + recency).
|
||||
times_reused = Column(BigInteger, nullable=False, default=0, server_default="0")
|
||||
last_used_at = Column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"markdown_sha256",
|
||||
"embedding_model",
|
||||
"chunker_kind",
|
||||
"chunker_version",
|
||||
name="uq_embedding_cache_sets_key",
|
||||
),
|
||||
Index("ix_embedding_cache_sets_last_used_at", "last_used_at"),
|
||||
)
|
||||
|
|
@ -1,126 +0,0 @@
|
|||
"""CRUD and eviction selectors for ``embedding_cache_sets`` (no business rules)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import delete, func, select, update
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.etl_pipeline.cache.schemas import EvictionCandidate
|
||||
from app.indexing_pipeline.cache.schemas import EmbeddingKey
|
||||
|
||||
from .models import CachedEmbeddingSet
|
||||
|
||||
_EVICTION_COLUMNS = (
|
||||
CachedEmbeddingSet.id,
|
||||
CachedEmbeddingSet.storage_key,
|
||||
CachedEmbeddingSet.size_bytes,
|
||||
CachedEmbeddingSet.last_used_at,
|
||||
CachedEmbeddingSet.times_reused,
|
||||
)
|
||||
|
||||
|
||||
def _as_eviction_candidate(row) -> EvictionCandidate:
|
||||
return EvictionCandidate(
|
||||
id=row.id,
|
||||
storage_key=row.storage_key,
|
||||
size_bytes=row.size_bytes,
|
||||
last_used_at=row.last_used_at,
|
||||
times_reused=row.times_reused,
|
||||
)
|
||||
|
||||
|
||||
class CachedEmbeddingSetRepository:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def get(self, key: EmbeddingKey) -> CachedEmbeddingSet | None:
|
||||
result = await self._session.execute(
|
||||
select(CachedEmbeddingSet).where(
|
||||
CachedEmbeddingSet.markdown_sha256 == key.markdown_sha256,
|
||||
CachedEmbeddingSet.embedding_model == key.embedding_model,
|
||||
CachedEmbeddingSet.chunker_kind == key.chunker_kind,
|
||||
CachedEmbeddingSet.chunker_version == key.chunker_version,
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def insert(
|
||||
self,
|
||||
*,
|
||||
key: EmbeddingKey,
|
||||
storage_backend: str,
|
||||
storage_key: str,
|
||||
size_bytes: int,
|
||||
chunk_count: int,
|
||||
) -> None:
|
||||
# Concurrent writers embed identical markdown, so a lost race is harmless.
|
||||
now = datetime.now(UTC)
|
||||
await self._session.execute(
|
||||
pg_insert(CachedEmbeddingSet)
|
||||
.values(
|
||||
markdown_sha256=key.markdown_sha256,
|
||||
embedding_model=key.embedding_model,
|
||||
embedding_dim=key.embedding_dim,
|
||||
chunker_kind=key.chunker_kind,
|
||||
chunker_version=key.chunker_version,
|
||||
storage_backend=storage_backend,
|
||||
storage_key=storage_key,
|
||||
size_bytes=size_bytes,
|
||||
chunk_count=chunk_count,
|
||||
times_reused=0,
|
||||
last_used_at=now,
|
||||
created_at=now,
|
||||
)
|
||||
.on_conflict_do_nothing(constraint="uq_embedding_cache_sets_key")
|
||||
)
|
||||
await self._session.commit()
|
||||
|
||||
async def mark_used(self, row_id: int) -> None:
|
||||
await self._session.execute(
|
||||
update(CachedEmbeddingSet)
|
||||
.where(CachedEmbeddingSet.id == row_id)
|
||||
.values(
|
||||
times_reused=CachedEmbeddingSet.times_reused + 1,
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
)
|
||||
await self._session.commit()
|
||||
|
||||
async def total_size_bytes(self) -> int:
|
||||
result = await self._session.execute(
|
||||
select(func.coalesce(func.sum(CachedEmbeddingSet.size_bytes), 0))
|
||||
)
|
||||
return int(result.scalar() or 0)
|
||||
|
||||
async def select_expired(
|
||||
self, *, cutoff: datetime, limit: int
|
||||
) -> list[EvictionCandidate]:
|
||||
result = await self._session.execute(
|
||||
select(*_EVICTION_COLUMNS)
|
||||
.where(CachedEmbeddingSet.last_used_at < cutoff)
|
||||
.order_by(CachedEmbeddingSet.last_used_at.asc())
|
||||
.limit(limit)
|
||||
)
|
||||
return [_as_eviction_candidate(row) for row in result]
|
||||
|
||||
async def select_coldest(self, *, limit: int) -> list[EvictionCandidate]:
|
||||
result = await self._session.execute(
|
||||
select(*_EVICTION_COLUMNS)
|
||||
.order_by(
|
||||
CachedEmbeddingSet.times_reused.asc(),
|
||||
CachedEmbeddingSet.last_used_at.asc(),
|
||||
)
|
||||
.limit(limit)
|
||||
)
|
||||
return [_as_eviction_candidate(row) for row in result]
|
||||
|
||||
async def delete_by_ids(self, ids: list[int]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
await self._session.execute(
|
||||
delete(CachedEmbeddingSet).where(CachedEmbeddingSet.id.in_(ids))
|
||||
)
|
||||
await self._session.commit()
|
||||
|
|
@ -1,12 +0,0 @@
|
|||
"""Pure value objects for the embedding cache."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .embedding_key import EmbeddingKey
|
||||
from .embedding_set import CachedChunk, EmbeddingSet
|
||||
|
||||
__all__ = [
|
||||
"CachedChunk",
|
||||
"EmbeddingKey",
|
||||
"EmbeddingSet",
|
||||
]
|
||||
|
|
@ -1,27 +0,0 @@
|
|||
"""Identity of a cacheable embedding set: equal keys yield identical vectors.
|
||||
|
||||
Embeddings depend on the markdown text, the embedding model, and the chunker --
|
||||
never on how the markdown was produced. So the key is the markdown's own hash
|
||||
plus the model and chunker recipe, not the upstream parse identity.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EmbeddingKey:
|
||||
markdown_sha256: str
|
||||
embedding_model: str
|
||||
embedding_dim: int
|
||||
chunker_kind: str
|
||||
chunker_version: int
|
||||
|
||||
@property
|
||||
def object_suffix(self) -> str:
|
||||
# Fingerprint the model so distinct models never share a blob, while the
|
||||
# markdown hash (the object's folder) stays human-readable.
|
||||
fingerprint = hashlib.sha256(self.embedding_model.encode("utf-8")).hexdigest()
|
||||
return f"{fingerprint[:16]}.{self.chunker_kind}.v{self.chunker_version}.emb"
|
||||
|
|
@ -1,29 +0,0 @@
|
|||
"""The cached payload: a document's chunk texts paired with their vectors."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class CachedChunk:
|
||||
text: str
|
||||
embedding: np.ndarray
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EmbeddingSet:
|
||||
"""Everything the indexer needs to rebuild a document's chunks without embedding.
|
||||
|
||||
``summary_embedding`` is the document-level vector; ``chunks`` are the ordered
|
||||
chunk texts and their vectors.
|
||||
"""
|
||||
|
||||
summary_embedding: np.ndarray
|
||||
chunks: list[CachedChunk]
|
||||
|
||||
@property
|
||||
def chunk_count(self) -> int:
|
||||
return len(self.chunks)
|
||||
|
|
@ -1,75 +0,0 @@
|
|||
"""Serialize an EmbeddingSet to a compact, self-describing blob (no pickle).
|
||||
|
||||
Layout: ``MAGIC | uint32 header_len | json header | float32 matrix``. The header
|
||||
carries the dim, chunk count, and ordered chunk texts; the matrix holds the
|
||||
summary vector followed by one row per chunk, all float32 for compactness.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import struct
|
||||
|
||||
import numpy as np
|
||||
|
||||
from app.indexing_pipeline.cache.schemas import CachedChunk, EmbeddingSet
|
||||
|
||||
# Marker at the start of every blob: "SurfSense EMBeddings, version 1"-> SSEMB1. Lets us
|
||||
# reject foreign blobs and bump the trailing digit if the layout ever changes.
|
||||
_MAGIC = b"SSEMB1"
|
||||
# 4-byte big-endian unsigned int written before the variable-length JSON header,
|
||||
# so the reader knows where the header ends and the float matrix begins.
|
||||
_HEADER_LEN = struct.Struct(">I")
|
||||
|
||||
|
||||
def serialize(embedding_set: EmbeddingSet) -> bytes:
|
||||
summary = np.asarray(embedding_set.summary_embedding, dtype=np.float32).reshape(-1)
|
||||
dim = int(summary.shape[0])
|
||||
|
||||
rows = [summary]
|
||||
texts: list[str] = []
|
||||
for chunk in embedding_set.chunks:
|
||||
vector = np.asarray(chunk.embedding, dtype=np.float32).reshape(-1)
|
||||
if vector.shape[0] != dim:
|
||||
raise ValueError(
|
||||
"All vectors in an embedding set must share one dimension."
|
||||
)
|
||||
rows.append(vector)
|
||||
texts.append(chunk.text)
|
||||
|
||||
matrix = np.stack(rows, axis=0)
|
||||
header = json.dumps(
|
||||
{"dim": dim, "count": len(texts), "texts": texts}, ensure_ascii=False
|
||||
).encode("utf-8")
|
||||
return b"".join(
|
||||
[_MAGIC, _HEADER_LEN.pack(len(header)), header, matrix.tobytes(order="C")]
|
||||
)
|
||||
|
||||
|
||||
def deserialize(blob: bytes) -> EmbeddingSet:
|
||||
view = memoryview(blob)
|
||||
if bytes(view[: len(_MAGIC)]) != _MAGIC:
|
||||
raise ValueError("Unrecognized embedding cache blob.")
|
||||
|
||||
offset = len(_MAGIC)
|
||||
(header_len,) = _HEADER_LEN.unpack(view[offset : offset + _HEADER_LEN.size])
|
||||
offset += _HEADER_LEN.size
|
||||
|
||||
header = json.loads(bytes(view[offset : offset + header_len]).decode("utf-8"))
|
||||
offset += header_len
|
||||
|
||||
dim = int(header["dim"])
|
||||
count = int(header["count"])
|
||||
texts: list[str] = header["texts"]
|
||||
|
||||
matrix = np.frombuffer(view[offset:], dtype=np.float32)
|
||||
if matrix.shape[0] != (count + 1) * dim:
|
||||
raise ValueError("Embedding cache blob is truncated or corrupt.")
|
||||
matrix = matrix.reshape(count + 1, dim)
|
||||
|
||||
return EmbeddingSet(
|
||||
summary_embedding=matrix[0],
|
||||
chunks=[
|
||||
CachedChunk(text=texts[i], embedding=matrix[i + 1]) for i in range(count)
|
||||
],
|
||||
)
|
||||
|
|
@ -1,51 +0,0 @@
|
|||
"""Recall and remember embedding sets, coordinating the index and blob store."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.indexing_pipeline.cache.persistence import CachedEmbeddingSetRepository
|
||||
from app.indexing_pipeline.cache.schemas import EmbeddingKey, EmbeddingSet
|
||||
from app.indexing_pipeline.cache.storage import EmbeddingCacheStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingCacheService:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._index = CachedEmbeddingSetRepository(session)
|
||||
self._store = EmbeddingCacheStore()
|
||||
|
||||
async def recall(self, key: EmbeddingKey) -> EmbeddingSet | None:
|
||||
"""Return the cached embedding set, or None on a miss."""
|
||||
row = await self._index.get(key)
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
embedding_set = await self._store.load(row.storage_key)
|
||||
except Exception:
|
||||
# Index points at a blob that is gone; treat as a miss and re-embed.
|
||||
logger.warning("Cache blob missing: %s", row.storage_key, exc_info=True)
|
||||
return None
|
||||
|
||||
if int(embedding_set.summary_embedding.shape[0]) != key.embedding_dim:
|
||||
# A model swapped its dimension under a reused name; never serve it.
|
||||
logger.warning("Cached embedding dimension mismatch: %s", row.storage_key)
|
||||
return None
|
||||
|
||||
await self._index.mark_used(row.id)
|
||||
return embedding_set
|
||||
|
||||
async def remember(self, key: EmbeddingKey, embedding_set: EmbeddingSet) -> None:
|
||||
"""Store a freshly embedded set for future reuse."""
|
||||
storage_key, size_bytes = await self._store.save(key, embedding_set)
|
||||
await self._index.insert(
|
||||
key=key,
|
||||
storage_backend=self._store.backend_name,
|
||||
storage_key=storage_key,
|
||||
size_bytes=size_bytes,
|
||||
chunk_count=embedding_set.chunk_count,
|
||||
)
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
"""Embedding-cache configuration resolved from the central ``Config``.
|
||||
|
||||
The blob backend is intentionally not configured here: it is shared with the ETL
|
||||
parse cache (see ``ETL_CACHE_STORAGE_*``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EmbeddingCacheSettings:
|
||||
enabled: bool
|
||||
chunker_version: int
|
||||
ttl_days: int
|
||||
max_total_bytes: int
|
||||
eviction_batch: int
|
||||
|
||||
|
||||
def load_embedding_cache_settings() -> EmbeddingCacheSettings:
|
||||
from app.config import config
|
||||
|
||||
return EmbeddingCacheSettings(
|
||||
enabled=config.EMBEDDING_CACHE_ENABLED,
|
||||
chunker_version=config.EMBEDDING_CACHE_CHUNKER_VERSION,
|
||||
ttl_days=config.EMBEDDING_CACHE_TTL_DAYS,
|
||||
max_total_bytes=config.EMBEDDING_CACHE_MAX_TOTAL_MB * 1024 * 1024,
|
||||
eviction_batch=config.EMBEDDING_CACHE_EVICTION_BATCH,
|
||||
)
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
"""Blob storage for cached embedding sets."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .embedding_store import EmbeddingCacheStore
|
||||
|
||||
__all__ = [
|
||||
"EmbeddingCacheStore",
|
||||
]
|
||||
|
|
@ -1,39 +0,0 @@
|
|||
"""Read and write cached embedding blobs through the shared cache backend.
|
||||
|
||||
The blob backend is shared with the ETL parse cache (same bucket / root), so
|
||||
markdown and its embeddings live side by side; only the object prefix differs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.etl_pipeline.cache.storage.backend import resolve_cache_backend
|
||||
from app.indexing_pipeline.cache.schemas import EmbeddingKey, EmbeddingSet
|
||||
from app.indexing_pipeline.cache.serialization import deserialize, serialize
|
||||
from app.indexing_pipeline.cache.storage.object_keys import build_embedding_object_key
|
||||
|
||||
_EMBEDDING_CONTENT_TYPE = "application/octet-stream"
|
||||
|
||||
|
||||
class EmbeddingCacheStore:
|
||||
def __init__(self) -> None:
|
||||
self._backend = resolve_cache_backend()
|
||||
|
||||
@property
|
||||
def backend_name(self) -> str:
|
||||
return self._backend.backend_name
|
||||
|
||||
async def save(
|
||||
self, key: EmbeddingKey, embedding_set: EmbeddingSet
|
||||
) -> tuple[str, int]:
|
||||
"""Persist the embedding set and return its storage key and byte size."""
|
||||
blob = serialize(embedding_set)
|
||||
storage_key = build_embedding_object_key(key)
|
||||
await self._backend.put(storage_key, blob, content_type=_EMBEDDING_CONTENT_TYPE)
|
||||
return storage_key, len(blob)
|
||||
|
||||
async def load(self, storage_key: str) -> EmbeddingSet:
|
||||
chunks = [chunk async for chunk in self._backend.open_stream(storage_key)]
|
||||
return deserialize(b"".join(chunks))
|
||||
|
||||
async def delete(self, storage_key: str) -> None:
|
||||
await self._backend.delete(storage_key)
|
||||
|
|
@ -1,12 +0,0 @@
|
|||
"""Object keys for cached embedding sets, namespaced under a dedicated prefix."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.indexing_pipeline.cache.schemas import EmbeddingKey
|
||||
|
||||
CACHE_PREFIX = "embedding_cache"
|
||||
|
||||
|
||||
def build_embedding_object_key(key: EmbeddingKey) -> str:
|
||||
# Content-addressed: identical markdown + recipe always map to the same key.
|
||||
return f"{CACHE_PREFIX}/{key.markdown_sha256}/{key.object_suffix}"
|
||||
|
|
@ -1,56 +0,0 @@
|
|||
"""Diff a document's existing chunk rows against its freshly chunked texts.
|
||||
|
||||
Embeddings are a pure function of chunk text, so a row whose content reappears
|
||||
in the new chunking keeps its embedding (and its HNSW/GIN index entries); only
|
||||
genuinely new texts are embedded and only vanished rows are deleted. Matching
|
||||
is a greedy multiset match on content in document order, so duplicate
|
||||
boilerplate chunks pair up one-to-one and reordered chunks become cheap
|
||||
position updates instead of delete+reinsert.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ExistingChunk:
|
||||
id: int
|
||||
content: str
|
||||
position: int
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ChunkPlan:
|
||||
"""The minimal set of writes that turns the stored chunks into the new ones.
|
||||
|
||||
``reused`` holds only kept rows whose position actually changed; rows that
|
||||
match in place need no write at all. Kept-row count (for metrics) is
|
||||
``len(existing) - len(to_delete)``.
|
||||
"""
|
||||
|
||||
reused: list[tuple[int, int]] # (existing_chunk_id, new_position)
|
||||
to_embed: list[tuple[int, str]] # (new_position, text)
|
||||
to_delete: list[int] # existing chunk ids
|
||||
|
||||
|
||||
def reconcile(existing: list[ExistingChunk], new_texts: list[str]) -> ChunkPlan:
|
||||
available: dict[str, deque[ExistingChunk]] = defaultdict(deque)
|
||||
for chunk in sorted(existing, key=lambda c: c.position):
|
||||
available[chunk.content].append(chunk)
|
||||
|
||||
reused: list[tuple[int, int]] = []
|
||||
to_embed: list[tuple[int, str]] = []
|
||||
|
||||
for new_position, text in enumerate(new_texts):
|
||||
matches = available.get(text)
|
||||
if matches:
|
||||
chunk = matches.popleft()
|
||||
if chunk.position != new_position:
|
||||
reused.append((chunk.id, new_position))
|
||||
else:
|
||||
to_embed.append((new_position, text))
|
||||
|
||||
to_delete = [chunk.id for queue in available.values() for chunk in queue]
|
||||
return ChunkPlan(reused=reused, to_embed=to_embed, to_delete=to_delete)
|
||||
|
|
@ -1,12 +1,12 @@
|
|||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import object_session
|
||||
from sqlalchemy.orm.attributes import set_committed_value
|
||||
|
||||
from app.db import Chunk, Document, DocumentStatus
|
||||
from app.db import Document, DocumentStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -22,6 +22,7 @@ async def rollback_and_persist_failure(
|
|||
try:
|
||||
await session.rollback()
|
||||
except Exception:
|
||||
# Session is completely dead; surface it but never raise.
|
||||
logger.warning(
|
||||
"Rollback failed; cannot persist failed status for document %s",
|
||||
getattr(document, "id", "unknown"),
|
||||
|
|
@ -34,6 +35,8 @@ async def rollback_and_persist_failure(
|
|||
document.status = DocumentStatus.failed(message)
|
||||
await session.commit()
|
||||
except Exception:
|
||||
# Best-effort: the document stays non-ready and is retried next sync.
|
||||
# Log it so a permanently-stuck document is at least traceable.
|
||||
logger.warning(
|
||||
"Could not persist failed status for document %s; will retry next sync",
|
||||
getattr(document, "id", "unknown"),
|
||||
|
|
@ -43,60 +46,12 @@ async def rollback_and_persist_failure(
|
|||
await session.rollback()
|
||||
|
||||
|
||||
async def persist_scratch_index(
|
||||
session: AsyncSession,
|
||||
document: Document,
|
||||
content: str,
|
||||
chunks: list[Chunk],
|
||||
*,
|
||||
batch_size: int,
|
||||
perf: logging.Logger,
|
||||
) -> None:
|
||||
"""Commit document content first, then chunk rows in batches, then mark ready."""
|
||||
if document.id is None:
|
||||
raise ValueError("document.id is required to persist chunks")
|
||||
|
||||
document.content = content
|
||||
document.updated_at = datetime.now(UTC)
|
||||
await session.commit()
|
||||
|
||||
t_persist = time.perf_counter()
|
||||
total = len(chunks)
|
||||
if total == 0:
|
||||
set_committed_value(document, "chunks", [])
|
||||
document.status = DocumentStatus.ready()
|
||||
document.updated_at = datetime.now(UTC)
|
||||
await session.commit()
|
||||
return
|
||||
|
||||
effective_batch = total if batch_size <= 0 else batch_size
|
||||
num_batches = (total + effective_batch - 1) // effective_batch
|
||||
doc_id = document.id
|
||||
|
||||
for batch_idx, start in enumerate(range(0, total, effective_batch), start=1):
|
||||
batch = chunks[start : start + effective_batch]
|
||||
t_batch = time.perf_counter()
|
||||
for chunk in batch:
|
||||
chunk.document_id = doc_id
|
||||
session.add_all(batch)
|
||||
await session.commit()
|
||||
perf.info(
|
||||
"[indexing] chunk batch doc=%d batch=%d/%d rows=%d in %.3fs",
|
||||
doc_id,
|
||||
batch_idx,
|
||||
num_batches,
|
||||
len(batch),
|
||||
time.perf_counter() - t_batch,
|
||||
)
|
||||
|
||||
def attach_chunks_to_document(document: Document, chunks: list) -> None:
|
||||
"""Assign chunks to a document without triggering SQLAlchemy async lazy loading."""
|
||||
set_committed_value(document, "chunks", chunks)
|
||||
document.status = DocumentStatus.ready()
|
||||
document.updated_at = datetime.now(UTC)
|
||||
await session.commit()
|
||||
perf.info(
|
||||
"[indexing] chunk persist doc=%d chunks=%d batches=%d in %.3fs",
|
||||
doc_id,
|
||||
total,
|
||||
num_batches,
|
||||
time.perf_counter() - t_persist,
|
||||
)
|
||||
session = object_session(document)
|
||||
if session is not None:
|
||||
if document.id is not None:
|
||||
for chunk in chunks:
|
||||
chunk.document_id = document.id
|
||||
session.add_all(chunks)
|
||||
|
|
|
|||
|
|
@ -14,8 +14,6 @@ from litellm.exceptions import (
|
|||
)
|
||||
from sqlalchemy.exc import IntegrityError as IntegrityError
|
||||
|
||||
from app.services.llm_error_adapter import LLMErrorCategory, adapt_llm_exception
|
||||
|
||||
# Tuples for use directly in except clauses.
|
||||
RETRYABLE_LLM_ERRORS = (
|
||||
RateLimitError,
|
||||
|
|
@ -99,20 +97,38 @@ def safe_exception_message(exc: Exception) -> str:
|
|||
|
||||
def llm_retryable_message(exc: Exception) -> str:
|
||||
try:
|
||||
adapted = adapt_llm_exception(exc)
|
||||
if adapted.category is LLMErrorCategory.UNKNOWN:
|
||||
return safe_exception_message(exc)
|
||||
return adapted.user_message
|
||||
if isinstance(exc, RateLimitError):
|
||||
return PipelineMessages.RATE_LIMIT
|
||||
if isinstance(exc, Timeout):
|
||||
return PipelineMessages.LLM_TIMEOUT
|
||||
if isinstance(exc, ServiceUnavailableError):
|
||||
return PipelineMessages.LLM_UNAVAILABLE
|
||||
if isinstance(exc, BadGatewayError):
|
||||
return PipelineMessages.LLM_BAD_GATEWAY
|
||||
if isinstance(exc, InternalServerError):
|
||||
return PipelineMessages.LLM_SERVER_ERROR
|
||||
if isinstance(exc, APIConnectionError):
|
||||
return PipelineMessages.LLM_CONNECTION
|
||||
return safe_exception_message(exc)
|
||||
except Exception:
|
||||
return "Something went wrong when calling the LLM."
|
||||
|
||||
|
||||
def llm_permanent_message(exc: Exception) -> str:
|
||||
try:
|
||||
adapted = adapt_llm_exception(exc)
|
||||
if adapted.category is LLMErrorCategory.UNKNOWN:
|
||||
return safe_exception_message(exc)
|
||||
return adapted.user_message
|
||||
if isinstance(exc, AuthenticationError):
|
||||
return PipelineMessages.LLM_AUTH
|
||||
if isinstance(exc, PermissionDeniedError):
|
||||
return PipelineMessages.LLM_PERMISSION
|
||||
if isinstance(exc, NotFoundError):
|
||||
return PipelineMessages.LLM_NOT_FOUND
|
||||
if isinstance(exc, BadRequestError):
|
||||
return PipelineMessages.LLM_BAD_REQUEST
|
||||
if isinstance(exc, UnprocessableEntityError):
|
||||
return PipelineMessages.LLM_UNPROCESSABLE
|
||||
if isinstance(exc, APIResponseValidationError):
|
||||
return PipelineMessages.LLM_RESPONSE
|
||||
return safe_exception_message(exc)
|
||||
except Exception:
|
||||
return "Something went wrong when calling the LLM."
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from collections.abc import Awaitable, Callable
|
|||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import delete, select, update
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
|
@ -19,17 +19,16 @@ from app.db import (
|
|||
DocumentStatus,
|
||||
DocumentType,
|
||||
)
|
||||
from app.indexing_pipeline.cache import build_chunk_embeddings
|
||||
from app.indexing_pipeline.cache.cached_indexing import chunk_markdown, embed_batch
|
||||
from app.indexing_pipeline.chunk_reconciler import ExistingChunk, reconcile
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
from app.indexing_pipeline.document_chunker import chunk_text, chunk_text_hybrid
|
||||
from app.indexing_pipeline.document_embedder import embed_texts
|
||||
from app.indexing_pipeline.document_hashing import (
|
||||
compute_content_hash,
|
||||
compute_identifier_hash,
|
||||
compute_unique_identifier_hash,
|
||||
)
|
||||
from app.indexing_pipeline.document_persistence import (
|
||||
persist_scratch_index,
|
||||
attach_chunks_to_document,
|
||||
rollback_and_persist_failure,
|
||||
)
|
||||
from app.indexing_pipeline.exceptions import (
|
||||
|
|
@ -381,50 +380,53 @@ class IndexingPipelineService:
|
|||
|
||||
content = connector_doc.source_markdown
|
||||
|
||||
t_step = time.perf_counter()
|
||||
existing = await self._load_existing_chunks(document.id)
|
||||
if existing and self._reconcile_enabled():
|
||||
chunk_count = await self._reindex_incrementally(
|
||||
document, content, connector_doc, existing
|
||||
)
|
||||
perf.info(
|
||||
"[indexing] chunk+embed doc=%d chunks=%d in %.3fs",
|
||||
document.id,
|
||||
chunk_count,
|
||||
time.perf_counter() - t_step,
|
||||
)
|
||||
document.content = content
|
||||
document.updated_at = datetime.now(UTC)
|
||||
document.status = DocumentStatus.ready()
|
||||
await self.session.commit()
|
||||
else:
|
||||
from app.config import config
|
||||
await self.session.execute(
|
||||
delete(Chunk).where(Chunk.document_id == document.id)
|
||||
)
|
||||
|
||||
chunks = await self._reindex_from_scratch(
|
||||
document, content, connector_doc
|
||||
t_step = time.perf_counter()
|
||||
if connector_doc.should_use_code_chunker:
|
||||
chunk_texts = await asyncio.to_thread(
|
||||
chunk_text,
|
||||
connector_doc.source_markdown,
|
||||
use_code_chunker=True,
|
||||
)
|
||||
chunk_count = len(chunks)
|
||||
perf.info(
|
||||
"[indexing] chunk+embed doc=%d chunks=%d in %.3fs",
|
||||
document.id,
|
||||
chunk_count,
|
||||
time.perf_counter() - t_step,
|
||||
)
|
||||
await persist_scratch_index(
|
||||
self.session,
|
||||
document,
|
||||
content,
|
||||
chunks,
|
||||
batch_size=config.INDEXING_CHUNK_INSERT_BATCH_SIZE,
|
||||
perf=perf,
|
||||
else:
|
||||
# Use the table-aware hybrid chunker so Markdown tables are not
|
||||
# split mid-row (see issue #1334).
|
||||
chunk_texts = await asyncio.to_thread(
|
||||
chunk_text_hybrid,
|
||||
connector_doc.source_markdown,
|
||||
)
|
||||
|
||||
texts_to_embed = [content, *chunk_texts]
|
||||
embeddings = await asyncio.to_thread(embed_texts, texts_to_embed)
|
||||
summary_embedding, *chunk_embeddings = embeddings
|
||||
|
||||
chunks = [
|
||||
Chunk(content=text, embedding=emb)
|
||||
for text, emb in zip(chunk_texts, chunk_embeddings, strict=False)
|
||||
]
|
||||
perf.info(
|
||||
"[indexing] chunk+embed doc=%d chunks=%d in %.3fs",
|
||||
document.id,
|
||||
len(chunks),
|
||||
time.perf_counter() - t_step,
|
||||
)
|
||||
|
||||
document.content = content
|
||||
document.embedding = summary_embedding
|
||||
attach_chunks_to_document(document, chunks)
|
||||
document.updated_at = datetime.now(UTC)
|
||||
document.status = DocumentStatus.ready()
|
||||
await self.session.commit()
|
||||
perf.info(
|
||||
"[indexing] index TOTAL doc=%d chunks=%d in %.3fs",
|
||||
document.id,
|
||||
chunk_count,
|
||||
len(chunks),
|
||||
time.perf_counter() - t_index,
|
||||
)
|
||||
log_index_success(ctx, chunk_count=chunk_count)
|
||||
log_index_success(ctx, chunk_count=len(chunks))
|
||||
outcome_status = "success"
|
||||
|
||||
await self._enqueue_ai_sort_if_enabled(document)
|
||||
|
|
@ -481,89 +483,6 @@ class IndexingPipelineService:
|
|||
persist_span_cm.__exit__(*sys.exc_info())
|
||||
return document
|
||||
|
||||
@staticmethod
|
||||
def _reconcile_enabled() -> bool:
|
||||
from app.config import config
|
||||
|
||||
return config.CHUNK_RECONCILE_ENABLED
|
||||
|
||||
async def _load_existing_chunks(self, document_id: int) -> list[ExistingChunk]:
|
||||
result = await self.session.execute(
|
||||
select(Chunk.id, Chunk.content, Chunk.position).where(
|
||||
Chunk.document_id == document_id
|
||||
)
|
||||
)
|
||||
return [
|
||||
ExistingChunk(id=row.id, content=row.content, position=row.position)
|
||||
for row in result
|
||||
]
|
||||
|
||||
async def _reindex_from_scratch(
|
||||
self, document: Document, content: str, connector_doc: ConnectorDocument
|
||||
) -> list[Chunk]:
|
||||
await self.session.execute(
|
||||
delete(Chunk).where(Chunk.document_id == document.id)
|
||||
)
|
||||
|
||||
summary_embedding, chunk_pairs = await build_chunk_embeddings(
|
||||
content,
|
||||
use_code_chunker=connector_doc.should_use_code_chunker,
|
||||
)
|
||||
|
||||
document.embedding = summary_embedding
|
||||
return [
|
||||
Chunk(content=text, embedding=emb, position=i)
|
||||
for i, (text, emb) in enumerate(chunk_pairs)
|
||||
]
|
||||
|
||||
async def _reindex_incrementally(
|
||||
self,
|
||||
document: Document,
|
||||
content: str,
|
||||
connector_doc: ConnectorDocument,
|
||||
existing: list[ExistingChunk],
|
||||
) -> int:
|
||||
"""Edit path: keep rows whose text survived, embed only new texts.
|
||||
|
||||
Unchanged rows keep their embedding and their HNSW/GIN index entries;
|
||||
moved rows get a position-only UPDATE, which touches neither index.
|
||||
"""
|
||||
new_texts = await chunk_markdown(
|
||||
content, use_code_chunker=connector_doc.should_use_code_chunker
|
||||
)
|
||||
plan = reconcile(existing, new_texts)
|
||||
|
||||
# One batch: the document-level summary vector plus the missing chunks.
|
||||
embeddings = await embed_batch([content, *[t for _, t in plan.to_embed]])
|
||||
summary_embedding, *new_embeddings = embeddings
|
||||
|
||||
if plan.reused:
|
||||
await self.session.execute(
|
||||
update(Chunk),
|
||||
[{"id": cid, "position": pos} for cid, pos in plan.reused],
|
||||
)
|
||||
if plan.to_delete:
|
||||
await self.session.execute(
|
||||
delete(Chunk).where(Chunk.id.in_(plan.to_delete))
|
||||
)
|
||||
self.session.add_all(
|
||||
Chunk(
|
||||
content=text,
|
||||
embedding=emb,
|
||||
position=pos,
|
||||
document_id=document.id,
|
||||
)
|
||||
for (pos, text), emb in zip(plan.to_embed, new_embeddings, strict=True)
|
||||
)
|
||||
document.embedding = summary_embedding
|
||||
|
||||
ot_metrics.record_chunk_reconcile(
|
||||
reused=len(existing) - len(plan.to_delete),
|
||||
embedded=len(plan.to_embed),
|
||||
deleted=len(plan.to_delete),
|
||||
)
|
||||
return len(new_texts)
|
||||
|
||||
async def _enqueue_ai_sort_if_enabled(self, document: Document) -> None:
|
||||
"""Fire-and-forget: enqueue incremental AI sort if the search space has it enabled."""
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -275,7 +275,7 @@ async def list_notifications(
|
|||
query = query.where(unread_filter)
|
||||
count_query = count_query.where(unread_filter)
|
||||
elif filter == "errors":
|
||||
error_filter = (Notification.type == "insufficient_credits") | (
|
||||
error_filter = (Notification.type == "page_limit_exceeded") | (
|
||||
Notification.notification_metadata["status"].astext == "failed"
|
||||
)
|
||||
query = query.where(error_filter)
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue